{
"cells": [
{
"cell_type": "markdown",
"id": "3471b94b",
"metadata": {
"cell_style": "center",
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# 2022-09-16 Differentiation"
]
},
{
"cell_type": "markdown",
"id": "4159be0e",
"metadata": {
"cell_style": "split",
"slideshow": {
"slide_type": ""
}
},
"source": [
"## Last time\n",
"\n",
"* Newton's method for systems\n",
"* Bratu nonlinear PDE\n",
"* p-Laplacian\n",
"\n",
"## Today\n",
"\n",
"* NLsolve Newton solver library\n",
" * p-Laplacian robustness\n",
" * diagnostics\n",
"* Algorithmic differentiation via Zygote\n",
"* Symbolic differentiation\n",
"* Structured by-hand differentiation\n",
"* Concept of PDE-based inference (inverse problems)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e82d841d",
"metadata": {
"hideOutput": true,
"slideshow": {
"slide_type": "skip"
}
},
"outputs": [
{
"data": {
"text/plain": [
"newton (generic function with 1 method)"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Plots\n",
"default(linewidth=3)\n",
"using LinearAlgebra\n",
"using SparseArrays\n",
"\n",
"function vander(x, k=nothing)\n",
" if k === nothing\n",
" k = length(x)\n",
" end\n",
" V = ones(length(x), k)\n",
" for j = 2:k\n",
" V[:, j] = V[:, j-1] .* x\n",
" end\n",
" V\n",
"end\n",
"\n",
"function fdstencil(source, target, k)\n",
" \"kth derivative stencil from source to target\"\n",
" x = source .- target\n",
" V = vander(x)\n",
" rhs = zero(x)'\n",
" rhs[k+1] = factorial(k)\n",
" rhs / V\n",
"end\n",
"\n",
"function my_spy(A)\n",
" cmax = norm(vec(A), Inf)\n",
" s = max(1, ceil(120 / size(A, 1)))\n",
" spy(A, marker=(:square, s), c=:diverging_rainbow_bgymr_45_85_c67_n256, clims=(-cmax, cmax))\n",
"end\n",
"\n",
"\n",
"function advdiff_sparse(n, kappa, wind, forcing)\n",
" x = LinRange(-1, 1, n)\n",
" xstag = (x[1:end-1] + x[2:end]) / 2\n",
" rhs = forcing.(x)\n",
" kappa_stag = kappa.(xstag)\n",
" rows = [1, n]\n",
" cols = [1, n]\n",
" vals = [1., 1.] # diagonals entries (float)\n",
" rhs[[1,n]] .= 0 # boundary condition\n",
" for i in 2:n-1\n",
" flux_L = -kappa_stag[i-1] * fdstencil(x[i-1:i], xstag[i-1], 1) +\n",
" wind * (wind > 0 ? [1 0] : [0 1])\n",
" flux_R = -kappa_stag[i] * fdstencil(x[i:i+1], xstag[i], 1) +\n",
" wind * (wind > 0 ? [1 0] : [0 1])\n",
" weights = fdstencil(xstag[i-1:i], x[i], 1)\n",
" append!(rows, [i,i,i])\n",
" append!(cols, i-1:i+1)\n",
" append!(vals, weights[1] * [flux_L..., 0] + weights[2] * [0, flux_R...])\n",
" end\n",
" L = sparse(rows, cols, vals)\n",
" x, L, rhs\n",
"end\n",
"\n",
"function newton(residual, jacobian, u0; maxits=20)\n",
" u = u0\n",
" uhist = [copy(u)]\n",
" normhist = []\n",
" for k in 1:maxits\n",
" f = residual(u)\n",
" push!(normhist, norm(f))\n",
" J = jacobian(u)\n",
" delta_u = - J \\ f\n",
" u += delta_u\n",
" push!(uhist, copy(u))\n",
" end\n",
" uhist, normhist\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "f7fc5a18",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# $p$-Laplacian\n",
"\n",
"$$ -\\big(|u_x|^{p-2} u_x\\big)_x = 0 $$"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "93077c39",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"plaplace_f (generic function with 1 method)"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plaplace_p = 1.5\n",
"plaplace_forcing = 0.\n",
"function plaplace_f(u)\n",
" n = length(u)\n",
" h = 2 / (n - 1)\n",
" u = copy(u)\n",
" f = copy(u)\n",
" u[[1, n]] = [0, 1]\n",
" f[n] -= 1\n",
" for i in 2:n-1\n",
" u_xstag = diff(u[i-1:i+1]) / h\n",
" kappa_stag = abs.(u_xstag) .^ (plaplace_p - 2)\n",
" f[i] = [1/h, -1/h]' *\n",
" (kappa_stag .* u_xstag) - plaplace_forcing\n",
" end\n",
" f\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "262d0ca1",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"plaplace_J (generic function with 1 method)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function plaplace_J(u)\n",
" n = length(u)\n",
" h = 2 / (n - 1)\n",
" u = copy(u)\n",
" u[[1, n]] = [0, 1]\n",
" rows = [1, n]\n",
" cols = [1, n]\n",
" vals = [1., 1.] # diagonals entries (float)\n",
" for i in 2:n-1\n",
" js = i-1:i+1\n",
" u_xstag = diff(u[js]) / h\n",
" kappa_stag = abs.(u_xstag) .^ (plaplace_p - 2)\n",
" fi = [h, -h]' * (kappa_stag .* u_xstag)\n",
" fi_ujs = [-kappa_stag[1]/h^2,\n",
" sum(kappa_stag)/h^2,\n",
" -kappa_stag[2]/h^2]\n",
" mask = 1 .< js .< n\n",
" append!(rows, [i,i,i][mask])\n",
" append!(cols, js[mask])\n",
" append!(vals, fi_ujs[mask])\n",
" end\n",
" sparse(rows, cols, vals)\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "626c75c1",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Try solving"
]
},
{
"cell_type": "code",
"execution_count": 106,
"id": "9205e9b1",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 106,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"n = 20\n",
"x = collect(LinRange(-1., 1, n))\n",
"u0 = (1 .+ x) / 2\n",
"\n",
"plaplace_p = 1.3 # try different values\n",
"plaplace_forcing = 1\n",
"uhist, normhist = newton(plaplace_f, plaplace_J, u0; maxits=20);\n",
"plot(normhist, marker=:circle, yscale=:log10, ylims=(1e-10, 1e3))"
]
},
{
"cell_type": "code",
"execution_count": 107,
"id": "c11c47e3",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 107,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(x, uhist[1:5:end], marker=:auto, legend=:bottomright)"
]
},
{
"cell_type": "markdown",
"id": "536e54f6",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"## What's wrong?"
]
},
{
"cell_type": "markdown",
"id": "046e6066",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Using Zygote to differentiate"
]
},
{
"cell_type": "code",
"execution_count": 108,
"id": "ed016e68",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"(6.0,)"
]
},
"execution_count": 108,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using Zygote\n",
"\n",
"gradient(x -> x^2, 3)"
]
},
{
"cell_type": "code",
"execution_count": 109,
"id": "9b9a9cb8",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"([-7.683385553661419, 21.58727725682051, -13.903891703159092],)"
]
},
"execution_count": 109,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function plaplace_fpoint(u, h)\n",
" u_xstag = diff(u) / h\n",
" kappa_stag = abs.(u_xstag) .^ (plaplace_p - 2)\n",
" [1/h, -1/h]' * (kappa_stag .* u_xstag)\n",
"end\n",
"\n",
"gradient(u -> plaplace_fpoint(u, .1), [0., .7, 1.])"
]
},
{
"cell_type": "markdown",
"id": "a1f099a3",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# p-Laplacian with Zygote"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "15656b6a",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"plaplace_fzygote (generic function with 1 method)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function plaplace_fzygote(u)\n",
" n = length(u)\n",
" h = 2 / (n - 1)\n",
" u = copy(u)\n",
" f = copy(u)\n",
" u[[1, n]] = [0, 1]\n",
" f[n] -= 1\n",
" for i in 2:n-1\n",
" f[i] = plaplace_fpoint(u[i-1:i+1], h) - plaplace_forcing\n",
" end\n",
" f\n",
"end"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "3165e987",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"plaplace_Jzygote (generic function with 1 method)"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function plaplace_Jzygote(u)\n",
" n = length(u)\n",
" h = 2 / (n - 1)\n",
" u = copy(u)\n",
" u[[1, n]] = [0, 1]\n",
" rows = [1, n]\n",
" cols = [1, n]\n",
" vals = [1., 1.] # diagonals entries (float)\n",
" for i in 2:n-1\n",
" js = i-1:i+1\n",
" fi_ujs = gradient(ujs -> plaplace_fpoint(ujs, h), u[js])[1]\n",
" mask = 1 .< js .< n\n",
" append!(rows, [i,i,i][mask])\n",
" append!(cols, js[mask])\n",
" append!(vals, fi_ujs[mask])\n",
" end\n",
" sparse(rows, cols, vals)\n",
"end"
]
},
{
"cell_type": "markdown",
"id": "cb99e8bd",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Experiment with parameters"
]
},
{
"cell_type": "code",
"execution_count": 117,
"id": "cca878c9",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 117,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plaplace_p = 1.4\n",
"plaplace_forcing = 1\n",
"u0 = (x .+ 1)\n",
"uhist, normhist = newton(plaplace_fzygote, plaplace_Jzygote, u0; maxits=10);\n",
"plot(normhist, marker=:auto, yscale=:log10)"
]
},
{
"cell_type": "code",
"execution_count": 118,
"id": "84ef8973",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 118,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(x, uhist, marker=:auto, legend=:topleft)"
]
},
{
"cell_type": "markdown",
"id": "c525279a",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* Can you fix the `plaplace_J` to differentiate correctly (analytically)?\n",
"* What is causing Newton to diverge?\n",
"* How might we fix or work around it?"
]
},
{
"cell_type": "markdown",
"id": "fccfb912",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# A model problem for Newton divergence; compare NLsolve"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "6e0a2b3a",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 122,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"k = 1. # what happens as this is changed\n",
"fk(x) = tanh.(k*x)\n",
"Jk(x) = reshape(k / cosh.(k*x).^2, 1, 1)\n",
"xhist, normhist = newton(fk, Jk, [2.], maxits=10)\n",
"plot(xhist, fk.(xhist), marker=:circle, legend=:bottomright)\n",
"plot!(fk, xlims=(-20, 2), linewidth=1)"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "4c6a7675",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Iter f(x) inf-norm Step 2-norm \n",
"------ -------------- --------------\n",
" 0 1.499909e+00 NaN\n",
" 1 5.000000e-01 5.000000e+00\n",
" 2 3.788284e-02 5.000000e-01\n",
" 3 8.529212e-04 4.816956e-02\n",
" 4 4.836986e-07 1.135938e-03\n",
" 5 1.559863e-13 6.449311e-07\n"
]
},
{
"data": {
"text/plain": [
"Results of Nonlinear Solver Algorithm\n",
" * Algorithm: Trust-region with dogleg and autoscaling\n",
" * Starting Point: [5.0]\n",
" * Zero: [-0.5493061443338468]\n",
" * Inf-norm of residuals: 0.000000\n",
" * Iterations: 5\n",
" * Convergence: true\n",
" * |x - x'| < 0.0e+00: false\n",
" * |f(x)| < 1.0e-08: true\n",
" * Function Calls (f): 6\n",
" * Jacobian Calls (df/dx): 6"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"using NLsolve\n",
"\n",
"nlsolve(x -> fk(x) .+ .5, Jk, [5.], show_trace=true)"
]
},
{
"cell_type": "markdown",
"id": "3b075a71",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# [NLsolve](https://juliapackages.com/p/nlsolve) for p-Laplacian"
]
},
{
"cell_type": "code",
"execution_count": 133,
"id": "a178b039",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 133,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plaplace_p = 1.5\n",
"plaplace_forcing = 1\n",
"u0 = (x .+ 1) / 2\n",
"sol = nlsolve(plaplace_fzygote, plaplace_Jzygote, u0, \n",
" store_trace=true)\n",
"plaplace_p = 1.3\n",
"sol = nlsolve(plaplace_fzygote, plaplace_Jzygote, sol.zero, \n",
" store_trace=true)\n",
"fnorms = [sol.trace[i].fnorm for i in 1:sol.iterations]\n",
"plot(fnorms, marker=:circle, yscale=:log10)"
]
},
{
"cell_type": "code",
"execution_count": 134,
"id": "70fadcef",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n"
]
},
"execution_count": 134,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"plot(x, [sol.initial_x sol.zero], label=[\"initial\" \"solution\"])"
]
},
{
"cell_type": "markdown",
"id": "35e00ea4",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# What about symbolic differentiation?"
]
},
{
"cell_type": "code",
"execution_count": 136,
"id": "a525f7ef",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/latex": [
"\\begin{equation}\n",
"\\frac{dtanh(x)}{dx}\n",
"\\end{equation}\n"
],
"text/plain": [
"Differential(x)(tanh(x))"
]
},
"execution_count": 136,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import Symbolics: Differential, expand_derivatives, @variables\n",
"@variables x\n",
"Dx = Differential(x)\n",
"y = tanh(k*x)\n",
"Dx(y)"
]
},
{
"cell_type": "code",
"execution_count": 137,
"id": "dd580f10",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/latex": [
"\\begin{equation}\n",
"1 - \\tanh^{2}\\left( x \\right)\n",
"\\end{equation}\n"
],
"text/plain": [
"1 - (tanh(x)^2)"
]
},
"execution_count": 137,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"expand_derivatives(Dx(y))"
]
},
{
"cell_type": "markdown",
"id": "476bdc75",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Cool, what about composition of functions"
]
},
{
"cell_type": "code",
"execution_count": 140,
"id": "364e7dc8",
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"\\begin{equation}\n",
"\\frac{\\left( \\frac{\\cos\\left( x^{\\pi} \\right)}{x} - 3.141592653589793 x^{2.141592653589793} \\log\\left( x \\right) \\sin\\left( x^{\\pi} \\right) \\right) \\cos\\left( \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} \\right)}{\\cos\\left( x^{\\pi} \\right) \\log\\left( x \\right)} - \\left( \\frac{3.141592653589793 \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{2.141592653589793}}{x} - 9.869604401089358 \\cos^{2.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} x^{2.141592653589793} \\sin\\left( x^{\\pi} \\right) \\right) \\sin\\left( \\cos^{3.141592653589793}\\left( x^{\\pi} \\right) \\left( \\log\\left( x \\right) \\right)^{3.141592653589793} \\right) \\log\\left( \\log\\left( x \\right) \\cos\\left( x^{\\pi} \\right) \\right)\n",
"\\end{equation}\n"
],
"text/plain": [
"((x^-1)*cos(x^π) - 3.141592653589793(x^2.141592653589793)*log(x)*sin(x^π))*(log(x)^-1)*(cos(x^π)^-1)*cos((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793)) - (3.141592653589793(x^-1)*(log(x)^2.141592653589793)*(cos(x^π)^3.141592653589793) - 9.869604401089358(x^2.141592653589793)*(log(x)^3.141592653589793)*(cos(x^π)^2.141592653589793)*sin(x^π))*sin((log(x)^3.141592653589793)*(cos(x^π)^3.141592653589793))*log(log(x)*cos(x^π))"
]
},
"execution_count": 140,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = x\n",
"for _ in 1:2\n",
" y = cos(y^pi) * log(y)\n",
"end\n",
"expand_derivatives(Dx(y))"
]
},
{
"cell_type": "markdown",
"id": "139d9cc0",
"metadata": {
"slideshow": {
"slide_type": "fragment"
}
},
"source": [
"* The size of these expressions can grow **exponentially**."
]
},
{
"cell_type": "markdown",
"id": "af45c2a5",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Hand coding derivatives: it's all chain rule and associativity\n",
"$$ df = f'(x) dx $$"
]
},
{
"cell_type": "code",
"execution_count": 141,
"id": "dcf1032a",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"(-1.5346823414986814, (-34.03241959914049,))"
]
},
"execution_count": 141,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function f(x)\n",
" y = x\n",
" for _ in 1:2\n",
" a = y^pi\n",
" b = cos(a)\n",
" c = log(y)\n",
" y = b * c\n",
" end\n",
" y\n",
"end\n",
"\n",
"f(1.9), gradient(f, 1.9)"
]
},
{
"cell_type": "code",
"execution_count": 142,
"id": "07ec7ef4",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"-34.03241959914048"
]
},
"execution_count": 142,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function df(x, dx)\n",
" y = x\n",
" dy = dx\n",
" for _ in 1:2\n",
" a = y^pi\n",
" da = pi * y^(pi-1) * dy\n",
" b = cos(a)\n",
" db = -sin(a) * da\n",
" c = log(y)\n",
" dc = 1/y * dy\n",
" y = b * c\n",
" dy = db * c + b * dc\n",
" end\n",
" dy\n",
"end\n",
"\n",
"df(1.9, 1)"
]
},
{
"cell_type": "markdown",
"id": "7c0de5ea",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# We can go the other way\n",
"\n",
"We can differentiate a composition $h(g(f(x)))$ as\n",
"\n",
"\\begin{align}\n",
" \\operatorname{d} h &= h' \\operatorname{d} g \\\\\n",
" \\operatorname{d} g &= g' \\operatorname{d} f \\\\\n",
" \\operatorname{d} f &= f' \\operatorname{d} x.\n",
"\\end{align}\n",
"\n",
"What we've done above is called \"forward mode\", and amounts to placing the parentheses in the chain rule like\n",
"\n",
"$$ \\operatorname d h = \\frac{dh}{dg} \\left(\\frac{dg}{df} \\left(\\frac{df}{dx} \\operatorname d x \\right) \\right) .$$\n",
"\n",
"The expression means the same thing if we rearrange the parentheses,\n",
"\n",
"$$ \\operatorname d h = \\left( \\left( \\left( \\frac{dh}{dg} \\right) \\frac{dg}{df} \\right) \\frac{df}{dx} \\right) \\operatorname d x $$\n",
"\n",
"which we can compute with in reverse order via\n",
"\n",
"$$ \\underbrace{\\bar x}_{\\frac{dh}{dx}} = \\underbrace{\\bar g \\frac{dg}{df}}_{\\bar f} \\frac{df}{dx} .$$"
]
},
{
"cell_type": "markdown",
"id": "163a414f",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# A reverse mode example\n",
"\n",
"$$ \\underbrace{\\bar x}_{\\frac{dh}{dx}} = \\underbrace{\\bar g \\frac{dg}{df}}_{\\bar f} \\frac{df}{dx} .$$"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "a4a4a693",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"(-0.32484122107701546, (-1.2559761698835525,))"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function g(x)\n",
" a = x^pi\n",
" b = cos(a)\n",
" c = log(x)\n",
" y = b * c\n",
" y\n",
"end\n",
"(g(1.4), gradient(g, 1.4))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "0d51a1d8",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"data": {
"text/plain": [
"-1.2559761698835525"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"function gback(x, y_)\n",
" a = x^pi\n",
" b = cos(a)\n",
" c = log(x)\n",
" y = b * c\n",
" # backward pass\n",
" c_ = y_ * b \n",
" b_ = c * y_\n",
" a_ = -sin(a) * b_\n",
" x_ = 1/x * c_ + pi * x^(pi-1) * a_\n",
" x_\n",
"end\n",
"gback(1.4, 1)"
]
},
{
"cell_type": "markdown",
"id": "ec931e73",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Kinds of algorithmic differentation\n",
"\n",
"* Source transformation: Fortran code in, Fortran code out\n",
" * Duplicates compiler features, usually incomplete language coverage\n",
" * Produces efficient code\n",
"* Operator overloading: C++ types\n",
" * Hard to vectorize\n",
" * Loops are effectively unrolled/inefficient\n",
"* Just-in-time compilation: tightly coupled with compiler\n",
" * JIT lag\n",
" * Needs dynamic language features (JAX) or tight integration with compiler (Zygote, Enzyme)\n",
" * Some [sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow)"
]
},
{
"cell_type": "markdown",
"id": "2b1493b0",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# How does Zygote work?"
]
},
{
"cell_type": "code",
"execution_count": 144,
"id": "a4846311",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[90m; @ In[144]:1 within `h1`\u001b[39m\n",
"\u001b[95mdefine\u001b[39m \u001b[36mdouble\u001b[39m \u001b[93m@julia_h1_18050\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n",
"\u001b[91mtop:\u001b[39m\n",
"\u001b[90m; ┌ @ intfuncs.jl:322 within `literal_pow`\u001b[39m\n",
"\u001b[90m; │┌ @ operators.jl:591 within `*` @ float.jl:385\u001b[39m\n",
" \u001b[0m%1 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[0m, \u001b[0m%0\n",
" \u001b[0m%2 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%1\u001b[0m, \u001b[0m%0\n",
"\u001b[90m; └└\u001b[39m\n",
"\u001b[90m; ┌ @ promotion.jl:389 within `*` @ float.jl:385\u001b[39m\n",
" \u001b[0m%3 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[0m, \u001b[33m3.000000e+00\u001b[39m\n",
"\u001b[90m; └\u001b[39m\n",
"\u001b[90m; ┌ @ float.jl:383 within `+`\u001b[39m\n",
" \u001b[0m%4 \u001b[0m= \u001b[96m\u001b[1mfadd\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%3\u001b[0m, \u001b[0m%2\n",
"\u001b[90m; └\u001b[39m\n",
" \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%4\n",
"\u001b[33m}\u001b[39m\n"
]
}
],
"source": [
"h1(x) = x^3 + 3*x\n",
"h2(x) = ((x * x) + 3) * x\n",
"@code_llvm h1(4.)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "7d6a5aae",
"metadata": {
"cell_style": "split"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[90m; @ /home/jed/.julia/packages/Zygote/xGkZ5/src/compiler/interface.jl:95 within `gradient`\u001b[39m\n",
"\u001b[95mdefine\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[93m@julia_gradient_15323\u001b[39m\u001b[33m(\u001b[39m\u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[33m)\u001b[39m \u001b[0m#0 \u001b[33m{\u001b[39m\n",
"\u001b[91mtop:\u001b[39m\n",
"\u001b[90m; @ /home/jed/.julia/packages/Zygote/xGkZ5/src/compiler/interface.jl:97 within `gradient`\u001b[39m\n",
"\u001b[90m; ┌ @ /home/jed/.julia/packages/Zygote/xGkZ5/src/compiler/interface.jl:45 within `#60`\u001b[39m\n",
"\u001b[90m; │┌ @ In[54]:1 within `Pullback`\u001b[39m\n",
"\u001b[90m; ││┌ @ /home/jed/.julia/packages/Zygote/xGkZ5/src/compiler/chainrules.jl:206 within `ZBack`\u001b[39m\n",
"\u001b[90m; │││┌ @ /home/jed/.julia/packages/Zygote/xGkZ5/src/lib/number.jl:12 within `literal_pow_pullback`\u001b[39m\n",
"\u001b[90m; ││││┌ @ intfuncs.jl:321 within `literal_pow`\u001b[39m\n",
"\u001b[90m; │││││┌ @ float.jl:385 within `*`\u001b[39m\n",
" \u001b[0m%1 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%0\u001b[0m, \u001b[0m%0\n",
"\u001b[90m; ││││└└\u001b[39m\n",
"\u001b[90m; ││││┌ @ promotion.jl:389 within `*` @ float.jl:385\u001b[39m\n",
" \u001b[0m%2 \u001b[0m= \u001b[96m\u001b[1mfmul\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%1\u001b[0m, \u001b[33m3.000000e+00\u001b[39m\n",
"\u001b[90m; │└└└└\u001b[39m\n",
"\u001b[90m; │┌ @ /home/jed/.julia/packages/Zygote/xGkZ5/src/lib/lib.jl:17 within `accum`\u001b[39m\n",
"\u001b[90m; ││┌ @ float.jl:383 within `+`\u001b[39m\n",
" \u001b[0m%3 \u001b[0m= \u001b[96m\u001b[1mfadd\u001b[22m\u001b[39m \u001b[36mdouble\u001b[39m \u001b[0m%2\u001b[0m, \u001b[33m3.000000e+00\u001b[39m\n",
"\u001b[90m; └└└\u001b[39m\n",
"\u001b[90m; @ /home/jed/.julia/packages/Zygote/xGkZ5/src/compiler/interface.jl:98 within `gradient`\u001b[39m\n",
" \u001b[0m%.fca.0.insert \u001b[0m= \u001b[96m\u001b[1minsertvalue\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[95mzeroinitializer\u001b[39m\u001b[0m, \u001b[36mdouble\u001b[39m \u001b[0m%3\u001b[0m, \u001b[33m0\u001b[39m\n",
" \u001b[96m\u001b[1mret\u001b[22m\u001b[39m \u001b[33m[\u001b[39m\u001b[33m1\u001b[39m \u001b[0mx \u001b[36mdouble\u001b[39m\u001b[33m]\u001b[39m \u001b[0m%.fca.0.insert\n",
"\u001b[33m}\u001b[39m\n"
]
}
],
"source": [
"@code_llvm gradient(h1, 4.)"
]
},
{
"cell_type": "markdown",
"id": "b6180a8e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Forward or reverse?\n",
"\n",
"It all depends on the shape.\n",
"\n",
"$$ \\operatorname d h = \\frac{dh}{dg} \\left(\\frac{dg}{df} \\left(\\frac{df}{dx} \\operatorname d x \\right) \\right) .$$\n",
"\n",
"$$ \\operatorname d h = \\left( \\left( \\left( \\frac{dh}{dg} \\right) \\frac{dg}{df} \\right) \\frac{df}{dx} \\right) \\operatorname d x $$\n",
"\n",
"* One input, many outputs: use forward mode\n",
" * \"One input\" can be looking in one direction\n",
"* Many inputs, one output: use reverse mode\n",
" * Will need to traverse execution backwards (\"tape\")\n",
" * Hierarchical checkpointing\n",
"* About square? Forward is usually a bit more efficient."
]
},
{
"cell_type": "markdown",
"id": "80af4c12",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Inference using PDE-based models\n",
"\n",
""
]
},
{
"cell_type": "markdown",
"id": "d88ae97a",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# [Compressible Blasius boundary layer](https://en.wikipedia.org/wiki/Blasius_boundary_layer#Compressible_Blasius_boundary_layer)\n",
"\n",
"* Activity will solve this 1D nonlinear PDE"
]
}
],
"metadata": {
"@webio": {
"lastCommId": null,
"lastKernelId": null
},
"celltoolbar": "Slideshow",
"kernelspec": {
"display_name": "Julia 1.7.2",
"language": "julia",
"name": "julia-1.7"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.8.1"
},
"rise": {
"enable_chalkboard": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}