kotlingrad alternatives and similar libraries
Based on the "Science" category
Do you think we are missing an alternative of kotlingrad or a related project?
README
<! @file:Suppress("ClassName") > <! @file:Suppress("PropertyName") >
Kotlin∇: Typesafe Symbolic Differentiation for Kotlin
Kotlin∇ is a typesafe automatic differentiation framework in Kotlin. It allows users to express differentiable programs with higherdimensional data structures and operators. We attempt to restrict syntactically valid constructions to those which are algebraically valid and can be checked at compiletime. By enforcing these constraints in the type system, it eliminates certain classes of runtime errors that may occur during the execution of a differentiable program. Due to typeinference in the language, most types may be safely omitted by the end user. Kotlin∇ strives to be expressive, safe, and notationally similar to mathematics. It is currently prerelease and offers no stability guarantees at this time.
Table of contents
 Introduction
 Supported features
 Usage
 Visualization
 Testing and gradient checking
 How does it work?
 Formal Grammar
 Comparison to other frameworks
 Citation
 Special thanks
Introduction
Inspired by Stalin∇, Autograd, DiffSharp, Myia, Nexus, Tangent, Lantern et al., Kotlin∇ attempts to port recent advancements in automatic differentiation (AD) to the Kotlin language. AD is useful for gradient descent and has a variety of applications in numerical optimization and machine learning. Our implementation adds a number of experimental ideas, including compiletime shapesafety, algebraic simplification and numerical stability checking with propertybased testing. We aim to provide an algebraicallygrounded implementation of AD for shapesafe tensor operations. Tensors in Kotlin∇ are represented as multidimensional arrays.
Features
Kotlin∇ currently supports the following features:
 Arithmetical operations on scalars, vectors and matrices
 Shapesafe vector and matrix algebra
 Partial and higherorder differentiation on scalars
 Propertybased testing for numerical gradient checking
 Recovery of symbolic derivatives from AD
Additionally, it aims to support:
 PyTorchstyle definebyrun semantics
 Ndimensional tensors and higherorder tensor operators
 Fullygeneral AD over control flow, variable reassignment (via delegation), and array programming, possibly using a typed IR such as Myia
All of these features are implemented without access to bytecode or special compiler tricks  just using higherorder functions and lambdas as shown in Lambda the Ultimate Backpropogator, embedded DSLs a la Lightweight Modular Staging, and ordinary generics. Please see below for a more detailed feature comparison.
Usage
Installation
Kotlin∇ is hosted on the GitHub Package Registry. If you have not already done so, first generate a new GitHub Personal Access Token with the read:packages
permission.
Gradle
Gradle users should write their GPR credentials to the ~/.gradle/gradle.properties
file as follows:
gpr.user=<USERNAME>
grp.key=<PERSONAL_ACCESS_TOKEN>
Ensure GRADLE_USER_HOME
points to ~/.gradle
. Then add a repository and dependency to the build.gradle.kts
file:
repositories {
maven("https://maven.pkg.github.com/breandan/kotlingrad") {
credentials {
username = ext.properties["gpr.user"] as String?
password = ext.properties["gpr.key"] as String?
}
}
}
dependencies {
implementation("edu.umontreal:kotlingrad:<VERSION>")
}
Finally, run gradle dependencies
to ensure the requested dependency can be downloaded.
For additional help, please refer to the GPR Gradle configuration instructions.
Maven
Maven users should refer to these instructions.
Notation
Kotlin∇ operators are higherorder functions, which take at most two inputs and return a single output, all of which are functions with the same numerical type, and whose shape is denoted using superscript in the rightmost column below.
Math  Infix †  Prefix  Postfix‡  Operator Type Signature 

a(b) a of b 
(a : ℝτ→ℝπ, b : ℝλ → ℝτ) → (ℝλ→ℝπ) 

a + b a  b 
plus(a, b) minus(a, b) 
(a : ℝτ→ℝπ, b : ℝλ → ℝπ) → (ℝ?→ℝπ) 

a * b a.times(b) 
times(a, b) 
(a : ℝτ→ℝm×n, b : ℝλ→ℝn×p) → (ℝ?→ℝm×p) 

a / b a.div(b) 
div(a, b) 
(a : ℝτ→ℝm×n, b : ℝλ→ℝp×n) → (ℝ?→ℝm×p) 

a +a 
a.unaryMinus() a.unaryPlus() 
(a : ℝτ→ℝπ) → (ℝτ→ℝπ) 

sin(a) cos(a) tan(a) 
a.sin() a.cos() a.tan() 
(a : ℝ→ℝ) → (ℝ→ℝ) 

ln(a) log(a) 
a.ln() a.log() 
(a : ℝτ→ℝm×m) → (ℝτ→ℝm×m) 

a.log(b) 
log(a, b) 
(a : ℝτ→ℝm×m, b : ℝλ→ℝm×m) → (ℝ?→ℝ) 

a.pow(b) 
pow(a, b) 
(a : ℝτ→ℝm×m, b : ℝλ→ℝ) → (ℝ?→ℝm×m) 

a.pow(1.0/2) a.root(3) 
sqrt(a) cbrt(a) 
a.sqrt() a.cbrt() 
(a : ℝτ→ℝm×m) → (ℝτ→ℝm×m) 

a.d(b) d(a) / d(b) 
grad(a)[b] 
(a : C(ℝτ→ℝ)*, b : C(ℝλ→ℝ)) → (ℝ?→ℝ) 

grad(a) 
a.grad() 
(a : C(ℝτ→ℝ)) → (ℝτ→ℝτ) 

a.d(b) a.grad(b) 
grad(a, b) 
(a : C(ℝτ→ℝ), b : C(ℝλ→ℝn)) → (ℝ?→ℝn) 

divg(a) 
a.divg() 
(a : C(ℝτ→ℝm)) → (ℝτ→ℝ) 

curl(a) 
a.curl() 
(a : C(ℝ3→ℝ3)) → (ℝ3→ℝ3) 

grad(a) 
a.grad() 
(a : C(ℝτ→ℝm)) → (ℝτ→ℝm×τ) 

a.d(b) a.grad(b) 
grad(a, b) 
(a : C(ℝτ→ℝm), b : C(ℝλ→ℝn)) → (ℝ?→ℝm×n) 

hess(a) 
a.hess() 
(a : C(ℝτ→ℝ)) → (ℝτ→ℝτ×τ) 

lapl(a) 
a.lapl() 
(a : C(ℝτ→ℝ)) → (ℝτ→ℝτ) 
<! Equations >
ℝ can be a Double
, Float
or BigDecimal
. Specialized operators are defined for subsets of ℝ, e.g. Int
, Short
or BigInteger
for subsets of ℤ, however differentiation is only defined for continuous functions on ℝ.
† a
and b
are higherorder functions. These may be constants (e.g. 0
, 1.0
), variables (e.g. Var("x")
) or expressions (e.g. x + 1
, 2 * x + y
).
‡ For infix notation, .
is optional. Parentheses are also optional depending on precedence.
§ Matrix division is defined iff B is invertible, although it could be possible to redefine this operator using the MoorePenrose inverse.
∗ Where C(ℝm) is the space of all continuous functions over ℝ. If the function is not over ℝ, it will fail at compiletime. If the function is over ℝ but not continuous differentiable at the point under consideration, it will fail at runtime.
? While it would be nice to infer a union type bound over the inputs of binary functions, it is likely impossible using the Kotlin type system [without great effort](core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/VariableCapture.kt). Otherwise, if the user desires compiletime shapesafety when invoking higher order functions with literal values, they will need to specify the combined input type explicitly, or wait for a runtime exception.
τ, λ Arbitrarily shaped tensors.
Shape Safety
Shape safety is an important concept in Kotlin∇. There are three broad strategies for handling shape errors:
 Hide the error somehow by implicitly reshaping or broadcasting arrays
 Announce the error at runtime, with a relevant message, e.g.
InvalidArgumentError
 Do not allow programs which can result in a shape error to compile
In Kotlin∇, we use the last strategy to check the shape of tensor operations. Consider the following program:
// Inferred type: Vec<Double, D2>
val a = Vec(1.0, 2.0)
// Inferred type: Vec<Double, D3>
val b = Vec(1.0, 2.0, 3.0)
val c = b + b
// Does not compile, shape mismatch
// a + b
Attempting to sum two vectors whose shapes do not match will fail to compile, and they must be explicitly resized.
// Inferred type: Mat<Double, D1, D4>
val a = Mat1x4(1.0, 2.0, 3.0, 4.0)
// Inferred type: Mat<Double, D4, D1>
val b = Mat4x1(1.0, 2.0, 3.0, 4.0)
val c = a * b
// Does not compile, inner dimension mismatch
// a * a
// b * b
Similarly, attempting to multiply two matrices whose inner dimensions do not match will fail to compile.
val a = Mat2x4(
1.0, 2.0, 3.0, 4.0,
5.0, 6.0, 7.0, 8.0
)
val b = Mat4x2(
1.0, 2.0,
3.0, 4.0,
5.0, 6.0,
7.0, 8.0
)
// Types are optional, but encouraged
val c: Mat<Double, D2, D2> = a * b
val d = Mat2x1(1.0, 2.0)
val e = c * d
val f = Mat3x1(1.0, 2.0, 3.0)
// Does not compile, inner dimension mismatch
// e * f
Explicit types are optional but encouraged. Type inference helps preserve shape information over long programs.
fun someMatFun(m: Mat<Double, D3, D1>): Mat<Double, D3, D3> = ...
fun someMatFun(m: Mat<Double, D2, D2>) = ...
When writing a function, it is mandatory to declare the input type(s), but the return type may be omitted. Shapesafety is currently supported up to rank2 tensors, i.e. matrices.
Variable Capture
Kotlin∇ provides a DSL with support for typesafe variable capture with variadic currying. Consider the following example:
val q = X + Y * Z + Y + 0.0
val p0 = q(X to 1.0, Y to 2.0, Z to 3.0) // Name resolution
val p1 = q(X to 1.0, Y to 1.0)(Z to 1.0) // Variadic currying
val p3 = q(Z to 1.0)(X to 1.0, Y to 1.0) // Any order is possible
val p4 = q(Z to 1.0)(X to 1.0)(Y to 1.0) // Proper currying
val p5 = q(Z to 1.0)(X to 1.0) // Returns a partially applied function
val p6 = (X + Z + 0)(Y to 1.0) // Does not compile
For further details, please refer to [the implementation](core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/VariableCapture.kt).
Example
The following example shows how to derive higherorder partials of a function z
of type ℝ²→ℝ:
import edu.umontreal.kotlingrad.numerical.DoublePrecision
@Suppress("NonAsciiCharacters", "LocalVariableName")
fun main() {
with(DoublePrecision) {
val x = Var("x")
val y = Var("y")
val z = x * (sin(x * y) + y) * 4 // Infix notation
val `∂z∕∂x` = d(z) / d(x) // Leibniz notation [Christianson, 2012]
val `∂z∕∂y` = d(z) / d(y) // Partial derivatives
val `∂²z∕∂x²` = d(`∂z∕∂x`) / d(x) // Higher order derivatives
val `∂²z∕∂x∂y` = d(`∂z∕∂x`) / d(y) // Higher order partials
val `∇z` = z.grad() // Gradient operator
val values = mapOf(x to 0, y to 1)
val indVar = z.variables.joinToString(", ")
print("z($indVar) \t\t\t= $z\n" +
"z($values) \t\t\t= ${z(values)}\n" +
"∂z($values)/∂x \t\t= $`∂z∕∂x` \n\t\t\t\t= " + `∂z∕∂x`(values) + "\n" +
"∂z($values)/∂y \t\t= $`∂z∕∂y` \n\t\t\t\t= " + `∂z∕∂y`(values) + "\n" +
"∂²z($values)/∂x² \t\t= $`∂z∕∂y` \n\t\t\t\t= " + `∂²z∕∂x²`(values) + "\n" +
"∂²z($values)/∂x∂y \t\t= $`∂²z∕∂x∂y` \n\t\t\t\t= " + `∂²z∕∂x∂y`(values) + "\n" +
"∇z($values) \t\t\t= $`∇z` \n\t\t\t\t= [${`∇z`[x]!!(values)}, ${`∇z`[y]!!(values)}]ᵀ")
}
}
Any backticks and unicode characters above are simply for readability and have no effect on the behavior. Running [this program](samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/HelloKotlinGrad.kt) via ./gradlew HelloKotlinGrad
should produce the following output:
z(x, y) = x * (sin(x * y) + y) * 4
z({x=0, y=1}) = 0.0
∂z({x=0, y=1})/∂x = (sin(x * y) + y  x * cos(x * y) * y) * 4
= 4.0
∂z({x=0, y=1})/∂y = x * (1  cos(x * y) * x) * 4
= 0.0
∂²z({x=0, y=1})/∂x² = x * (1  cos(x * y) * x) * 4
= 8.0
∂²z({x=0, y=1})/∂x∂y = (x * (sin(x * y) * x * y + cos(x * y)) + 1  cos(x * y) * x) * 4
= 4.0
∇z({x=0, y=1}) = {x=(sin(x * y) + y  x * cos(x * y) * y) * 4, y=x * (1  cos(x * y) * x) * 4}
= [4.0, 0.0]ᵀ
Visualization tools
Kotlin∇ provides various graphical tools that can be used for visual debugging.
Dataflow
Kotlin∇ functions are a type of directed acyclic graph, called dataflow graphs (DFGs). For example, running the expression ((1 + x * 2  3 + y + z / y).d(y).d(x) + z / y * 3  2).render()
will display the following DFG:
[](samples/src/main/resources/dataflow.svg)
Plotting
To generate the [sample 2D plots](samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/Plot2D.kt) below, run ./gradlew Plot2D
.
Plotting is also possible in higher dimensions, [for example](samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/Plot3D.kt) in 3D via ./gradlew Plot3D
:
[](samples/src/main/resources/ripple.png) [](samples/src/main/resources/pulsar.png) [](samples/src/main/resources/starquake.png) [](samples/src/main/resources/novaflux.png)
Testing
To run [the tests](core/src/test/kotlin/edu/umontreal/kotlingrad), execute: ./gradlew test
Kotlin∇ claims to eliminate certain runtime errors, but how do we know the proposed implementation is not incorrect? One method, borrowed from the Haskell community, is called propertybased testing (PBT), closely related to metamorphic testing. Notable implementations include QuickCheck, Hypothesis and ScalaTest (ported to Kotlin in KotlinTest). PBT uses algebraic properties to verify the result of an operation by constructing semantically equivalent but syntactically distinct expressions, which should produce the same answer. Kotlin∇ uses two such equivalences to validate its AD implementation:
 Analytic differentiation: manually differentiate and compare the values returned on a subset of the domain with AD.
 Finite difference approximation: sample space of symbolic (differentiable) functions, comparing results of AD to FD.
For example, consider the following test, which checks whether the analytical derivative and the automatic derivative, when evaluated at a given point, are equal to each other within the limits of numerical precision:
val x = Var("x")
val y = Var("y")
val z = y * (sin(x * y)  x) // Function under test
val `∂z∕∂x` = d(z) / d(x) // Automatic derivative
val manualDx = y * (cos(x * y) * y  1) // Analytical derivative
"∂z/∂x should be y * (cos(x * y) * y  1)" {
NumericalGenerator.assertAll { ẋ, ẏ >
// Evaluate the results at a given seed
val autoEval = `∂z∕∂x`(x to ẋ, y to ẏ)
val manualEval = manualDx(x to ẋ, y to ẏ)
// Should pass iff Δ(adEval, manualEval) < Ɛ
autoEval shouldBeApproximately manualEval
}
}
PBT will search the input space for two numerical values ẋ
and ẏ
, which violate the specification, then "shrink" them to discover passfail boundary values. We can construct a similar test using finite differences:
"d(sin x)/dx should be equal to (sin(x + dx)  sin(x)) / dx" {
NumericalGenerator.assertAll { ẋ >
val f = sin(x)
val `df∕dx` = d(f) / d(x)
val adEval = `df∕dx`(ẋ)
val dx = 1E8
// Since ẋ is a raw numeric type, sin => kotlin.math.sin
val fdEval = (sin(ẋ + dx)  sin(ẋ)) / dx
adEval shouldBeApproximately fdEval
}
}
[](samples/src/main/resources/comparison.svg)
Above, we [compare numerical errors](samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/ADSDComparison.kt) for three types of computational differentiation: (1) finite precision automatic differentiation (AD), (2) finite precision symbolic differentiation (SD) and (3) finite precision finite differences (FD) against infinite precision symbolic differentiation (IP). AD and SD both exhibit relative errors (i.e. with respect to each other) several orders of magnitude lower than their absolute errors (i.e. with respect to IP), which roughly agree to within numerical precision. As expected, FD exhibits numerical error significantly higher than AD and SD due to the inaccuracy of floating point division.
There are many other ways to independently verify the numerical gradient, such as dual numbers or the complex step derivative. Another method is to compare the numerical output against a wellknown implementation, such as TensorFlow. We plan to conduct a more thorough comparison of numerical accuracy and performance.
How?
To understand the core of Kotlin∇'s AD implementation, please refer to the [toy example](core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyExample.kt).
This project relies on a few Kotlinnative language features, which together enable a concise, flexible and typesafe user interface. The following features have proven beneficial to the development of Kotlin∇:
Operator overloading
Operator overloading enables concise notation for arithmetic on abstract types, where the types encode algebraic structures, e.g. [Group
](core/src/main/kotlin/edu/umontreal/kotlingrad/algebra/Group.kt), [Ring
](core/src/main/kotlin/edu/umontreal/kotlingrad/algebra/Ring.kt), and [Field
](core/src/main/kotlin/edu/umontreal/kotlingrad/algebra/Field.kt). These abstractions are extensible to other kinds of mathematical structures, such as complex numbers and quaternions.
For example, suppose we have an interface Group
, which overloads the operators +
and *
, and is defined like so:
interface Group<T: Group<T>> {
operator fun plus(addend: T): T
operator fun times(multiplicand: T): T
}
Here, we specify a recursive type bound using a method known as Fbounded quantification to ensure that operations return the concrete type variable T
, rather than something more abstract like Group
. Imagine a class Expr
which has implemented Group
. It can be used as follows:
fun <T: Group<T>> cubed(t: T): T = t * t * t
fun <E: Expr<E>> twiceExprCubed(e: E): E = cubed(e) + cubed(e)
Like Python, Kotlin supports overloading a limited set of operators, which are evaluated using a fixed precedence. In the current version of Kotlin∇, operators do not perform any computation, they simply construct a directed acyclic graph representing the symbolic expression. Expressions are only evaluated when invoked as a function.
Firstclass functions
With higherorder functions and lambdas, Kotlin treats functions as firstclass citizens. This allows us to represent mathematical functions and programming functions with the same underlying abstractions (typed FP). A number of recent papers have demonstrated the expressiveness of this paradigm for automatic differentiation.
In Kotlin∇, all expressions can be treated as functions. For example:
fun <T: Group<T>> makePoly(x: Var<T>, y: Var<T>) = x * y + y * y + x * x
val x: Var<Double> = Var()
val y: Var<Double> = Var()
val f = makePoly(x, y)
val z = f(1.0, 2.0) // Returns a value
println(z) // Prints: 7
Currently, it is only possible to represent functions where all inputs and outputs share a single type. In future iterations, it is possible to extend support for building functions with varying input/output types and enforce constraints on both, using covariant and contravariant type bounds.
Coroutines
Coroutines are a generalization of subroutines for nonpreemptive multitasking, typically implemented using continuations. One form of continuation, known as shiftreset a.k.a. delimited continuations, are sufficient for implementing reverse mode AD with operator overloading alone (without any additional data structures) as described by Wang et al. in Shift/Reset the Penultimate Backpropagator and later in Backpropagation with Continuation Callbacks. Delimited continuations can be implemented using Kotlin coroutines and would be an interesting extension to this work. Please stay tuned!
Extension Functions
Extension functions augment external classes with new fields and methods. Via context oriented programming, Kotlin∇ can expose its custom extensions (e.g. in [DoublePrecision](core/src/main/kotlin/edu/umontreal/kotlingrad/numerical/Protocol.kt)) to [consumers](samples/src/main/kotlin/edu/umontreal/kotlingrad/samples/HelloKotlinGrad.kt) without requiring subclasses or inheritance.
data class Const<T: Group<T>>(val number: Double) : Expr()
data class Sum<T: Group<T>>(val e1: Expr, val e2: Expr) : Expr()
data class Prod<T: Group<T>>(val e1: Expr, val e2: Expr) : Expr()
class Expr<T: Group<T>>: Group<Expr<T>> {
operator fun plus(addend: Expr<T>) = Sum(this, addend)
operator fun times(multiplicand: Expr<T>) = Prod(this, multiplicand)
}
object DoubleContext {
operator fun Number.times(expr: Expr<Double>) = Const(toDouble()) * expr
}
Now, we can use the context to define another extension, Expr.multiplyByTwo
, which computes the product inside a DoubleContext
, using the operator overload we defined above:
fun Expr<Double>.multiplyByTwo() = with(DoubleContext) { 2 * this } // Uses `*` operator in DoubleContext
Extensions can also be defined in another file or context and imported on demand.
Algebraic data types
Algebraic data types (ADTs) in the form of sealed classes (a.k.a. sum types) facilitate a limited form of pattern matching over a closed set of subclasses. When matching against subclasses of a sealed class, the compiler forces the author to provide an exhaustive control flow over all concrete subtypes of an abstract class. Consider the following classes:
class Const<T: Fun<T>>(val number: Number) : Fun<T>()
class Sum<T: Fun<T>>(val left: Fun<T>, val right: Fun<T>) : Fun<T>()
class Prod<T: Fun<T>>(val left: Fun<T>, val right: Fun<T>) : Fun<T>()
class Var<T: Fun<T>>: Fun<T>() { override val variables: Set<Var<X>> = setOf(this) }
class Zero<T: Fun<T>>: Const<T>(0.0)
class One<T: Fun<T>>: Const<T>(1.0)
When branching on the type of a sealed class, consumers must explicitly handle every case, since incomplete control flow will not compile rather than fail silently at runtime. Let us now consider a simplified definition of Fun
, a sealed class which defines the behavior of function invocation and differentiation, using a restricted form of pattern matching. It can be constructed with a set of Var
s, and can be invoked with a numerical value:
sealed class Fun<X: Fun<X>>(open val variables: Set<Var<X>> = emptySet()): Group<Fun<X>> {
constructor(vararg fns: Fun<X>): this(fns.flatMap { it.variables }.toSet())
// Since the subclasses of Fun are a closed set, no `else ...` is required.
operator fun invoke(map: Map<Var<X>, X>): Fun<X> = when (this) {
is Const > this
is Var > map.getOrElse(this) { this } // Partial application is permitted
is Prod > left(map) * right(map) // Smart casting implicitly casts after checking
is Sum > left(map) + right(map)
}
fun d(variable: Var<X>): Fun<X> = when(this) {
is Const > Zero
is Var > if (variable == this) One else Zero
// Product rule: d(u*v)/dx = du/dx * v + u * dv/dx
is Prod > left.d(variable) * right + left * right.d(variable)
is Sum > left.d(variable) + right.d(variable)
}
operator fun plus(addend: Fun<T>) = Sum(this, addend)
operator fun times(multiplicand: Fun<T>) = Prod(this, multiplicand)
}
Kotlin's smartcasting implicitly downcasts the abstract type Fun
as a subtype, such as Sum
after performing an is Sum
check. If Fun
were not sealed, we would have needed to write (this as Sum).left
instead to access its member, left
. If the type cast was mistaken, a ClassCastException
would need to be thrown, which smart casting also prevents.
Multiple Dispatch
In conjunction with ADTs, Kotlin∇ also uses multiple dispatch to instantiate the most specific result type of applying an operator based on the type of its operands. While multiple dispatch is not an explicit language feature, it can be emulated using inheritance.
Building on the previous example, a common task in AD is to simplify a graph. This is useful in order to minimize the total number of calculations required, improving numerical stability. We can eagerly simplify expressions based on algebraic rules of replacement. Smart casting allows us to access members of a class after checking its type, without explicitly casting it:
override fun times(multiplicand: Function<X>): Function<X> = when {
this == zero > this
this == one > multiplicand
multiplicand == one > this
multiplicand == zero > multiplicand
this == multiplicand > pow(two)
this is Const && multiplicand is Const > const(value * multiplicand.value)
// Further simplification is possible using rules of replacement
else > Prod(this, multiplicand)
}
val result = Const(2.0) * Sum(Var(2.0), Const(3.0)) // Sum(Prod(Const(2.0), Var(2.0)), Const(6.0))
This allows us to put all related control flow on a single abstract class which is inherited by subclasses, simplifying readability, debugging and refactoring.
Shapesafe Tensor Operations
While firstclass dependent types are useful for ensuring arbitrary shape safety (e.g. when concatenating and reshaping matrices), they are unnecessary for simple equality checking (such as when multiplying two matrices).* When the shape of a tensor is known at compiletime, it is possible to encode this information using a less powerful type system, as long as it supports subtyping and parametric polymorphism (a.k.a. generics). In practice, we can implement a shapechecked tensor arithmetic in languages like Java, Kotlin, C++, C# or Typescript, which accept generic type parameters. In Kotlin, whose type system is less expressive than Java, we use the following strategy.
Shape safety is currently supported up to rank2 tensors, i.e. matrices. To perform dimension checking in our type system, first we enumerate a list of integer type literals as a chain of subtypes, C <: C  1 <: C  2 <: ... <: 1 <: 0
, where C
is the largest fixedlength dimension we wish to represent, which can be specified by the user prior to compilation. This guarantees linear space and time complexity for subtype checking, with a constant upper bound.
@file:Suppress("ClassName")
interface Nat<T: D0> { val i: Int } // Used for certain type bounds
sealed class D0(open val i: Int = 0) { companion object: D0(), Nat<D0> }
sealed class D1(override val i: Int = 1): D0(i) { companion object: D1(), Nat<D1> }
sealed class D2(override val i: Int = 2): D1(i) { companion object: D2(), Nat<D2> }
sealed class D3(override val i: Int = 3): D2(i) { companion object: D3(), Nat<D3> }
//...
Next, we overload the call operator to emulate instantiating a collection literal, using arity to infer its dimensionality. Consider the rank1 case for length inference on vector literals:
open class Vec<E, Len: D1> constructor(val contents: List<E>) {
companion object {
operator fun <T> invoke(t: T): Vec<T, D1> = Vec(listOf(t))
operator fun <T> invoke(t0: T, t1: T): Vec<T, D2> = Vec(listOf(t0, t1))
operator fun <T> invoke(t0: T, t1: T, t2: T): Vec<T, D3> = Vec(listOf(t0, t1, t2))
}
}
Finally, we encode length as a parameter of the operand type. Since integer literals are a chain of subtypes, we only need to define one operator using the highest literal, and can rely on Liskov substitution to preserve shape safety for all subtypes.
@JvmName("floatVecPlus") infix operator fun <C: D1, V: Vec<Float, C>> V.plus(v: V): Vec<Float, C> =
Vec(length, contents.zip(v.contents).map { it.first + it.second })
The operator +
can now be used like so. Incompatible operands will cause a type error:
val one = Vec(1, 2, 3) + Vec(1, 2, 3) // Always runs safely
val add = Vec(1, 2, 3) + Vec(D3, listOf(...)) // May fail at runtime
val vec = Vec(1, 2, 3) // Does not compile
val sum = Vec(1, 2) + add // Does not compile
A similar syntax is available for [matrices](core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyMatrixExample.kt) and higherrank [tensors](core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyTensorExample.kt). For example, Kotlin∇ can infer the shape of multiplying two matrices, and will not compile if their inner dimensions do not match:
open class Mat<X, R: D1, C: D1>(vararg val rows: Vec<X, C>)
fun <X> Mat1x2(d0: X, d1: X): Mat<X, D1, D2> = Mat(Vec(d0, d1))
fun <X> Mat2x1(d0: X, d1: X): Mat<X, D2, D1> = Mat(Vec(d0), Vec(d1))
// ...
operator fun <Q: D1, R: D1, S: D1> Mat<Int, Q, R>.times(m: Mat<Int, R, S>): Mat<Int, Q, S> = TODO()
// Inferred type: Mat<Int, D4, D4>
val l = Mat4x4(
1, 2, 3, 4,
5, 6, 7, 8,
9, 0, 0, 0,
9, 0, 0, 0
)
// Inferred type: Mat<Int, D4, D3>
val m = Mat4x3(
1, 1, 1,
2, 2, 2,
3, 3, 3,
4, 4, 4
)
// Inferred type: Mat<Int, D4, D3>
val lm = l * m
// m * m // Compile error: Expected Mat<3, *>, found Mat<4, 3>
[Further examples](core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/ToyMatrixExample.kt) are provided for shapesafe matrix operations such as addition, subtraction and transposition.
A similar technique is possible in Haskell, which is capable of a more powerful form of typelevel computation, type arithmetic. Type arithmetic makes it easy to express convolutional arithmetic and other arithmetic operations on shape variables (say, splitting a vector in half), which is currently not possible, or would require enumerating every possible combination of type literals.
∗ Many less powerful type systems are still capable of performing arbitrary computation in the type checker. As specified, Java's type system is known to be Turing Complete. It may be possible to emulate a limited form of dependent types in Java by exploiting this property, although this may not computationally tractable due to the practical limitations noted by Grigore.
Ideal API (WIP)
The current API is experimental, but can be improved in many ways. Currently, Kotlin∇ does not infer a function's input dimensionality (i.e. free variables and their corresponding shape). While it is possible to perform variable capture over a small alphabet using [type safe currying](core/src/main/kotlin/edu/umontreal/kotlingrad/experimental/VariableCapture.kt), this technique incurs a large source code overhead. It may be possible to reduce the footprint using phantom types or some form of union type bound (cf. Kotlin, Java).
When the shape of an Ndimensional array is known at compiletime, we can use [typelevel integers](core/src/main/kotlin/edu/umontreal/kotlingrad/dependent) to ensure shape conforming tensor operations (inspired by Nexus and others).
Allowing users to specify a matrix's structure in its type signature, (e.g. Singular
, Symmetric
, Orthogonal
, Unitary
, Hermitian
, Toeplitz
) would allows us to specialize derivation over such matrices (cf. section 2.8 of The Matrix Cookbook).
Scalar functions
A function's type would ideally encode arity, based on the number of unique variables:
val f = x * y + sin(2 * x + 3 * y) // f: BinaryFunction<Double> "
val g = f(x to 1.0) // g: UnaryFunction<Double> == y + sin(2 + 3 * y)
val h = f(x to 0.0, y to 0.0) // h: Const<Double> == 0 + sin(0 + 0) == 0
However inferring arity for arbitrary expressions at compiletime would be difficult in the Kotlin type system. Instead, we can have the user specify it directly.
val x = Var(1.0) // x: Variable<Double> inferred type
val y = Var(1.0) // x: Variable<Double> "
val f = Fun(D2) { x * y + sin(2 * x + 3 * y) } // f: BinaryFunction<Double> "
val g = f(x to 1.0) // g: UnaryFunction<Double> == y + sin(2 + 3 * y)
val h = f(x to 0.0, y to 0.0) // h: Const<Double> == 0 + sin(0 + 0) == 0
Grammar
Below is the approximate BNF grammar for Kotlin∇. This is incomplete and subject to change without notice.
type = "Double"  "Float"  "Int"  "BigInteger"  "BigDouble";
nat = "1"  ...  "99";
output = "Fun<" type "Real>"  "VFun<" type "Real," nat ">"  "MFun<" type "Real," nat "," nat ">";
int = "0"  nat int;
float = int "." int;
num = type "(" int ")"  type "(" float ")";
var = "x"  "y"  "z"  "ONE"  "ZERO"  "E"  "Var()";
signOp = "+"  "";
binOp = signOp  "*"  "/"  "pow";
trigOp = "sin"  "cos"  "tan"  "asin"  "acos"  "atan"  "asinh"  "acosh"  "atanh";
unaryOp = signOp  trigOp  "sqrt"  "log"  "ln"  "exp";
exp = var  num  unaryOp exp  var binOp exp  "(" exp ")";
expList = exp  exp "," expList;
linOp = signOp  "*"  " dot ";
vec = "Vec(" expList ")"  "Vec" nat "(" expList ")";
vecExp = vec  signOp vecExp  exp "*" vecExp  vec linOp vecExp  vecExp ".norm(" int ")";
mat = "Mat" nat "x" nat "(" expList ")";
matExp = mat  signOp matExp  exp linOp matExp  vecExp linOp matExp  mat linOp matExp;
anyExp = exp  vecExp  matExp  derivative  invocation;
bindings = exp " to " exp  exp " to " exp "," bindings;
invocation = anyExp "(" bindings ")";
derivative = "d(" anyExp ") / d(" exp ")"  anyExp ".d(" exp ")"  anyExp ".d(" expList ")";
gradient = exp ".grad()";
Comparison
Unlike certain frameworks which simply wrap an existing AD library in a typesafe DSL, Kotlin∇ contains a fully shapesafe implementation of algorithmic differentiation, written in pure Kotlin. By doing so, it can leverage Kotlin language features such as typed functional programming, as well as interoperability with other languages on the JVM platform. Furthermore, it implements symbolic differentiation, which unlike Wengert tape or dualnumber based ADs, allows it to calculate derivatives of arbitrarily high order with zero extra engineering required. Further details can be found below.
Framework  Language  SD¹  AD²  HD³  DP⁴  FP⁵  TS⁶  SS⁷  DT⁸  MP⁹ 

Kotlin∇  Kotlin  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :construction: 
DiffSharp  F#  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x:  :x: 
TensorFlow.FSharp  F#  :x:  :x:  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x: 
Nexus  Scala  :x:  :heavy_check_mark:  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x: 
Lantern  Scala  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x:  :x: 
JAutoDiff  Java  :heavy_check_mark:  :heavy_check_mark:  :x:  :x:  :x:  :heavy_check_mark:  :x:  :x:  :x: 
Eclipse DL4J  Java  :x:  :construction:  :x:  :x:  :x:  :heavy_check_mark:  :x:  :x:  :x: 
Halide  C++  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :heavy_check_mark:  :x:  :x:  :x: 
Tensor Safe  Haskell  :x:  :x:  :x:  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x: 
HaskTorch  Haskell  :x:  :x:  :x:  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x: 
Grenade  Haskell  :x:  :x:  :x:  :x:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x: 
Stalin∇  Scheme  :x:  :heavy_check_mark:  :x:  :x:  :heavy_check_mark:  :x:  :x:  :x:  :x: 
Myia  Python  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x:  :x:  :construction: 
Autograd  Python  :x:  :heavy_check_mark:  :x:  :x:  :x:  :x:  :x:  :x:  :x: 
JAX  Python  :x:  :heavy_check_mark:  :x:  :heavy_check_mark:  :heavy_check_mark:  :x:  :x:  :x:  :construction: 
Tangent  Python  :x:  :heavy_check_mark:  :x:  :x:  :x:  :x:  :x:  :x:  :x: 
¹ Symbolic differentiation
² Automatic differentiation
³ Higher order differentiation
⁴ Differentiable programming
⁵ Functional programming
⁶ Compiletime type safety
⁷ Compiletime shape safety
⁸ Dependently Typed
⁹ Multiplatform
Citation
If you would like to cite Kotlin∇, please use the following bibtex
entry:
@misc{considine2019kotlingrad,
authors = {Considine, Breandan and Famelis, Michalis and Paull, Liam},
title = {Kotlin{\nabla}: A ShapeSafe e{DSL} for Differentiable Programming},
year = {2019},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/breandan/kotlingrad}},
}
References
To the author's knowledge, Kotlin∇ is the first AD implementation in native Kotlin. While the particular synthesis of these ideas (i.e. shapesafe, functional AD, using generic types) is unique, it has been influenced by a long list of prior work in AD. Below is a list of projects and publications that helped inspire this work.
Automatic Differentiation
 The Simple Essence of Automatic Differentiation
 ReverseMode AD in a Functional Framework: Lambda the Ultimate Backpropagator
 FirstClass Automatic Differentiation in Swift: A Manifesto
 AD and the danger of confusing infinitesimals
 Automatic differentiation in PyTorch
 Automatic differentiation in machine learning: a survey
 The (JAX) Autodiff Cookbook
 Automatic differentiation in ML: Where we are and where we should be going
 A Leibniz Notation for Automatic Differentiation
Differentiable Programming
 Neural Networks, Types, and Functional Programming
 Backpropagation with Continuation Callbacks: Foundations for Efficient and Expressive Differentiable Programming
 Backprop as Functor: A compositional perspective on supervised learning
 Demystifying Differentiable Programming: Shift/Reset the Penultimate Backpropagator
 Efficient Differentiable Programming in a Functional ArrayProcessing Language
 Operational Calculus for Differentiable Programming
 Differentiable Functional Programming
 Differentiable Programming for Image Processing and Deep Learning in Halide
 Software 2.0
Calculus
 The Matrix Calculus You Need For Deep Learning
 Backpropagation in matrix notation
 Matrix derivatives, from the Matrix Cookbook
 Div, Grad, Curl and All That
Computer Algebra
 A Design Proposal for an Object Oriented Algebraic Library
 On Using Generics for Implementing Algebraic Structures
 How to turn a scripting language into a domainspecific language for computer algebra
 Evaluation of a Java Computer Algebra System
 jalgebra: An abstract algebra library for Java
 Typesafe Abstractions for Tensor Operations
 Generalized Algebraic Data Types and ObjectOriented Programming
Computational Mathematics
 KMath  Kotlin mathematics extensions library
 An introduction to contextoriented programming in Kotlin
 COJAC  Numerical sniffing tool and Enriching number wrapper for Java
 chebfun  Allows representing functions as Chebyshev polynomials, for easy symbolic differentiation (or integration)
Neural Networks
 Hacker's Guide to Neural Networks
 Tricks from Deep Learning
 Practical Dependent Types in Haskell: TypeSafe Neural Networks
 A guide to convolutional arithmetic for deep learning
Type Systems
DomainSpecific Languages
Automated Testing
 DeepTest: Automated Testing of DeepNeuralNetworkdriven Autonomous Cars
 QuickCheck: A Lightweight Tool for Random Testing of Haskell Programs
 Learning to Discover Efficient Mathematical Identities
Libraries
 TensorFlow.FSharp: An eDSL for writing numerical models in F# with support for interactive tensor shapechecking
 Stalin∇, a brutally optimizing compiler for the VLAD language, a pure dialect of Scheme with firstclass automatic differentiation operators
 Autograd  Efficiently computes derivatives of NumPy code
 DiffSharp, a functional AD library implemented in the F# language
 Myia  SCT based AD, adapted from Pearlmutter & Siskind's "Reverse Mode AD in a functional framework"
 Nexus  Typesafe tensors, deep learning and probabilistic programming in Scala
 Tangent  "SourcetoSource Debuggable Derivatives in Pure Python"
 Grenade  composable, dependently typed, practical, and fast RNNs in Haskell
 Lantern  a framework in Scala, based on delimited continuations and multistage programming
Special Thanks
The following individuals have helped shape this project through their enthusiasm and thoughtful feedback. Please check out their work.