# 2023-09-20 Nonlinear solvers#

## Last time#

• Cost of linear solvers

• Assembling sparse matrices

• Rootfinding intro

## Today#

• Newton’s method for systems

• Bratu nonlinear PDE

• p-Laplacian

• Techniques for differentiation

using Plots
default(linewidth=3)
using LinearAlgebra
using SparseArrays

function vander(x, k=nothing)
if k === nothing
k = length(x)
end
V = ones(length(x), k)
for j = 2:k
V[:, j] = V[:, j-1] .* x
end
V
end

function fdstencil(source, target, k)
"kth derivative stencil from source to target"
x = source .- target
V = vander(x)
rhs = zero(x)'
rhs[k+1] = factorial(k)
rhs / V
end

function my_spy(A)
cmax = norm(vec(A), Inf)
s = max(1, ceil(120 / size(A, 1)))
spy(A, marker=(:square, s), c=:diverging_rainbow_bgymr_45_85_c67_n256, clims=(-cmax, cmax))
end

x = LinRange(-1, 1, n)
xstag = (x[1:end-1] + x[2:end]) / 2
rhs = forcing.(x)
kappa_stag = kappa.(xstag)
rows = [1, n]
cols = [1, n]
vals = [1., 1.] # diagonals entries (float)
rhs[[1,n]] .= 0 # boundary condition
for i in 2:n-1
flux_L = -kappa_stag[i-1] * fdstencil(x[i-1:i], xstag[i-1], 1) +
wind * (wind > 0 ? [1 0] : [0 1])
flux_R = -kappa_stag[i] * fdstencil(x[i:i+1], xstag[i], 1) +
wind * (wind > 0 ? [1 0] : [0 1])
weights = fdstencil(xstag[i-1:i], x[i], 1)
append!(rows, [i,i,i])
append!(cols, i-1:i+1)
append!(vals, weights[1] *  [flux_L..., 0] + weights[2] * [0, flux_R...])
end
L = sparse(rows, cols, vals)
x, L, rhs
end

advdiff_sparse (generic function with 1 method)


# Assembly cost#

n = 400; h = 2/n
kappa = 1
wind = 100
x, L, rhs = advdiff_sparse(n, x -> kappa, wind, one)
@show minimum(diff(x))
plot(x, L \ rhs, legend=:none, title="Pe_h $(wind*h/kappa)")  minimum(diff(x)) = 0.005012531328320691  n = 10000 @time advdiff_sparse(n, one, 1, one);   0.103166 seconds (757.67 k allocations: 44.279 MiB, 13.23% gc time, 51.47% compilation time)  # It's also possible to dynamically insert. # (But performance this way is generally poor.) A = spzeros(5, 5) A[1,1] = 3 A  5×5 SparseMatrixCSC{Float64, Int64} with 1 stored entry: 3.0 ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅ ⋅  ## Newton-Raphson Method# Much of numerical analysis reduces to Taylor series, the approximation $f(x) = f(x_0) + f'(x_0) (x-x_0) + f''(x_0) (x - x_0)^2 / 2 + \underbrace{\dotsb}_{O((x-x_0)^3)}$ centered on some reference point $$x_0$$. In numerical computation, it is exceedingly rare to look beyond the first-order approximation $\tilde f_{x_0}(x) = f(x_0) + f'(x_0)(x - x_0) .$ Since $$\tilde f_{x_0}(x)$$ is a linear function, we can explicitly compute the unique solution of $$\tilde f_{x_0}(x) = 0$$ as $x = x_0 - f(x_0) / f'(x_0) .$ This is Newton’s Method (aka Newton-Raphson or Newton-Raphson-Simpson) for finding the roots of differentiable functions. # An implementation# function newton(f, fp, x0; tol=1e-8, verbose=false) x = x0 for k in 1:100 # max number of iterations fx = f(x) fpx = fp(x) if verbose println("[$k] x=$x f(x)=$fx  f'(x)=\$fpx")
end
if abs(fx) < tol
return x, fx, k
end
x = x - fx / fpx
end
end
eps = 1
newton(x -> eps*(cos(x) - x), x -> eps*(-sin(x) - 1), 1; tol=1e-15, verbose=true)

[1] x=1  f(x)=-0.45969769413186023  f'(x)=-1.8414709848078965
[2] x=0.7503638678402439  f(x)=-0.018923073822117442  f'(x)=-1.6819049529414878
[3] x=0.7391128909113617  f(x)=-4.6455898990771516e-5  f'(x)=-1.6736325442243012
[4] x=0.739085133385284  f(x)=-2.847205804457076e-10  f'(x)=-1.6736120293089505
[5] x=0.7390851332151607  f(x)=0.0  f'(x)=-1.6736120291832148

(0.7390851332151607, 0.0, 5)


# That’s really fast!#

• 10 digits of accuracy in 4 iterations.

• How is this convergence test different from the one we used for bisection?

• How can this break down?

$x_{k+1} = x_k - f(x_k)/f'(x_k)$
newton(x -> cos(x) - x, x -> -sin(x) - 1, 3pi/2 + 0.001; verbose=true)

[1] x=4.71338898038469  f(x)=-4.712388980551356  f'(x)=-4.999999583255033e-7
[2] x=-9.424774033259554e6  f(x)=9.424773326519862e6  f'(x)=-0.292526319358701
[3] x=2.2793774188936606e7  f(x)=-2.2793773685912155e7  f'(x)=-0.1357278194645779
[4] x=-1.451436005941369e8  f(x)=1.4514359959518498e8  f'(x)=-0.954228185423985
[5] x=6.962155463251442e6  f(x)=-6.962156426217826e6  f'(x)=-1.2696214789257265
[6] x=1.4785081386104291e6  f(x)=-1.4785074154248144e6  f'(x)=-0.30934627585724994
[7] x=-3.3009494815843664e6  f(x)=3.3009485852473797e6  f'(x)=-0.5566265608245774
[8] x=2.6293255314466488e6  f(x)=-2.629324971066165e6  f'(x)=-1.8282353007342522
[9] x=1.191149072426204e6  f(x)=-1.1911491529894928e6  f'(x)=-1.996749495363329
[10] x=594604.9612317495  f(x)=-594605.3799884742  f'(x)=-1.9080984558468124
[11] x=282983.00159623957  f(x)=-282982.3813435215  f'(x)=-1.7844020434228265
[12] x=124396.32972676173  f(x)=-124396.58315146118  f'(x)=-1.9673551166508052
[13] x=61165.964197618225  f(x)=-61165.30029265608  f'(x)=-0.25218304295619787
[14] x=-181377.30745003175  f(x)=181378.1343703368  f'(x)=-0.4376808654587586
[15] x=233029.9666267637  f(x)=-233029.27716158552  f'(x)=-0.2756811695825775
[16] x=-612255.4676828291  f(x)=612254.4726584505  f'(x)=-0.9003682479628197
[17] x=67749.04607445002  f(x)=-67749.87112776139  f'(x)=-0.43494510585341273
[18] x=-88017.44081319592  f(x)=88016.59385697359  f'(x)=-0.46833736500554857
[19] x=99916.72894054736  f(x)=-99916.67435713837  f'(x)=-1.9985092145110306
[20] x=49921.125401844285  f(x)=-49920.78000649864  f'(x)=-1.9384572740428156
[21] x=24168.285411523193  f(x)=-24169.285325359368  f'(x)=-0.9868729201258163
[22] x=-322.492309156416  f(x)=322.03136204007603  f'(x)=-0.11257239397354768
[23] x=2538.1678462609693  f(x)=-2537.1962752925397  f'(x)=-0.7632514977779352
[24] x=-786.0264496940531  f(x)=786.835485634744  f'(x)=-0.4122408259582405
[25] x=1122.6527399659478  f(x)=-1123.1026223978365  f'(x)=-0.1069122117742306
[26] x=-9382.252207370868  f(x)=9382.366211539233  f'(x)=-0.006519728632929378
[27] x=1.4296908656753965e6  f(x)=-1.4296915427004495e6  f'(x)=-1.7359599701784154
[28] x=606116.8390499374  f(x)=-606117.8374607187  f'(x)=-1.056355229556066
[29] x=32334.629711663933  f(x)=-32334.418635880542  f'(x)=-1.977469699615
[30] x=15983.219299572242  f(x)=-15982.860786650146  f'(x)=-0.06647523616614692
[31] x=-224450.08652130357  f(x)=224449.5467361156  f'(x)=-0.15819720193163844
[32] x=1.1943458466274173e6  f(x)=-1.1943448867767288e6  f'(x)=-1.2805114186409385
[33] x=261636.5639219666  f(x)=-261636.54858054023  f'(x)=-0.00011768660659039476
[34] x=-2.2229016966364274e9  f(x)=2.2229016976220207e9  f'(x)=-0.8308681058553692
[35] x=4.524948939374385e8  f(x)=-4.524948931067944e8  f'(x)=-1.5568037952043454
[36] x=1.618385540051478e8  f(x)=-1.6183855340340227e8  f'(x)=-0.20131213176424612
[37] x=-6.420799778045952e8  f(x)=6.420799773342474e8  f'(x)=-0.1175188970268981
[38] x=4.821551774859711e9  f(x)=-4.821551775408336e9  f'(x)=-1.83606872070359
[39] x=2.195532540916934e9  f(x)=-2.195532541703071e9  f'(x)=-0.38194772670535804
[40] x=-3.5527214430549297e9  f(x)=3.552721443896677e9  f'(x)=-0.46012834477022624
[41] x=4.168431761864005e9  f(x)=-4.1684317613180594e9  f'(x)=-0.16217937446467123
[42] x=-2.1534169293772083e10  f(x)=2.1534169292778805e10  f'(x)=-0.8842466247514228
[43] x=2.818956508469387e9  f(x)=-2.818956508872447e9  f'(x)=-0.08482651138273156
[44] x=-3.0413065685081673e10  f(x)=3.0413065684447895e10  f'(x)=-0.22648534861494496
[45] x=1.0386964121157559e11  f(x)=-1.038696412105759e11  f'(x)=-1.0245010081930062
[46] x=2.4840492210132446e9  f(x)=-2.4840492218501883e9  f'(x)=-1.5472892202452475
[47] x=8.786291163889372e8  f(x)=-8.786291153951594e8  f'(x)=-1.1113803373531725
[48] x=8.805447162990355e7  f(x)=-8.805447250994958e7  f'(x)=-1.4748883914875157
[49] x=2.8352006671764433e7  f(x)=-2.8352005673152175e7  f'(x)=-1.0526645650645092
[50] x=1.4184452941999547e6  f(x)=-1.418446170272365e6  f'(x)=-0.5178204359576648
[51] x=-1.3208173377374457e6  f(x)=1.3208165601529954e6  f'(x)=-1.6287785163756425
[52] x=-509892.74186785135  f(x)=509893.6935340123  f'(x)=-1.3071343648749028
[53] x=-119807.60048086708  f(x)=119808.58486489067  f'(x)=-1.1760343549171952
[54] x=-17932.52824003411  f(x)=17933.47829843498  f'(x)=-0.6879278369794791
[55] x=8136.308248944799  f(x)=-8135.3938290237475  f'(x)=-0.5952331436689382
[56] x=-5531.267079401769  f(x)=5530.793621226314  f'(x)=-0.11918369900687953
[57] x=40874.35438823757  f(x)=-40874.970005261646  f'(x)=-1.7880454807060322
[58] x=18014.21439385574  f(x)=-18013.26582687264  f'(x)=-1.3165764971923002
[59] x=4332.316026926583  f(x)=-4333.314241970898  f'(x)=-0.9402779328797781
[60] x=-276.2301171898789  f(x)=277.20377530272515  f'(x)=-1.2280128928125196
[61] x=-50.496513800874595  f(x)=51.46994455496929  f'(x)=-0.7710184134419699
[62] x=16.25927783475946  f(x)=-17.111114513115744  f'(x)=-0.476192522574438
[63] x=-19.673908223282744  f(x)=20.3529408118591  f'(x)=-0.2658918719621115
[64] x=56.8720225028709  f(x)=-55.92384771275381  f'(x)=-1.317749220906071
[65] x=14.43310709656837  f(x)=-14.724746360782682  f'(x)=-1.9565283788620889
[66] x=6.907151163669226  f(x)=-6.095583403367469  f'(x)=-1.584258307975837
[67] x=3.059556757706712  f(x)=-4.056193700318225  f'(x)=-1.0819439114393261
[68] x=-0.6894302807454107  f(x)=1.4610388180463496  f'(x)=-0.3639023147626572
[69] x=3.3254900942438557  f(x)=-4.308628559355858  f'(x)=-0.8171373236080834
[70] x=-1.947342800392553  f(x)=1.5796317043954309  f'(x)=-0.07005992134944783
[71] x=20.5995238527468  f(x)=-20.77773835306402  f'(x)=-1.9839916625036431
[72] x=10.126829463071727  f(x)=-10.890348427734228  f'(x)=-0.3542145940014567
[73] x=-20.618229075339144  f(x)=20.421641041658564  f'(x)=-0.019513822120065027
[74] x=1025.903611495303  f(x)=-1026.07635079201  f'(x)=-1.984967580874203
[75] x=508.98013089109384  f(x)=-507.9810178496695  f'(x)=-1.0421085556139327
[76] x=21.52513868982072  f(x)=-22.418506944115425  f'(x)=-1.4493252298930428
[77] x=6.056901139184923  f(x)=-5.082394341847138  f'(x)=-0.7756420227795489
[78] x=-0.4955988448078781  f(x)=1.375282926638113  f'(x)=-0.5244414692443249
[79] x=2.1267775447247375  f(x)=-2.654554528677306  f'(x)=-1.8493829850014245
[80] x=0.6914044769470746  f(x)=0.07894677588289933  f'(x)=-1.6376197513120885
[81] x=0.7396127229819957  f(x)=-0.0008830834262884002  f'(x)=-1.6740018691679532
[82] x=0.7390851946425396  f(x)=-1.028056018093082e-7  f'(x)=-1.673612074583276
[83] x=0.7390851332151614  f(x)=-1.3322676295501878e-15  f'(x)=-1.6736120291832153

(0.7390851332151614, -1.3322676295501878e-15, 83)


# Convergence of fixed-point iteration#

Consider the iteration

$x_{k+1} = g(x_k)$
where $$g$$ is a continuously differentiable function. Suppose that there exists a fixed point $$x_* = g(x_*)$$. By the mean value theorem, we have that
$x_{k+1} - x_* = g(x_k) - g(x_*) = g'(c_k) (x_k - x_*)$
for some $$c_i$$ between $$x_k$$ and $$x_*$$.

Taking absolute values,

$|e_{k+1}| = |g'(c_k)| |e_k|,$
which converges to zero if $$|g'(c_k)| < 1$$.

## Exercise#

• Write Newton’s method for $$f(x) = 0$$ from initial guess $$x_0$$ as a fixed point method.

• Suppose the Newton iterates $$x_k$$ converge to a simple root $$x_*$$, $$x_k \to x_*$$. What is $$\lvert g'(x_*) \rvert$$ for Newton’s method?

# Newton for systems of equations#

Let $$\mathbf u \in \mathbb R^n$$ and consider the function $$\mathbf f(\mathbf u) \in \mathbb R^n$$. Then

$\mathbf f(\mathbf u) = \mathbf f(\mathbf u_0) + \underbrace{\mathbf f'(\mathbf u_0)}_{n\times n} \underbrace{\mathbf \delta \mathbf u}_{n\times 1} + \frac 1 2 \underbrace{\mathbf f''(\mathbf u_0)}_{n\times n \times n} {(\delta \mathbf u)^2}_{n\times n} + O(|\delta\mathbf u|^3).$

We drop all but the first two terms and name the Jacobian matrix $$J(\mathbf u) = \mathbf f'(\mathbf u)$$,

$\mathbf f(\mathbf u) \approx \mathbf f(\mathbf u_0) + J \delta \mathbf u .$
Solving the right hand side equal to zero yields

(13)#\begin{align} J \delta \mathbf u &= -\mathbf f(\mathbf u_k) \\ \mathbf u_{k+1} &= \mathbf u_k + \delta \mathbf u \end{align}

# Newton in code#

function newton(residual, jacobian, u0; maxits=20)
u = u0
uhist = [copy(u)]
normhist = []
for k in 1:maxits
f = residual(u)
push!(normhist, norm(f))
J = jacobian(u)
delta_u = - J \ f
u += delta_u
push!(uhist, copy(u))
end
uhist, normhist
end

newton (generic function with 1 method)

function residual(u)
x, y = u
[x^2 + y^2 - 1, x^2 - y]
end
function jacobian(u)
x, y = u
[2x 2y; 2x -1]
end

uhist, normhist = newton(residual, jacobian, [0.1, 2])
plot(normhist, marker=:circle, yscale=:log10)


# Plotting the trajectory#

xy = hcat(uhist...)
plot(xy[1,:], xy[2,:], marker=:circle)
circle = exp.(1im*LinRange(0, 2*pi, 50))
plot!(real(circle), imag(circle))
plot!(x -> x^2, xlims=(-2, 5), ylims=(-2, 4), axes=:equal, legend=:bottomright)


# Bratu problem#

$-(u_x)_x - \lambda e^{u}= 0$
function bratu_f(u; lambda=.5)
n = length(u)
h = 2 / (n - 1)
weights = -fdstencil([-h, 0, h], 0, 2)
u = copy(u)
f = copy(u)
u[1] = 0
u[n] = 1
f[n] -= 1
for i in 2:n-1
f[i] = weights * u[i-1:i+1] - lambda * exp(u[i])
end
f
end

bratu_f (generic function with 1 method)

function bratu_J(u; lambda=.5)
n = length(u)
h = 2 / (n - 1)
weights = -fdstencil([-h, 0, h], 0, 2)
rows = [1, n]
cols = [1, n]
vals = [1., 1.] # diagonals entries (float)
for i in 2:n-1
append!(rows, [i,i,i])
append!(cols, i-1:i+1)
append!(vals, weights + [0 -lambda*exp(u[i]) 0])
end
sparse(rows, cols, vals)
end

bratu_J (generic function with 1 method)


# Solving Bratu#

n = 50
x = collect(LinRange(-1., 1, n))
u0 = (1 .+ x) ./ 2
uhist, normhist = newton(bratu_f, bratu_J, u0);

plot(normhist, marker=:circle, yscale=:log10)

plot(x, uhist[end], marker=:circle)


# $$p$$-Laplacian#

$-(|u_x|^{p-2} u_x)_x = 0$
plaplace_p = 1.5
plaplace_forcing = 0.
function plaplace_f(u)
n = length(u)
h = 2 / (n - 1)
u = copy(u)
f = copy(u)
u[[1, n]] = [0, 1]
f[n] -= 1
for i in 2:n-1
u_xstag = diff(u[i-1:i+1]) / h
kappa_stag = abs.(u_xstag) .^ (plaplace_p - 2)
f[i] = [1/h, -1/h]' *
(kappa_stag .* u_xstag) - plaplace_forcing
end
f
end

plaplace_f (generic function with 1 method)

function plaplace_J(u)
n = length(u)
h = 2 / (n - 1)
u = copy(u)
u[[1, n]] = [0, 1]
rows = [1, n]
cols = [1, n]
vals = [1., 1.] # diagonals entries (float)
for i in 2:n-1
js = i-1:i+1
u_xstag = diff(u[js]) / h
kappa_stag = abs.(u_xstag) .^ (plaplace_p - 2)
fi = [h, -h]' * (kappa_stag .* u_xstag)
fi_ujs = [-kappa_stag[1]/h^2,
sum(kappa_stag)/h^2,
-kappa_stag[2]/h^2]
mask = 1 .< js .< n
end
sparse(rows, cols, vals)
end

plaplace_J (generic function with 1 method)


# Try solving#

n = 20
x = collect(LinRange(-1., 1, n))
u0 = (1 .+ x)

plaplace_p = 1.3 # try different values
plaplace_forcing = 1
uhist, normhist = newton(plaplace_f, plaplace_J, u0; maxits=20);
plot(normhist, marker=:circle, yscale=:log10, ylims=(1e-10, 1e3))

plot(x, uhist[1:5:end], marker=:auto, legend=:bottomright)


# Using Zygote to differentiate#

using Zygote


WARNING: using Zygote.jacobian in module Main conflicts with an existing identifier.

(6.0,)

function plaplace_fpoint(u, h)
u_xstag = diff(u) / h
kappa_stag = abs.(u_xstag) .^ (plaplace_p - 2)
[1/h, -1/h]' * (kappa_stag .* u_xstag)
end

gradient(u -> plaplace_fpoint(u, .1), [0., .7, 1.])

([-7.683385553661419, 21.58727725682051, -13.903891703159092],)


# p-Laplacian with Zygote#

function plaplace_fzygote(u)
n = length(u)
h = 2 / (n - 1)
u = copy(u)
f = copy(u)
u[[1, n]] = [0, 1]
f[n] -= 1
for i in 2:n-1
f[i] = plaplace_fpoint(u[i-1:i+1], h) - plaplace_forcing
end
f
end

plaplace_fzygote (generic function with 1 method)

function plaplace_Jzygote(u)
n = length(u)
h = 2 / (n - 1)
u = copy(u)
u[[1, n]] = [0, 1]
rows = [1, n]
cols = [1, n]
vals = [1., 1.] # diagonals entries (float)
for i in 2:n-1
js = i-1:i+1
fi_ujs = gradient(ujs -> plaplace_fpoint(ujs, h), u[js])[1]
mask = 1 .< js .< n
end
sparse(rows, cols, vals)
end

plaplace_Jzygote (generic function with 1 method)


# Test it out#

plaplace_p = 1.5
plaplace_forcing = .1
uhist, normhist = newton(plaplace_fzygote, plaplace_Jzygote, u0; maxits=20);
plot(normhist, marker=:auto, yscale=:log10)

plot(x, uhist, marker=:auto, legend=:topleft)

• Can you fix the plaplace_J to differentiate correctly (analytically)?

• What is causing Newton to diverge?

• How might we fix or work around it?

# A model problem for Newton divergence#

k = 1. # what happens as this is changed
fk(x) = tanh.(k*x)
Jk(x) = reshape(k / cosh.(k*x).^2, 1, 1)
xhist, normhist = newton(fk, Jk, [1.], maxits=10)
plot(xhist, fk.(xhist), marker=:circle, legend=:bottomright)
plot!(fk, xlims=(-2, 2), linewidth=1)


import Symbolics: Differential, expand_derivatives, @variables
@variables x
Dx = Differential(x)
y = tanh(k*x)
Dx(y)

$$$\frac{\mathrm{d} \tanh\left( x \right)}{\mathrm{d}x}$$$
expand_derivatives(Dx(y))

$$$1 - \tanh^{2}\left( x \right)$$$

y = x
for _ in 1:1
y = cos(y^pi) * log(y)
end
expand_derivatives(Dx(y))

$$$\frac{\cos\left( x^{\pi} \right)}{x} - 3.1416 x^{2.1416} \log\left( x \right) \sin\left( x^{\pi} \right)$$$
• The size of these expressions can grow exponentially.

# Hand coding derivatives: it’s all chain rule and associativity#

$df = f'(x) dx$
function f(x)
y = x
for _ in 1:2
a = y^pi
b = cos(a)
c = log(y)
y = b * c
end
y
end


(-1.5346823414986814, (-34.03241959914049,))

function df(x, dx)
y = x
dy = dx
for _ in 1:2
a = y^pi
da = pi * y^(pi-1) * dy
b = cos(a)
db = -sin(a) * da
c = log(y)
dc = 1/y * dy
y = b * c
dy = db * c + b * dc
end
dy
end

df(1.9, 1)

-34.03241959914048


# We can go the other way#

We can differentiate a composition $$h(g(f(x)))$$ as

(14)#\begin{align} \operatorname{d} h &= h' \operatorname{d} g \\ \operatorname{d} g &= g' \operatorname{d} f \\ \operatorname{d} f &= f' \operatorname{d} x. \end{align}

What we’ve done above is called “forward mode”, and amounts to placing the parentheses in the chain rule like

$\operatorname d h = \frac{dh}{dg} \left(\frac{dg}{df} \left(\frac{df}{dx} \operatorname d x \right) \right) .$

The expression means the same thing if we rearrange the parentheses,

$\operatorname d h = \left( \left( \left( \frac{dh}{dg} \right) \frac{dg}{df} \right) \frac{df}{dx} \right) \operatorname d x$

which we can compute with in reverse order via

$\underbrace{\bar x}_{\frac{dh}{dx}} = \underbrace{\bar g \frac{dg}{df}}_{\bar f} \frac{df}{dx} .$

# A reverse mode example#

$\underbrace{\bar x}_{\frac{dh}{dx}} = \underbrace{\bar g \frac{dg}{df}}_{\bar f} \frac{df}{dx} .$
function g(x)
a = x^pi
b = cos(a)
c = log(x)
y = b * c
y
end

(-0.32484122107701546, (-1.2559761698835525,))

function gback(x, y_)
a = x^pi
b = cos(a)
c = log(x)
y = b * c
# backward pass
c_ = y_ * b
b_ = c * y_
a_ = -sin(a) * b_
x_ = 1/x * c_ + pi * x^(pi-1) * a_
x_
end
gback(1.4, 1)

-1.2559761698835525


# Kinds of algorithmic differentation#

• Source transformation: Fortran code in, Fortran code out

• Duplicates compiler features, usually incomplete language coverage

• Produces efficient code

• Hard to vectorize

• Loops are effectively unrolled/inefficient

• Just-in-time compilation: tightly coupled with compiler

• JIT lag

• Needs dynamic language features (JAX) or tight integration with compiler (Zygote, Enzyme)

• Some sharp bits

# How does Zygote work?#

h1(x) = x^3 + 3*x
h2(x) = ((x * x)  + 3) * x
@code_llvm h1(4.)

;  @ In[32]:1 within h1
define double @julia_h1_6764(double %0) #0 {
top:
; ┌ @ intfuncs.jl:320 within literal_pow
; │┌ @ operators.jl:578 within * @ float.jl:410
%1 = fmul double %0, %0
%2 = fmul double %1, %0
; └└
; ┌ @ promotion.jl:411 within * @ float.jl:410
%3 = fmul double %0, 3.000000e+00
; └
; ┌ @ float.jl:408 within +
%4 = fadd double %3, %2
; └
ret double %4
}

@code_llvm gradient(h1, 4.)

;  @ /home/jed/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:95 within gradient
define [1 x double] @julia_gradient_6807(double %0) #0 {
top:
;  @ /home/jed/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:97 within gradient
; ┌ @ /home/jed/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:45 within #75
; │┌ @ In[32]:1 within Pullback
; ││┌ @ /home/jed/.julia/packages/Zygote/4rucm/src/compiler/chainrules.jl:211 within ZBack
; │││┌ @ /home/jed/.julia/packages/Zygote/4rucm/src/lib/number.jl:12 within literal_pow_pullback
; ││││┌ @ intfuncs.jl:319 within literal_pow
; │││││┌ @ float.jl:410 within *
%1 = fmul double %0, %0
; ││││└└
; ││││┌ @ promotion.jl:411 within * @ float.jl:410
%2 = fmul double %1, 3.000000e+00
; │└└└└
; │┌ @ /home/jed/.julia/packages/Zygote/4rucm/src/lib/lib.jl:17 within accum
; ││┌ @ float.jl:408 within +
%3 = fadd double %2, 3.000000e+00
; └└└
;  @ /home/jed/.julia/packages/Zygote/4rucm/src/compiler/interface.jl:98 within gradient
%.fca.0.insert = insertvalue [1 x double] zeroinitializer, double %3, 0
ret [1 x double] %.fca.0.insert
}


# Forward or reverse?#

It all depends on the shape.

$\operatorname d h = \frac{dh}{dg} \left(\frac{dg}{df} \left(\frac{df}{dx} \operatorname d x \right) \right) .$
$\operatorname d h = \left( \left( \left( \frac{dh}{dg} \right) \frac{dg}{df} \right) \frac{df}{dx} \right) \operatorname d x$
• One input, many outputs: use forward mode

• “One input” can be looking in one direction

• Many inputs, one output: use reverse mode

• Will need to traverse execution backwards (“tape”)

• Hierarchical checkpointing

• About square? Forward is usually a bit more efficient.