Jekyll2022-02-02T04:46:07+00:00https://lukemetz.github.io/feed.xmlLuke Metzdescription?On the Difficulty of Extrapolation with NN Scaling2022-01-24T12:00:00+00:002022-01-24T12:00:00+00:00https://lukemetz.github.io/difficulty-of-extrapolation-nn-scaling<hr />
<p>As deep learning models get bigger and bigger, doing any form of hyperparameter tuning is becoming prohibitively expensive as each training run can <a href="https://venturebeat.com/2020/06/01/ai-machine-learning-openai-gpt-3-size-isnt-everything/">cost millions</a>. Recently, there has been a surge of interest in understanding how the performance improves as model size increases
[<a href="https://arxiv.org/abs/2001.08361">1</a>,
<a href="https://arxiv.org/abs/2010.14701">2</a>,
<a href="https://arxiv.org/abs/2112.11446">3</a>,
<a href="https://arxiv.org/abs/2005.14165">4</a>].
Understanding this scaling could enable research at smaller, cheaper scales to more easily transfer to larger, more expensive, but more performant settings.
By leveraging small scale experiments performed at multiple model sizes, one can find simple functions (often power-law relationships) that can predict performance on larger models before spending the compute needed to train them.</p>
<p>While great in theory, this has difficulties in practice. If not careful, extrapolating scaling performance can mislead, causing companies to invest millions to train a model that performs no better than considerably smaller models.
In this post, we’ll walk through an example showing how this can be, as well as one reason why this could happen.</p>
<p>As a toy task to study these effects, let’s say our goal is to train ImageNet in a ridiculously wide MLP with 3 hidden layers.
We will start small with hidden sizes of 64, 128, and 256. We use these to pick hyperparameters, in this case to find learning rate for Adam of 3e-4. We also fix the length of training to 30k weight updates with batches of 128 images.</p>
<p>Next, we will seek to understand how our model changes with hidden size. We’ll train models ranging in size and look at how performance changes and plot the results.</p>
<div style="text-align:center">
<img src="/assets/images/nn_scaling_blog/pre.png" />
<figcaption class="caption">
Performance of 8 different models with different hidden sizes (shown in blue). The fitted linear regression (dashed black) should ideally be able to predict loss at a given hidden size.
</figcaption>
</div>
<p>The data looks surprisingly linear on this log-log plot. Great, we found our “law”!
We can find the coefficients of this relation with least squares: <code class="language-plaintext highlighter-rouge">loss(hsize) = 7.0 - 0.275 log(hsize)</code>. Empirically, this seems to hold for more than two orders of magnitude in hidden size.</p>
<p>All excited about our nice looking interpolation, we thought we could extrapolate a little over one order of magnitude in hidden size to train a bigger model. However, to our dismay, we find the performance dramatically off of our predicted curve.</p>
<div style="text-align:center">
<img src="/assets/images/nn_scaling_blog/post.png" />
<figcaption class="caption">
The performance achieved with the larger model (shown in red) is quite poor and greatly underperforms our prediction from the smaller scale models (dashed-black line).
</figcaption>
</div>
<p>In the real world, a mess up like this could cost thousands or even millions of dollars given how big the models are these days.
At the >100B parameter scale, doing any form of experimenting to figure out what is wrong with a model is near impossible.
Luckily, we are working on a small scale and thus can afford the luxury of being exhaustive with our experiments – in this case we can run 12 model sizes each with 12 different learning rates (with 3 random initializations a piece) totalling 432 trials.</p>
<div style="margin-left:-100px; margin-right:-100px">
<img src="/assets/images/nn_scaling_blog/4pane.png" />
<figcaption class="caption">
The results of training 12 different model sizes with 12 different learning rates. Each figure shows a different representation of this data. In (a) we show loss achieved for different hidden sizes with learning rate shown in color. Our extrapolation before was with a single learning rate. In (b) we show loss for a given learning rate with hidden size in color. Larger models reach a lower loss, but need a smaller learning rate. In (c) we show a heat map showing learning rate vs hidden size. Each pixel here is the results of a full training run. In (d) we look at what the optimal learning rate is for a given hidden size.
</figcaption>
</div>
<p>With this data, the story becomes quite clear and should come as no surprise.
As we increase model size, the optimal learning rate shrinks.
We can also see that if we simply train with a smaller learning rate, we would come close to our originally predicted performance at a given model size.
We could even model the relationship between the optimal learning rate and model size then use this model to come up with yet another prediction.
The plot of optimal learning rate vs hidden size (d) appears to be a power law (linear on log-log) so incorporating this wouldn’t be much trouble.</p>
<p>Even with this correction, how do we know we are not tricking ourselves again with some other hyperparameter which will wreak havoc in the next order of magnitude of hidden size?
Learning rate seems to be important, but what about learning rate schedules?
What about other optimization parameters?
What about architecture decisions? What about relationships between width and depth? What about initialization? What about precision of floating point numbers (or lack thereof)? In many cases, the default, and accepted values for a variety of hyperparameters are all set at a relatively small scale – who’s to say they work with larger models?</p>
<p>Issues of scaling relationships seem to keep popping up as more and more folks train bigger and bigger models.
Even simple things like scaling learning rates with model size as shown here is not always done (i.e. when specifying a finetuning procedure for language models).
To its credit, the <a href="https://arxiv.org/abs/2001.08361">original scaling law paper</a> discusses many theses issues (width/depth scaling, relationship with LR, <a href="https://arxiv.org/abs/1812.06162">effect of batchsize</a>), but also acknowledges that it neglects to study many others.
They also discuss relationships with compute amount and datasize but I don’t discuss or vary those here. The scaling laws they propose are designed under the assumption that the underlying model is trained with the best performing hyper parameters.</p>
<p>So what can we do about potentially misleading extrapolations? In an ideal world, we would have a full understanding of how every aspect of our model changes with scale and use this understanding to design larger scale models. Without this, extrapolation seems fraught and could potentially result in an expensive mistake. Reaching this point of full understanding, however, feels impossible given just how many factors are at play. Tuning every parameter at every scale is not a solution either given the computational costs.</p>
<p>One potential solution is to use scaling laws to predict the <strong>best case</strong> performance.
As one scales up, if the performance goes off the power law relation, one should see this as a signal that something is not tuned or set up properly.
I have heard this is the mindset OpenAI often uses.
Put another way, when scaling doesn’t work as expected, it might mean something interesting is happening.
Knowing what to do about this, or what parameters to tune to fix this performance degradation can be extremely challenging.</p>
<p>In my opinion, one must balance using scaling laws to extrapolate performance at larger scales, and actually evaluating performance at a larger scale.
In some sense this is obvious, and a rough approximation of what is done in practice.
As the study of scaling develops, I hope this balance can be made more explicit and that one can make more use of scaling relationships to enable more research at a small scale.
Take this particular example, while we found that naively scaling with a fixed learning rate does not extrapolate, we did find a power law relationship between model size and learning rate which leads to models that do extrapolate within the tested model sizes. Is there some other factor that we are missing if we try to extrapolate to even larger models? Possibly.
It’s hard to know without running the experiment.</p>
<h3>Acknowledgments:</h3>
<p>This blog is possible thanks to Google Research (my current employer) for the compute to perform these experiments.
I would also like to thank <a href="https://twitter.com/imordatch">Igor Mordatch</a>, <a href="https://twitter.com/ethansdyer">Ethan Dyer</a>, <a href="https://twitter.com/jaschasd">Jascha Sohl-Dickstein</a>, <a href="https://twitter.com/chipro">Chip Huyen</a> for reviewing early versions of this post.</p>lukemetzExploring hyperparameter meta-loss landscapes with Jax2021-02-06T12:00:00+00:002021-02-06T12:00:00+00:00https://lukemetz.github.io/exploring-hyperparameter-meta-loss-landscapes-with-jax<hr />
<p>A common mantra of the deep learning community is to differentiate though all the things, e.g. <a href="https://arxiv.org/abs/2006.12057">differentiable renderer</a>, <a href="https://papers.nips.cc/paper/2018/file/842424a1d0595b76ec4fa03c46e8d755-Paper.pdf">differentiable physics</a>, differentiable programming language[<a href="https://github.com/FluxML/Zygote.jl">julia</a>, <a href="https://github.com/google-research/dex-lang">dex</a>, <a href="https://github.com/mila-iqia/myia">myia</a>], etc. In my own research, I’ve found that, while one can often compute a gradient, it isn’t always the most useful quantity. This is especially true with “complex” loss landscapes, because the crazier the loss landscape is, the less useful local information (i.e. the gradient) is to find a global minimum. These complex loss landscape can emerge from iterative computation such as common optimization procedures for machine learning [<a href="https://arxiv.org/abs/1502.03492">1</a>, <a href="https://arxiv.org/abs/1810.10180">2</a>, <a href="https://arxiv.org/abs/1703.03400">3</a>].</p>
<p>In this post, we’ll walk through an example showing how extraordinarily complex meta-loss landscapes can emerge from a relatively simple setting and as a result gradients of these loss landscapes become a lot less useful.</p>
<p>We’ll do this in a relatively new machine learning library: <a href="https://github.com/google/jax">Jax</a>. Jax is amazing. I’ve been using it more and more for my research and have noticed it gaining traction at Google. Despite this, very few externally have even heard about it, let alone use it. This post I also wanted to show some features that make doing this type of exploration easy.</p>
<h1 id="why-jax">Why Jax?</h1>
<p>At its core, Jax provides an API nearly identical to Numpy. There is no built in neural network library (e.g. no torch.nn, or tf.keras for example, though there are many Jax NN libraries under development (<a href="https://github.com/google/flax">flax</a>, <a href="https://github.com/deepmind/dm-haiku">haiku</a> are my favorites). Instead, one composes a set of simple operations into whatever computation is desired.</p>
<p>To make things fast, Jax offers a just-in-time compiler (jit) which compiles your functions to run on CPU, GPU, and TPU. This setup provides both really fast execution (Jax models are among the fastest in the recent <a href="https://mlperf.org/training-results-0-7">MLPerf benchmarks</a>), and flexibility (Jax not only does deep learning well, but has also been used for <a href="https://github.com/google/jax-md">molecular dynamics simulation</a>, and to design <a href="https://arxiv.org/abs/2009.00196">Stellarator</a> – a device to contain plasma). Jax builds upon XLA, a linear algebra compiler which is responsible for taking these low graphs, and making a fast executing program.</p>
<p>One core concept of Jax is <strong>function transforms</strong>. One can transform some function that operates on numerical data to another function operating on similar data. For example: <code class="language-plaintext highlighter-rouge">jax.jit</code> is a function transform. <code class="language-plaintext highlighter-rouge">jax.jit(f)</code>, returns a function with the same interface but that will compile it to run fast. Jax also has automatic differentiation written in this way, <code class="language-plaintext highlighter-rouge">jax.grad(f)(x)</code>, will return the derivative of f evaluated at x.</p>
<p>My favorite feature of Jax is <strong>auto vectorization</strong>. Say, we want to evaluate <code class="language-plaintext highlighter-rouge">f</code> on not one, but multiple inputs. We could do <code class="language-plaintext highlighter-rouge">[f(x) for x in xs]</code>, but this will execute <code class="language-plaintext highlighter-rouge">f</code> many times. When leveraging deep learning accelerators, this type of sequential computation can slow things down greatly. Instead we can use vmap: <code class="language-plaintext highlighter-rouge">jax.vmap(f)(xs)</code>. What this will do is <em>vectorize</em> f, and then will execute a single vectorized application. Lets say <code class="language-plaintext highlighter-rouge">f = lambda x: np.sum(x^2)</code>, defined on rank 1 arrays. The vectorized code will instead run something like <code class="language-plaintext highlighter-rouge">np.sum(x^2, axis=1)</code>, which operates on rank 2 arrays and returns a vector. Because we are executing a smaller number of ops, on more data, we can leverage hardware a lot better.</p>
<p>These program transformations are <strong>composable</strong> too. One can mix and match <code class="language-plaintext highlighter-rouge">grad</code>’s, <code class="language-plaintext highlighter-rouge">vmap</code>’s, and <code class="language-plaintext highlighter-rouge">jit</code>’s. I can write a complex machine learning model, <code class="language-plaintext highlighter-rouge">g</code>, that operates on a single example, vectorizes it, and performs backprop through this auto vectorization like this:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">batch_loss</span><span class="p">(</span><span class="n">theta</span><span class="p">,</span> <span class="n">xs</span><span class="p">):</span>
<span class="k">return</span> <span class="n">np</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">partial</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">,</span> <span class="n">theta</span><span class="p">))(</span><span class="n">xs</span><span class="p">))</span>
<span class="n">grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">grad</span><span class="p">(</span><span class="n">batch_loss</span><span class="p">))</span>
<span class="n">dl_dtheta</span> <span class="o">=</span> <span class="n">grad_fn</span><span class="p">(</span><span class="n">theta</span><span class="p">,</span> <span class="n">xs</span><span class="p">)</span>
</code></pre></div></div>
<p>Now instead, let’s say I want to compute <em>per example</em> gradients. This would usually be difficult in libraries like TF or Pytorch, but in Jax it’s just the composition of function transforms in a different order.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">per_example_grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">grad</span><span class="p">(</span><span class="n">loss_fn</span><span class="p">)))</span>
<span class="n">batch_of_dl_dtheta</span> <span class="o">=</span> <span class="n">per_example_grad_fn</span><span class="p">(</span><span class="n">theta</span><span class="p">,</span> <span class="n">xs</span><span class="p">)</span>
</code></pre></div></div>
<h1 id="exploring-meta-loss-landscapes">Exploring meta-loss landscapes</h1>
<p>One of the simplest examples of meta-learning is hyperparameter search. At surface level, one might expect that changing hyperparameters results in a predictable change in the performance of optimization. In some settings, this couldn’t be more wrong. I will demonstrate one such example here on a simple 1D loss landscape.</p>
<p>In the process process, I will be demonstrating some cool features of Jax including:
Auto vectorization (<code class="language-plaintext highlighter-rouge">jax.vmap</code>) and use it to visualize both inner-loss (the problem we are training to train), and outer-loss landscapes (how performance on that problem changes with hyper parameters).
Gradient computation (<code class="language-plaintext highlighter-rouge">jax.grad</code>, <code class="language-plaintext highlighter-rouge">jax.value_and_grad</code>) through complex, unrolled optimization procedures
Compilation (<code class="language-plaintext highlighter-rouge">jax.jit</code>) sprinkled around my code to greatly speed things up</p>
<p>For ease of reproducing, all the code is in <a href="https://colab.research.google.com/drive/12nB1nrvLJsu_3bEYJzd4jCftnIGLoMaK?usp=sharing">this colab</a> and can be run on GPU or TPU which are both available for free.</p>
<p>First, some imports:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="kn">import</span> <span class="nn">jax</span>
<span class="kn">import</span> <span class="nn">jax.numpy</span> <span class="k">as</span> <span class="n">jnp</span>
<span class="kn">from</span> <span class="nn">matplotlib</span> <span class="kn">import</span> <span class="n">pylab</span> <span class="k">as</span> <span class="n">plt</span>
</code></pre></div></div>
<h2 id="inner-problem-a-1d-optimization-problem-optimized-with-sgdm">Inner problem: A 1D optimization problem optimized with SGDM</h2>
<p>Let’s start by defining a simple loss function we want to minimize, in this case a 1D problem. I am calling it the <code class="language-plaintext highlighter-rouge">inner_loss</code>. To make things interesting, lets not just use a quadratic, but instead something a little funky looking so that we have a more complex, non-convex loss:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">jax</span><span class="p">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">inner_loss</span><span class="p">(</span><span class="n">x</span><span class="p">):</span>
<span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="n">log</span><span class="p">(</span><span class="n">x</span><span class="o">**</span><span class="mi">2</span> <span class="o">+</span> <span class="mf">1.0</span> <span class="o">+</span> <span class="n">jnp</span><span class="p">.</span><span class="n">sin</span><span class="p">(</span><span class="n">x</span><span class="o">*</span><span class="mi">3</span><span class="p">))</span> <span class="o">+</span> <span class="mf">1.5</span>
</code></pre></div></div>
<p>We can then use auto vectorization (vmap) to visualize this loss surface.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">xs</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mi">5</span><span class="p">,</span> <span class="mi">5</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">xs</span><span class="p">,</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">inner_loss</span><span class="p">)(</span><span class="n">xs</span><span class="p">))</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_hparam_metaopt/image7.png" alt="Plot" /></p>
<p>This loss has a number of interesting characteristics. It is non-convex, and has 2 minima – a global minima at approximately <code class="language-plaintext highlighter-rouge">x=-0.5</code>, with a value of zero, and a local minima at approximately <code class="language-plaintext highlighter-rouge">x=1.5</code> with a value of about 2.</p>
<p>We can use SGD + momentum with a given learning rate and momentum to train this for 50 iterations. We will make use of gradient computation via <code class="language-plaintext highlighter-rouge">jax.value_and_grad</code> of the inner problem as well as <code class="language-plaintext highlighter-rouge">jax.jit</code> to speed things up.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">4.</span> <span class="c1"># initial inner parameter
</span><span class="n">v</span> <span class="o">=</span> <span class="mf">0.0</span> <span class="c1"># initial momentum accumulator
</span>
<span class="n">lr</span> <span class="o">=</span> <span class="mf">0.1</span>
<span class="n">mom</span> <span class="o">=</span> <span class="mf">0.9</span>
<span class="n">value_and_grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">inner_loss</span><span class="p">))</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">value_and_grad_fn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">v</span><span class="o">=</span> <span class="n">mom</span><span class="o">*</span><span class="n">v</span> <span class="o">+</span> <span class="n">grad</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">v</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">losses</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"inner-step"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"inner-loss"</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_hparam_metaopt/image3.png" alt="Plot" /></p>
<p>We can see we first descend into the global minima, then move out (due to momentum) to settle in the local minima.</p>
<h2 id="outer-loss-optimization-performance-as-a-function-of-optimizer-hyperparameters">Outer-loss: Optimization performance as a function of optimizer hyperparameters</h2>
<p>Now, let’s explore how our inner problem training behaves as a function of momentum and learning rate. To do this, we will define an <code class="language-plaintext highlighter-rouge">outer_loss</code> function which, for a given set of hyperparameters, computes the mean inner-loss over the 50 step unroll. This can be considered a measurement of how “good” the learning rate and momentum are to train this inner-problem quickly and to a low loss. We can jit this /entire/ unrolled computation so it runs fast.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">jax</span><span class="p">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">outer_loss</span><span class="p">(</span><span class="n">lr</span><span class="p">,</span> <span class="n">mom</span><span class="p">):</span>
<span class="n">value_and_grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">inner_loss</span><span class="p">))</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">4.</span>
<span class="n">v</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">value_and_grad_fn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">v</span><span class="o">=</span> <span class="n">mom</span><span class="o">*</span><span class="n">v</span> <span class="o">+</span> <span class="n">grad</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">v</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">losses</span><span class="p">))</span>
</code></pre></div></div>
<p>To get a sense of this outer-loss function, we can leverage auto vectorization (<code class="language-plaintext highlighter-rouge">jax.vmap</code>) again, and plot the outer-loss as a function of learning rate. Here the in_axes denotes which dimensions to vectorize – in this case, add a batch dimension to argument zero (learning rate), but don’t add any to argument one (momentum).</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lrs</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">5000</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">semilogx</span><span class="p">(</span><span class="n">lrs</span><span class="p">,</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">outer_loss</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">(</span><span class="mi">0</span><span class="p">,</span> <span class="bp">None</span><span class="p">)))(</span><span class="n">lrs</span><span class="p">,</span> <span class="mf">0.9</span><span class="p">),</span> <span class="s">"-"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylim</span><span class="p">(</span><span class="o">-</span><span class="mf">2.6</span><span class="p">,</span> <span class="mf">5.0</span><span class="p">)</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_hparam_metaopt/image8.png" alt="Plot" /></p>
<p>This leads to a kinda funky result. First, the loss surface is incredibly sensitive to the learning rate – much more than traditionally considered. Despite being an entirely deterministic optimization problem there appears to be noise. This means small changes in learning rate compound and produce dramatically different behaviors reminiscent of <a href="https://en.wikipedia.org/wiki/Chaos_theory">chaos</a>.</p>
<p>We don’t have to stop there though, we can compose vmap and apply it twice to plot the entire outer-loss landscape. Here, we are going to parameterize the momentum via <code class="language-plaintext highlighter-rouge">1-log(val)</code>. This is because momentum values are usually something like 0.9, 0.99, 0.999 and so on, so to plot it makes sense to space those evenly.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">lrs</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">2000</span><span class="p">)</span>
<span class="n">moms</span> <span class="o">=</span> <span class="mf">1.</span> <span class="o">-</span> <span class="n">jnp</span><span class="p">.</span><span class="n">logspace</span><span class="p">(</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2000</span><span class="p">)</span>
<span class="n">lrs</span><span class="p">,</span> <span class="n">moms</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">lrs</span><span class="p">,</span> <span class="n">moms</span><span class="p">)</span>
<span class="n">img_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">outer_loss</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">figure</span><span class="p">(</span><span class="n">figsize</span><span class="o">=</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span><span class="mi">10</span><span class="p">))</span>
<span class="n">img</span> <span class="o">=</span> <span class="n">img_fn</span><span class="p">(</span><span class="n">lrs</span><span class="p">,</span> <span class="n">moms</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">imshow</span><span class="p">(</span><span class="n">img</span><span class="p">,</span> <span class="n">extent</span><span class="o">=</span><span class="p">[</span><span class="o">-</span><span class="mi">3</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="o">-</span><span class="mi">3</span><span class="p">],</span> <span class="n">aspect</span><span class="o">=</span><span class="s">"auto"</span><span class="p">,</span> <span class="n">interpolation</span><span class="o">=</span><span class="s">"nearest"</span><span class="p">,</span> <span class="n">vmax</span><span class="o">=</span><span class="mi">10</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"log learning rate"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"log (1 - momentum)"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">colorbar</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_hparam_metaopt/image6.png" alt="Plot" /></p>
<p>The outer-loss function continues to be kinda crazy – full of local minimum and quite high curvature. There is also this periodic behavior with learning rate, and what appears to be a low loss region that is low momentum and lower learning rate.</p>
<h2 id="outer-optimization-with-gradients">Outer-optimization with gradients</h2>
<p>Now on this simple, 2D problem, to find good parameters we can simply take the min of a bunch of random trials. This doesn’t scale well to a lot of hyper parameters though and would make for a boring example. Another less common approach that people often try (to varying degrees of success) is to use <a href="https://arxiv.org/abs/1502.03492">gradient descent to find hyperparameters</a>.</p>
<p>With Jax, this can be done relatively simply by applying <code class="language-plaintext highlighter-rouge">jax.grad</code> to the outer-loss function. But first, let’s define our outer-loss function that operates on a tuple of parameters (lr, mom) as opposed each as separate inputs. Second, let us change the parametrization so steps in the learning rate / momentum roughly have a similar effect regardless of the parameter location. In this case, we can do this by exponentiating the learning rate variable, and taking one minus the exponentiated momentum variable.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="o">@</span><span class="n">jax</span><span class="p">.</span><span class="n">jit</span>
<span class="k">def</span> <span class="nf">outer_loss</span><span class="p">(</span><span class="n">outer_params</span><span class="p">):</span>
<span class="n">log_lr</span><span class="p">,</span> <span class="n">one_minus_log_mom</span> <span class="o">=</span> <span class="n">outer_params</span>
<span class="n">lr</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">power</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">log_lr</span><span class="p">)</span>
<span class="n">mom</span> <span class="o">=</span> <span class="mf">1.0</span> <span class="o">-</span> <span class="n">jnp</span><span class="p">.</span><span class="n">power</span><span class="p">(</span><span class="mi">10</span><span class="p">,</span> <span class="n">one_minus_log_mom</span><span class="p">)</span>
<span class="n">x</span> <span class="o">=</span> <span class="o">-</span><span class="mf">4.</span>
<span class="n">v</span> <span class="o">=</span> <span class="mf">0.0</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">50</span><span class="p">):</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">grad</span> <span class="o">=</span> <span class="n">value_and_grad_fn</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
<span class="n">v</span><span class="o">=</span> <span class="n">mom</span><span class="o">*</span><span class="n">v</span> <span class="o">+</span> <span class="n">grad</span>
<span class="n">x</span> <span class="o">=</span> <span class="n">x</span> <span class="o">-</span> <span class="n">lr</span> <span class="o">*</span> <span class="n">v</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
<span class="k">return</span> <span class="n">jnp</span><span class="p">.</span><span class="n">mean</span><span class="p">(</span><span class="n">jnp</span><span class="p">.</span><span class="n">asarray</span><span class="p">(</span><span class="n">losses</span><span class="p">))</span>
</code></pre></div></div>
<p>We can compute outer gradients (and the outer-loss value) with this:</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">outer_grad_fn</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">value_and_grad</span><span class="p">(</span><span class="n">outer_loss</span><span class="p">))</span>
</code></pre></div></div>
<p>And, use a minimal training loop to meta-train / outer-train. Now this training loop also clips the outer-gradients. Looking at the loss surface visualizations above, the gradients can have extremely high magnitudes. Without this clipping, outer-training quickly diverges with SGD.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">outer_params</span> <span class="o">=</span> <span class="p">(</span><span class="o">-</span><span class="mf">2.</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">)</span>
<span class="n">outer_params_traj</span> <span class="o">=</span> <span class="p">[</span><span class="n">outer_params</span><span class="p">]</span>
<span class="n">losses</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">grads</span> <span class="o">=</span> <span class="p">[]</span>
<span class="n">alpha</span> <span class="o">=</span> <span class="mf">0.01</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="mi">100</span><span class="p">):</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">outer_grad</span> <span class="o">=</span> <span class="n">outer_grad_fn</span><span class="p">(</span><span class="n">outer_params</span><span class="p">)</span>
<span class="n">grads</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">outer_grad</span><span class="p">)</span>
<span class="n">outer_grad</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="o">-</span><span class="mf">10.0</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">),</span> <span class="n">outer_grad</span><span class="p">)</span>
<span class="n">outer_params</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">tree_multimap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">a</span><span class="p">,</span><span class="n">b</span><span class="p">:</span> <span class="n">a</span><span class="o">-</span><span class="n">alpha</span><span class="o">*</span><span class="n">b</span><span class="p">,</span> <span class="n">outer_params</span><span class="p">,</span> <span class="n">outer_grad</span><span class="p">)</span>
<span class="n">outer_params_traj</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">outer_params</span><span class="p">)</span>
<span class="n">losses</span><span class="p">.</span><span class="n">append</span><span class="p">(</span><span class="n">loss</span><span class="p">)</span>
</code></pre></div></div>
<p>We can plot this trajectory on the original loss surface.</p>
<p><img src="/assets/images/jax_hparam_metaopt/image4.png" alt="Plot" /></p>
<p>It looks like this found a local minimum as we are not moving any further. We can see evidence of this by looking at the loss surface right around the solution found. We can do this easily with <code class="language-plaintext highlighter-rouge">jax.vmap</code> again.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="n">ts</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">vs</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">outer_loss</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">((</span><span class="bp">None</span><span class="p">,</span> <span class="mi">0</span><span class="p">),</span> <span class="p">)))((</span><span class="n">outer_params</span><span class="p">[</span><span class="mi">0</span><span class="p">],</span> <span class="n">outer_params</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span> <span class="o">+</span> <span class="n">ts</span><span class="p">))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">ts</span><span class="p">,</span> <span class="n">vs</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"mom shift"</span><span class="p">)</span>
<span class="n">ts</span> <span class="o">=</span> <span class="n">jnp</span><span class="p">.</span><span class="n">linspace</span><span class="p">(</span><span class="o">-</span><span class="mf">0.05</span><span class="p">,</span> <span class="mf">0.05</span><span class="p">,</span> <span class="mi">100</span><span class="p">)</span>
<span class="n">vs</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">jit</span><span class="p">(</span><span class="n">jax</span><span class="p">.</span><span class="n">vmap</span><span class="p">(</span><span class="n">outer_loss</span><span class="p">,</span> <span class="n">in_axes</span><span class="o">=</span><span class="p">((</span><span class="mi">0</span><span class="p">,</span> <span class="bp">None</span><span class="p">),</span> <span class="p">)))((</span><span class="n">outer_params</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span><span class="o">+</span><span class="n">ts</span><span class="p">,</span> <span class="n">outer_params</span><span class="p">[</span><span class="mi">1</span><span class="p">]))</span>
<span class="n">plt</span><span class="p">.</span><span class="n">plot</span><span class="p">(</span><span class="n">ts</span><span class="p">,</span> <span class="n">vs</span><span class="p">,</span> <span class="n">label</span><span class="o">=</span><span class="s">"lr shift"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">ylabel</span><span class="p">(</span><span class="s">"outer-loss"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">xlabel</span><span class="p">(</span><span class="s">"shift"</span><span class="p">)</span>
<span class="n">plt</span><span class="p">.</span><span class="n">legend</span><span class="p">()</span>
</code></pre></div></div>
<p><img src="/assets/images/jax_hparam_metaopt/image1.png" alt="Plot" /></p>
<p>What I am showing here is shifts in learning rate centered at our final iterate. At <code class="language-plaintext highlighter-rouge">shift=0,</code> we see we are at what looks to be a local minimum. Moving a bit in either the learning rate or the momentum direction will increase the loss. We also see a really high curvature surface – it will be incredibly easy to get stuck in these ripples.</p>
<p>Overall, we got pretty lucky there though. The solution we found was still pretty good in the grand scheme of things. We can run this same procedure initializing the learning rate and the momentum in a more unstable regime and we will get stuck in a much worse position.</p>
<p><img src="/assets/images/jax_hparam_metaopt/image9.png" alt="Plot" /></p>
<p>And looking at our final hyperparameters we can see that around <code class="language-plaintext highlighter-rouge">x=0</code>, (<code class="language-plaintext highlighter-rouge">shift=0</code>), the final point of our gradient based training) we are stuck in an unstable (highly sensitive to outer-parameter) region.</p>
<p><img src="/assets/images/jax_hparam_metaopt/image5.png" alt="Plot" /></p>
<h2 id="outer-optimization-with-es">Outer-optimization with ES</h2>
<p>This illustrates some of the difficulty trying to use gradient based methods to train hyperparameters. Instead of using gradients computed via backprop, we can leverage a stochastic algorithm – in this case Evolutionary Strategies. The core idea is simple, and has been written up many places before. At a high level, ES samples points around the current outer-parameter value, and moves in the direction of decreased loss. This effectively smooths out the loss surface.</p>
<p>To implement this, we must make use of Jax’s random number generation. Randomness is one area Jax’s design diverges from Numpy. In particular, Numpy leverages a global random number state. (<code class="language-plaintext highlighter-rouge">np.random.RandomState</code> also exists, which isn’t global, but still relies upon mutated state.) Instead, Jax leverages pure and stateless random number generation. One specifies a random state, the key as it is often called: <code class="language-plaintext highlighter-rouge">key = jax.random.PRGKey(seed)</code>, and can use this key to generate deterministic, pseudo random numbers: <code class="language-plaintext highlighter-rouge">jax.random.normal(key, ..)</code>. What this means is with the same key, two calls to <code class="language-plaintext highlighter-rouge">jax.random.normal(key, ..)</code> will return the same value.</p>
<p>To get around this, Jax implements various ways to split / generate new keys. So if one wants to generate two random numbers, one can create 2 keys: <code class="language-plaintext highlighter-rouge">key1, key2 = jax.random.split(key, 2)</code> then generate a number using each key.</p>
<p>With this, we can implement ES. In our case, I am using antithetic samples. What this means is we will sample some noise value, run the outer-function on the base parameters plus this noise, and the base parameters minus this noise. We will then compare these two values, and produce a gradient that moves in the lower loss direction.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code><span class="k">def</span> <span class="nf">outer_gradient_es</span><span class="p">(</span><span class="n">outer_params</span><span class="p">,</span> <span class="n">key</span><span class="p">,</span> <span class="n">std</span><span class="p">):</span>
<span class="n">lr_noise</span><span class="p">,</span> <span class="n">mom_noise</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">random</span><span class="p">.</span><span class="n">normal</span><span class="p">(</span><span class="n">key</span><span class="p">,</span> <span class="p">[</span><span class="mi">2</span><span class="p">])</span><span class="o">*</span><span class="n">std</span>
<span class="n">outer_params_pos</span> <span class="o">=</span> <span class="p">(</span><span class="n">outer_params</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">+</span> <span class="n">lr_noise</span><span class="p">,</span> <span class="n">outer_params</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">+</span><span class="n">mom_noise</span><span class="p">)</span>
<span class="n">outer_params_neg</span> <span class="o">=</span> <span class="p">(</span><span class="n">outer_params</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">-</span> <span class="n">lr_noise</span><span class="p">,</span> <span class="n">outer_params</span><span class="p">[</span><span class="mi">1</span><span class="p">]</span><span class="o">-</span><span class="n">mom_noise</span><span class="p">)</span>
<span class="n">pos_loss</span> <span class="o">=</span> <span class="n">outer_loss</span><span class="p">(</span><span class="n">outer_params_pos</span><span class="p">)</span>
<span class="n">neg_loss</span> <span class="o">=</span> <span class="n">outer_loss</span><span class="p">(</span><span class="n">outer_params_neg</span><span class="p">)</span>
<span class="n">factor</span> <span class="o">=</span> <span class="p">(</span><span class="n">pos_loss</span> <span class="o">-</span> <span class="n">neg_loss</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="mi">2</span> <span class="o">*</span> <span class="n">std</span><span class="o">**</span><span class="mi">2</span><span class="p">)</span>
<span class="n">outer_grad</span> <span class="o">=</span> <span class="p">(</span><span class="n">lr_noise</span> <span class="o">*</span> <span class="n">factor</span><span class="p">,</span> <span class="n">mom_noise</span> <span class="o">*</span> <span class="n">factor</span><span class="p">)</span>
<span class="k">return</span> <span class="p">(</span><span class="n">pos_loss</span> <span class="o">+</span> <span class="n">neg_loss</span><span class="p">)</span> <span class="o">/</span> <span class="mi">2</span><span class="p">,</span> <span class="n">outer_grad</span>
</code></pre></div></div>
<p>Outer training looks similar to before, but now instead of using gradient we will use the above.</p>
<div class="language-python highlighter-rouge"><div class="highlight"><pre class="highlight"><code> <span class="n">std</span> <span class="o">=</span> <span class="mf">0.1</span>
<span class="n">loss</span><span class="p">,</span> <span class="n">outer_grad</span> <span class="o">=</span> <span class="n">outer_gradient_es</span><span class="p">(</span><span class="n">outer_params</span><span class="p">,</span> <span class="n">key1</span><span class="p">,</span> <span class="n">std</span><span class="p">)</span>
<span class="n">outer_grad</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">tree_map</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">jnp</span><span class="p">.</span><span class="n">clip</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="o">-</span><span class="mf">10.0</span><span class="p">,</span> <span class="mf">10.0</span><span class="p">),</span> <span class="n">outer_grad</span><span class="p">)</span>
<span class="n">outer_params</span> <span class="o">=</span> <span class="n">jax</span><span class="p">.</span><span class="n">tree_multimap</span><span class="p">(</span><span class="k">lambda</span> <span class="n">a</span><span class="p">,</span><span class="n">b</span><span class="p">:</span> <span class="n">a</span><span class="o">-</span><span class="n">alpha</span><span class="o">*</span><span class="n">b</span><span class="p">,</span> <span class="n">outer_params</span><span class="p">,</span> <span class="n">outer_grad</span><span class="p">)</span>
</code></pre></div></div>
<p>Outer-training for a few thousand steps leads the following trajectory in the 2D space of outer-parameters:</p>
<p><img src="/assets/images/jax_hparam_metaopt/image2.png" alt="Plot" /></p>
<p>Still not perfect, but all things considered fairly good. It reaches a much better solution than the gradient based method. To make this better, we could anneal the std of ES, use smaller learning rates, use more samples for ES.
Conclusion
I hope this demonstrates a bit of an insight into hyperparameter loss landscapes and how care should be taken when computing gradients through unrolled optimization problems as well as a bit of why so many researchers are loving Jax. Interested in exploring any of this further? Grab a free GPU / TPU colab and give it a go – Jax is included by default. <a href="https://colab.research.google.com/drive/12nB1nrvLJsu_3bEYJzd4jCftnIGLoMaK?usp=sharing">My notebook can be found here.</a></p>
<h1 id="acknowldgements">Acknowldgements</h1>
<p>Thanks so much to <a href="https://twitter.com/chipro">Chip Huyen</a>, the best writer I know, and my amazing colleagues <a href="https://twitter.com/bucketofkets">C. Daniel Freeman</a> and <a href="https://twitter.com/utkuevci">Utku Evci</a> for feedback on this post. Thanks <a href="https://twitter.com/mat_kelcey">Mat Kelcey</a> for the much shorter clolab + tpu snippet.</p>lukemetzMotion Capture via Probabilistic Inference2019-06-30T11:00:00+00:002019-06-30T11:00:00+00:00https://lukemetz.github.io/mocap1<hr />
<p>In an effort to control my robot arm, I realize it would be good to have a way to measure the position of the arm in 3D. I could do the more typical solution of motor encoders + inverse kinematics, but this seems error prone and inaccurate especially considering just how much mechanical wiggle room exists in my system – everything is 3D printed out of PLA after all. As nothing is hugely rigid, it would be great to have a more absolute measure of position.</p>
<p>As such, I figured it would be fun to build a motion capture system for this purpose! The basic idea is to have a bunch of webcams, add IR filters to them, and take pictures of objects with IR LEDs on them. Then, because we are capturing from multiple points of view, the system is over constrained and we can work out the true positions.</p>
<p>This post talks about the first steps taken to this end and includes a preliminary discussion of the potential hardware as well as a prototype structuring of the motion capture problem as probabilistic inference! As usual, I have no real idea what I am doing so feedback is appreciated!</p>
<h1 id="hardware">Hardware</h1>
<p>The hardware I settled on is a cheap <a href="https://www.amazon.com/gp/product/B004FHO5Y6/ref=ppx_yo_dt_b_search_asin_title?ie=UTF8&psc=1">Logitech C270 webcam</a> for $20. I disassembled these, 3D printed little lens hoods, and attached <a href="https://www.bhphotovideo.com/c/product/292664-REG/LEE_Filters_87CP3_3_x_3_Infrared.html">IR filter</a>. For markers, I took some IR LEDs and sanded down the transparent plastic to give more of a diffuse output. Below, we can see a test shot with one of these cameras. The LED is super clear and should be fairly easy to pick out via some simple computer vision.</p>
<div class="side-by-side">
<div class="toleft">
<a href="/assets/images/mocap1/hardware.jpg"><img src="/assets/images/mocap1/hardware.jpg" /></a>
</div>
<div class="toright">
<a href="/assets/images/mocap1/raw_data_from_camera.jpg"><img src="/assets/images/mocap1/raw_data_from_camera.jpg" /></a>
</div>
<figcaption class="caption">Left: Mechanical setup for a single camera.
Contains 3D printed lens hood, plus IR filter.
Right: Raw image captured of diffused IR Led.
</figcaption>
</div>
<h1 id="software-lets-treat-motion-capture-as-probabilistic-inference">Software: Let’s treat motion capture as probabilistic inference!</h1>
<p>I am a Bayesian at heart, so naturally I turned to these tools when designing this system. Philosophically, this family of methods fits this project quite well – we can build low quality hardware and make up for this in software by performing inference on a model that incorporates a lot of uncertainty! For example, I don’t expect my point tracking to be perfect or any sort of syncing between cameras.</p>
<p>For now though, I started with something simple and entirely in a simulated world. This is great because I can specify (and thus know) the ground truth generative process. In this setup, I am assuming I have a series of 2D point readings (observed data) produced via a generative process defined by multiple virtual cameras (with unknown position and orientation) that observe multiple moving 3D points of unknown positions. The goal is to recover the camera positions and orientation as well as the positions of the points through time.</p>
<h2 id="generative-model">Generative model</h2>
<p>Let’s first consider the generative model we wish to use for determining inference. Let us call the observed data, the 2D points observed from multiple cameras, \(X\) (indexed by $c$ to denote camera index and $i$ to denote point index), and the unknown variables we wish to do inference on, the camera positions and 3D point locations, \(\theta\). We can now define \(P(X|\theta)\) which is the probability that we observe the 2D position readings (\(X\)) given the underlying world (\(\theta\)).</p>
<p>To unpack this further, let’s first consider a single 3D point and a single 2D reading. Given a camera and a 3D point, we can project it into a 2D image. This projection is quite simple at this point and is based on a pinhole camera model. I parameterize the camera by translation, encoded as a vector with 3 components, and rotation encoded by a quaternion with 4 components. We convert these 2 values into a transformation matrix, \(T \in \mathcal{R}^{4,4}\). Given a 3D point, \(m\), we can apply this matrix, then the camera intrinsic matrix \(I\), and then finally normalize by the 3rd component to recover the 2D position, \(X_{ci}\): \(I T m = [x’, y’, s, 1]^{T}\), \(X_{ci} = [x’/s,y’/s]\). See <a href="https://docs.opencv.org/2.4/modules/calib3d/doc/camera_calibration_and_3d_reconstruction.html">here</a> for more info.</p>
<p>Given perfect observations and cameras, this projection would map a single 3D point to a single 2D point. Because the real world is noisy, though, and to make inference easier, let’s assume we could map to one of many possible 2D values. In our case, let’s assume a Gaussian probability centered around the predicted projection. Not all points, however, will be projected onto the image. Our camera only has a finite view and can only see points in front of it. In this case, let’s just call the (unnormalized) probability of a given image a constant.</p>
<p>When we have multiple points, however, we need to know which 3D point belongs to which 2D point. For this, we can simply take the sum of probabilities of all possible permutations. This scales poorly O(N^2) with the number of points but should still be computable for the small number of points I expect to have.
Finally, there is more than one camera, and more than one time point. To account for this, we can just take the product of these different components.</p>
\[P(X|\theta) = \sum_{permutations of points} \prod_{cameras} \prod_{times} \prod_{points} P(X_{ci} | camera, point)\]
<h2 id="inference">Inference</h2>
<p>Ideally, I will do something proper (e.g. define a prior on \(\theta\), and do <a href="https://en.wikipedia.org/wiki/Markov_chain_Monte_Carlo">MCMC</a> or <a href="https://en.wikipedia.org/wiki/Variational_Bayesian_methods] based methods">VI</a>.) For now, though, let’s just compute a <a href="https://en.wikipedia.org/wiki/Maximum_a_posteriori_estimation">MAP</a> estimate of \(P(X | \theta)\). Put simply, let’s just find the single most likely configuration of cameras and 3D points, \(\theta\), from which the observed 2D data, \(X\), could derive. This boils down to an optimization problem:
\(\text{argmax}_{\theta} P(X |\theta)\).
Coming from the machine learning community, what better way to tackle this than by writing all of this in a differentiable programming language / framework (in my case <a href="https://github.com/google/jax">Jax</a>), computing gradients, and doing SGD!</p>
<p>Sadly, though, the loss surface here is horrendous with local minimum and odd symmetry galore. I made some plots varying camera orientation and position to show just how wonky the loss surface is. Further, because Jax is awesome, I also computed eigenvalues of Hessians at different sigma!</p>
<div class="side-by-side">
<div class="toleft">
<a href="/assets/images/mocap1/orientation_slice.png"><img src="/assets/images/mocap1/orientation_slice.png" /></a>
</div>
<div class="toright">
<a href="/assets/images/mocap1/translation_slice.png"><img src="/assets/images/mocap1/translation_slice.png" /></a>
</div>
<figcaption class="caption">2D loss surfaces as a function of
camera orientation (left) and camera translation (right). The underlying
loss surfaces are not well behaved, full of local minimum.
</figcaption>
</div>
<p><img src="/assets/images/mocap1/hessian_eigen.png" alt="Hessian eigen values" class="smaller-image" style="margin-bottom: 0px;" /></p>
<figcaption class="caption">Hessian of loss surface at random init. The
problem is poorly conditioned (has a wide range of hessian eigen
values). The curvature increases as we increase the standard deviation
of the Gaussian in 2D space (sigma).
</figcaption>
<p>To get around this, we can do a few things. First, we can anneal the sigma, or how sharply we lose probability, if our prediction is incorrect. Because there are certain areas of the loss landscape that are high curvature, I make use of gradient clipping, a <a href="https://arxiv.org/abs/1211.5063">trick</a> used in the deep learning community to train RNN. Finally, and somewhat counterintuitively for me, the more data we have, the easier the optimization process seems to be empirically.</p>
<h2 id="results">Results</h2>
<p>Thus far, I have tested all of this with 4 simulated cameras, at 2 points, and at 100 time points. I hardcoded the points to follow simple curves made from sin and cos. I use this ground truth data to generate the observed data, \(X\), which is a list of 2D positions – 2 for each camera. For optimization, I randomly initialize 3 of the 4 cameras, (choosing 1 camera to be fixed to try to pin down extra degrees of freedom), as well as randomly initialize all the 3D points and optimize using Nesterov momentum with gradient clipping.</p>
<p><img src="/assets/images/mocap1/loss_curve.png" alt="loss curve" class="smaller-image" style="margin-bottom: 0px;" /></p>
<figcaption class="caption">Loss curve, (negative log likelihood of data), over the course of
optimization. At 2000 steps, sigma is increased causing the spike in
loss.
</figcaption>
<p><img src="/assets/images/mocap1/camera_images.png" alt="Camera projections" class="bigger-image" style="margin-bottom: 0px;" /></p>
<figcaption class="caption">"Images" taken from the 4 cameras. Solid
lines denote ground truth data, points denote inferred points.
</figcaption>
<p>While looking at the results qualitatively, I realized that there were a few unconstrained degrees of freedom in the form of rotations. This is not too hard to fix, as one can observe an object of fixed proportions and orientations and rotate. For now, though, I cheat and use the camera positions and fit a rotation matrix.</p>
<div class="side-by-side">
<div class="toleft">
<a href="/assets/images/mocap1/unaligned_topdown.png"><img src="/assets/images/mocap1/unaligned_topdown.png" /></a>
</div>
<div class="toright">
<a href="/assets/images/mocap1/aligned_topdown.png"><img src="/assets/images/mocap1/aligned_topdown.png" /></a>
</div>
<figcaption class="caption">Top down view of 3D data both unaligned
(left) and aligned (right). Despite all
cameras seeing the correct things, there is an extra degree of freedom
in the form of rotation causing the predictions to be wrong. We can
rotate the space and recover a good solution.
</figcaption>
</div>
<p><img src="/assets/images/mocap1/3d.png" alt="Camera projections" class="smaller-image" style="margin-bottom: 0px;" /></p>
<figcaption class="caption"> Reconstructed 3D scene. Solid lines denote
ground truth data, with inferred data shown with points.
</figcaption>
<h1 id="next-steps">Next steps</h1>
<p>That’s all for this quick update! Things are still progressing albeit slowly on the arm itself. Still, I did add a new axis and am working on revision 2 of the electronics. As for the motion capture system, I will be building a rig to mount cameras, collecting real data, figuring out how to go from 2D images to points, and adding in more uncertainty.</p>
<p>Stay tuned!</p>lukemetzRobot Arm V2: Electronics2019-04-02T12:00:00+00:002019-04-02T12:00:00+00:00https://lukemetz.github.io/electronics-v2<hr />
<p>Things have been busy, but I am trying to make more time for this project! This post is an update on some of the electronics and software for the Arm.</p>
<p>In terms of electronics, the previous design consisted of one base microcontroller that controlled everything with wires directly going to each motor. This was not scalable, not modular, and led to too many wires everywhere. Additionally, given the new slip ring, I did not have enough high current channels to make this approach possible. As such, I opted to be more distributed. I still have one base controller, connected to a PC for commands, but this controller is connected to a I2C bus. I2C is a simple, synchronous, master slave protocol that comes built in to many sensors and microcontrollers. For each axis, I have another microcontroller that converts the incoming commands to drive motors, as well as potentially manages and sends back encoder data. Each board is daisy chained together with an 8 channel connector, with 2 channels for I2C, each with 2x for +12V, +5V, and GND.</p>
<div class="side-by-side">
<div class="toleft">
<a href="/assets/images/blog6/assemb.jpg"><img src="/assets/images/blog6/assemb.jpg" /></a>
</div>
<div class="toright" style="margin-top: 50px;">
<a href="/assets/images/blog6/many.jpg"><img src="/assets/images/blog6/many.jpg" /></a>
</div>
<figcaption class="caption">Left: Single assembled custom PCB. Right:
Boards connected to motor controllers.
</figcaption>
</div>
<p>I realized early in the last iteration that <a href="https://en.wikipedia.org/wiki/Breadboard">breadboard</a> based electronics were not going to cut it.
Too many wires, and too hard to debug if something came loose.
That, and I am to lazy to cut wires to the correct lengths.
Perfboards are better. I tried to work with these, but concluded that soldering and working with them was time consuming and not very fun.
After building one board, I gave up and decided to dive into the world of custom PCB! I learned, and made schematics in <a href="https://www.autodesk.com/products/eagle/overview">Eagle</a>, and sent them to <a href="https://oshpark.com/">oshpark</a> for fabrication. Turnaround time was < 1 week which is amazing – would highly recommend this process! The circuits themselves are simple – really just an Arduino shield, some connectors, a few resistors, and some LED for debugging. For simplicity, I chose not to place the circuit for the motor driver, nor the Arduino, on my boards. I didn’t want to source components, am not equipped to do the required volume of surface mount soldering, and am not all that confident with so much electrical engineering. As a result, the full electronics are a bit messy, requiring many connections between the motor controller and the Arduino board. Next iteration, I plan on reconsidering this, as I spent the majority of the time making connectors.</p>
<div class="side-by-side">
<div class="toleft">
<a href="/assets/images/blog6/schm.png"><img src="/assets/images/blog6/schm.png" /></a>
</div>
<div class="toright">
<a href="/assets/images/blog6/layout.png"><img src="/assets/images/blog6/layout.png" /></a>
</div>
<figcaption class="caption">Schematic (left) and layout (right) in Eagle.
</figcaption>
</div>
<div class="side-by-side">
<div class="toleft" style="margin-top: 10px;">
<a href="/assets/images/blog6/breadboard.jpg"><img src="/assets/images/blog6/breadboard.jpg" /></a>
</div>
<div class="toright">
<a href="/assets/images/blog6/perf.jpg"><img src="/assets/images/blog6/perf.jpg" /></a>
</div>
<figcaption class="caption">Left: Horrible mess of a breadboard. Right:
Annoying soldering required for perf boards.
</figcaption>
</div>
<p>On the software side, I am still sending inputs from an Xbox controller over sending SLIP to the master board. The master board now sends packets over I2C to each slave board. The slave boards listen for these, then set the appropriate pins to control the motor controller.</p>
<p>As of now, I have 2 axis more or less assembled with electronics. Current plans now are to 3D print and assemble a third axis, finish up controls, get more sensor data out, and figure out how to control this thing!</p>lukemetzRobot Arm V2: Axes 2 and 32018-05-26T12:00:00+00:002018-05-26T12:00:00+00:00https://lukemetz.github.io/mechanical-v2-axis2<hr />
<p>Axis 2, and to a lesser extent Axis 3 were by far the weakest of the <a href="/project-log-matcha-making-robot-arm/">previous design</a>. The 25kg-cm servos just didn’t cut it. It couldn’t move anything as more and more weight was added. To solve this problem, I decided to use the <a href="https://www.servocity.com/12-rpm-hd-premium-planetary-gear-motor-w-encoder">highest torque (geared) motor I could find</a> – up to 584 kg-cm, or a little more than 20x the previous power capability! One negative though is that these motors require a very different form factor, so a complete redesign was needed. This post describes two such of iterations, the first for Axis 2, and then modifications of that for Axis 3.</p>
<p><img src="/assets/images/blog5/motors.jpg" alt="Final" /></p>
<figcaption class="caption">Motor Upgrade! Left: the old 25kg-cm servo. Right: New 584kg-cm motor.
</figcaption>
<h2 id="axis-2">Axis 2</h2>
<p>I decided to keep the same general design as the <a href="/project-log-matcha-making-robot-arm/">first version</a> but with a bunch of modifications. First, as before, I replaced the spur gears with herringbone gears to mitigate backlash. Next, I replaced the M8 threaded rod with a 3D printed shaft with two beefy 6008 bearings on either end for support. Finally, the mount needed to be bigger to accommodate a much larger motor. The motor itself is mounted to this, bolted into the PLA. As in the first axis, I applied a clamping couple to the drive shaft so that the drive gear could be attached.</p>
<p>Another annoying aspect of the previous design was assembly. Instead of finicky bolts installed vertically through tight spaces, I decided make a slot so the upper axis now fits around the lower, and I used a few bolts to secure it in place. Additionally, I am trying to use the same simple attachment mechanism across the arm to make things a little more modular.</p>
<p><img src="/assets/images/blog5/axis1.jpg" alt="Final" /></p>
<figcaption class="caption">Full assembly of Axis 2.
</figcaption>
<h2 id="axis-3">Axis 3</h2>
<p>In the first arm design, I simply printed out multiple copies of the the first axis and stacked them. I hoped to do the same thing here but decided instead to improve the design. In general, Axis 1 used waaaay too much plastic in terms of absolute size as well as in the thickness of components. In this Axis, I trimmed as much plastic as I could and thinned parts that didn’t need to be too strong. I also swapped the 6008 bearings for the lighter 6005 bearings. With this decreased size, the base piece needed to be two pieces so that I could get the motor installed correctly. While much smaller, it seems to hold up in terms of strength.</p>
<div class="side-by-side">
<div class="toleft" style="margin-top: 60px;">
<img src="/assets/images/blog5/layout.jpg" />
</div>
<div class="toright">
<img src="/assets/images/blog5/axis2.jpg" />
</div>
<figcaption class="caption">Right: All components before assembly. Left: Fully assembled Axis 3.
</figcaption>
</div>
<h2 id="results">Results</h2>
<p>Not surprisingly, these motors are incredibly strong. While testing, I accidentally ran a motor too long, which resulted in a collision with plastic. Instead of stalling, as I would have expected, the motor kept turning and broke the plastic! It’s good that the PLA is now probably the weakest component of this whole thing!</p>
<p><img src="/assets/images/blog5/full.jpg" alt="Final" /></p>
<figcaption class="caption">Fully assembled arm with 3 axes so far.
</figcaption>
<h2 id="next-up">Next up</h2>
<p>For well being of the 3D printed parts, it’s clear I cannot keep running these motors by manually with wires connected to a power supply.. There needs to be some sort of auto stop. I have been busy scoping sensors and such and hope to have an update on that as well as the rest of the electronics some point soon.</p>
<p>Thanks for dropping by. Your thoughts always welcome!</p>lukemetzRobot Arm Mechanical Design V2: Base2018-04-15T12:00:00+00:002018-04-15T12:00:00+00:00https://lukemetz.github.io/mechanical-v2-base<hr />
<p>It’s time for another iteration on the mechanical side of things. The old design had a number of issues. Foremost, at 20 kg-cm torque, the motors were incredibly underpowered. As a result, the base joint could not lift the remaining joints. Secondly, the lack of cable management on the old design meant that the base was not free to rotate fully without tangling. In addition, the use of spur gears meant that there was a lot of play / backlash in the whole system. Finally, assembly was a pain – the old design took far too long to put together. As such, I have been busy designing this last month, and have just now started to print some of the resulting pieces. This post will give an overview of the newly designed first axis!</p>
<p><img src="/assets/images/blog4/final.jpg" alt="Final" /></p>
<figcaption class="caption">Final assembly of the base.
</figcaption>
<h2 id="slip-ring">Slip Ring</h2>
<p>To manage wiring, I decided to try to make use of a <a href="https://en.wikipedia.org/wiki/Slip_ring">slip ring</a>. Slip rings are devices that enable electrical power transfer through rotation. Sadly, I couldn’t seem to find exactly what I wanted – a low diameter, high current ring that was not a fortune. I ended up going with <a href="https://www.amazon.com/gp/product/B01KBRV96U/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">this one</a> with the expectation that I will solder many of the wires together to increase power transfer. It is not small in diameter though, and to be effective the wire must be at the center of rotation. As such, I was in need of an essentially hollow shaft.</p>
<p><img src="/assets/images/blog4/slip.jpg" alt="Final" /></p>
<figcaption class="caption">The slip ring allows for continuous rotation of the base without getting wires tangled.
</figcaption>
<h2 id="hollow-3d-printed-shaft">Hollow 3D Printed Shaft</h2>
<p>Inspired by <a href="https://www.ytanaka-works.com/">Ttanaka’s arm</a> and the <a href="http://thorrobot.org/">Thor arm</a>, I decided to try the to make use of 3D printed axles / shafts. To fit around the slip ring, the inner diameter of this axel needs to be ~22 mm, which is quite large. Because of this size, I ended up using massive bearings to hold everything in place. For axial load I am using 51109 thrust bearings. To keep the shaft aligned, I am using 6008 deep groove ball bearings. Both of these are massively overkill in their max specs, but they fit the sizes that I need. I was debating to use just the 6008 for both axial and off axial load, but was unsure if the 3D printed parts would be strong enough as they would only be able to contact the race of the bearing (a few mm in diameter ring).</p>
<p>The actual bearing assembly is quite similar to the design from my <a href="http://lukemetz.com/project-log-matcha-making-robot-arm/">first post</a>. I sandwiched a 3D printed part between 2 thrust bearings, and then clamped the other ends of the thrust bearings together.</p>
<p><img src="/assets/images/blog4/shaft.jpg" alt="Final" /></p>
<figcaption class="caption">The 3D printed hollow shaft is printed in two pieces. Left is upper; right is lower.
</figcaption>
<p><img src="/assets/images/blog4/bearings.jpg" alt="Final" /></p>
<figcaption class="caption">The 6008 and 51109 bearings used.
</figcaption>
<h2 id="herringbone-gears">Herringbone Gears</h2>
<p>For gears, I swapped out the original spur gears with <a href="https://en.wikipedia.org/wiki/Herringbone_gear">herringbone gears</a>. These gears have much smoother motion as compared to spur gears, which results in less backlash. One of my favorite features about OnShape, the CAD program I am using, is FeatureScript. It exposes a programming language to design custom features or use other people’s. So instead of painstakingly cadding these gears, I used Aaron Griffith’s wonderful <a href="https://cad.onshape.com/documents/9ad0b046fa03032e4fc613ac/w/d6c4307218c918d50121e0ec/e/37ecc28aff8c1a0b615fcda4">FeatureScript feature</a> to create these gears! I have been using FeatureScript as much as I can in this project and plan to write a post specifically on this language in the future!</p>
<p>I had a lot of difficulty thinking about how to attach this gear to the shaft while keeping the diameter small enough to fit the bearings. My final solution ended up solving the problem by making the entire shaft 2 pieces and 3D printing the gear with the second half of shaft! This ended up being quite strong and is easy to assemble.</p>
<h2 id="new-motors">New Motors</h2>
<p>Instead of servo motors as before, I chose to use geared planetary gear motors. This first axis does not need to be all that strong, so I went with <a href="https://www.servocity.com/26-rpm-premium-planetary-gear-motor-w-encoder">one of these</a>. I opted for a motor with an encoder as I figured it can’t hurt to have an extra signal for control. I am using an off-the-shelf <a href="https://www.servocity.com/22mm-bore-clamping-hub-d">motor clamp</a>, and <a href="https://www.servocity.com/770-clamping-hubs#348=96">mounting hub</a> to connect to the drive gear. I 3D printed a little motor holding mount to put it at the right height and keep it vertical.</p>
<p><img src="/assets/images/blog4/motorgear.jpg" alt="Final" /></p>
<figcaption class="caption">Pieces of the motor mount.
</figcaption>
<h2 id="frame">Frame</h2>
<p>Everything is held in place by yet more plastic! I used a lot of struts to ensure that there will be no movement in the final frame. Additionally, I added little feet, 3D printed separately, with felt pads as this all this plastic has been scraping up my tables. I expect I will need to swap these feet out with something I can clamp down as the expected center of mass of the arm will easily be able to extend outside of this area.</p>
<div class="side-by-side">
<div class="toleft">
<a href="/assets/images/blog4/base1.jpg"> <img class="image" src="/assets/images/blog4/base1.jpg" alt="Alt Text" /> </a>
</div>
<div class="toright">
<a href="/assets/images/blog4/base2.jpg"> <img class="image" src="/assets/images/blog4/base2.jpg" alt="Alt Text" /> </a>
</div>
</div>
<div class="side-by-side">
<div class="toleft">
<a href="/assets/images/blog4/base3.jpg"> <img class="image" src="/assets/images/blog4/base3.jpg" alt="Alt Text" /> </a>
</div>
<div class="toright">
<a href="/assets/images/blog4/base4.jpg"> <img class="image" src="/assets/images/blog4/base4.jpg" alt="Alt Text" /> </a>
</div>
</div>
<h2 id="result">Result</h2>
<p>Everything seems to work! The motor spins, the 3D printed shaft seems strong enough, and I can put a lot of weight on it. See the video for this thing in action! There is still a decent amount of friction, though. It’s not enough to be a problem (I hope), but I am curious as to the source. The thrust bearing should be beefy enough to hold a lot more weight than what I have tested. Worst case, I can upgrade the motor. In general, I think I went a bit overboard on the stability of the thing. The first design was a little wobbly, so I thickened the plastic pretty much everywhere. This is okay for the base and Axis 1, but as weights starts to matter, (Axis 2 and 3), I need to find a way to cut down on plastic.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/STCQnQSqVHU" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen=""></iframe>
<p>Coming from a machine learning background, this type of design iteration, seems like something that I should not have to be doing by hand. I should probably invest time into learning a finite element analysis package and then manually making edits based on results. Taking this further, I wish things like <a href="https://en.wikipedia.org/wiki/Topology_optimization">topology optimization</a> were more widely available in hobby grade products.</p>
<p>In general, it seems like it’s harder to express these 3D objects / designs in a transferable and modifiable way as is done in software. JSON, Images, and TensorFlow graphs, for example, are all data structures that lend themselves to reason. Parametric CAD models on the other hand are quite difficult, even FeatureScript. To make meaningful edits, I need at least the visual “compiled” part. Meanwhile, all this is implicitly tied to some notion of “functionality”: how this part is going to be used, what needs to be strong, and what the maximum size of the item is to make things fit together? This coupling reminds me some “modern” software libraries springing up that make this separation explicit. Things like <a href="https://www.tensorflow.org/">TensorFlow</a>, for example, separate graph construction from execution. Similarly, <a href="http://halide-lang.org/">Halide</a>, which does the same for algorithm design from the execution schedule. Even web frameworks such as <a href="https://angular.io/">Angular</a> have some of this notion of separation of data and rendering! I am curious if any radically different workflow could be applied successfully to the mechanical design world.</p>
<p>Still so much to do! I am currently wrapping up the assembly and printing of the 2nd axis, and am thinking about designs for the 3rd and 4th. I am also wondering how the electronic system will function. The more I work on this project, the more I am amazed by the technology put into these commercial arms. It’s not as easy as it looks!</p>lukemetzDeep Learning to Control my Robotic Arm2018-02-25T12:10:00+00:002018-02-25T12:10:00+00:00https://lukemetz.github.io/deep-learning-controls<hr />
<p>This is the third installment in the chronicle of my attempt to build a robotic arm to make me tea. For the mechanical build, see <a href="/project-log-matcha-making-robot-arm">here</a>, and for the electrical and software groundwork for this post, see <a href="project-log-electronics">here</a>.</p>
<p>This thing is borderline impossible to control with an Xbox controller. Not only are there too many joints, there is no notion of correcting for the forces of gravity. As such, my first plan of attack is to see if I can build some controls that make it easier to control the arm by hand – in specific, when I release the controls on my XBox controller, I would like the robot to stop moving as opposed to come crashing down. From here, the hope is that I can then do point-to-point movement and “program” it to do tasks by running through a set of predefined states. This roadmap is designed to be sample efficient. Unlike Google, I do not have an <a href="https://research.googleblog.com/2016/03/deep-learning-for-robots-learning-from.html">army of these arms</a> at my disposal. This means that the more “traditional” deep reinforcement learning approaches (model-free control) are out of the question, as they are just too sample inefficient. It’s going to take a while then to get this point, but this is what I have so far.</p>
<p>In this post, I review the data collection and processing, discuss the forward dynamics model training, and finally, address the use of the model predictive control algorithm I employed and some initial results when applied on only 1 dimension. This post is almost entirely derived from <a href="http://bair.berkeley.edu/blog/2017/11/30/model-based-rl/">“Model-based Reinforcement Learning with Neural Network Dynamics”</a>, and in <a href="https://arxiv.org/abs/1708.02596">paper form</a>. I highly recommend both.</p>
<h3 id="data-pre-processing">Data pre-processing:</h3>
<p>As mentioned before, I log all data that is sent (motor commands) and received (accelerometer and gyro readings). To collect some data, I ran the robot with random-ish movements (generated by me and an XBox controller) for roughly 10 min. These are saved to ndjson files on disk. Importantly, data are received at specific instances of time, and are at different frequencies and not aligned. I resample the data to correct for this at some fixed rate. For this first test, I am resampling at 100 Hertz using the wonderful <a href="https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.interp.html">np.interp</a>.</p>
<h3 id="forward-modeling">Forward Modeling:</h3>
<p>My forward dynamics model takes the current sensor readings (the state, in reinforcement learning speak), and the control commands (the actions), and predicts the next sensor 1/100th of a second later. To complicate things, these commands do not happen instantly. This is partially attributable to the software side of things, and partially to the mechanical system itself – a motor cannot simply go from off to full power instantaneously. As such, this model must take into account this hidden information when making predictions. Technically speaking, the system I am modeling is a <a href="https://en.wikipedia.org/wiki/Partially_observable_Markov_decision_process">POMDP</a>.</p>
<p>The model I chose to work with first is a LSTM, as it’s capable of modeling these hidden states naturally.</p>
<p>I used a 64 unit hidden state model, with a linear layer transforming the output into the predicted actions. As done in <a href="https://arxiv.org/abs/1708.02596">Nagabandi et. al.</a>, instead of predicting the next state, the model predicts the difference from the current state to the next state. Additionally, to make the predictions of a sane scale (motors don’t move all that much in 1/100th of a step), I normalized the differences to unit mean and variance by estimating the mean and variance of the deltas over the training data.</p>
<p>The full update can be written as \(s_{t+1} = s_t + p(s_t, a_t; \theta) \sigma^2 + \mu\) where \(p(\cdot)\) is a normalized sample, and \(\sigma^2\) and \(\mu\) are the normalizing variance and mean of the training data respectively. Additionally, unlike the work from <a href="https://arxiv.org/abs/1708.02596">Nagabandi et al.</a>, I chose to use a stochastic model instead of a deterministic one. If you do the math, their model can also be written down exactly as optimizing log likelihood under a normal distribution with a fixed variance, but I figured I would just model it explicitly as such, and learn the normal distributions variance while I am at it.
Instead of the mean squared error loss, I minimize the negative log likelihood of the next state given all the current \((x_t, a_t)\) and past information \(h_t\).</p>
\[L = -\text{log}(p(x_t+1 | x_t, h_t, a_t; \theta))\]
<p>For training data, I take 10 second random slices, and run them through the model (with <a href="https://machinelearningmastery.com/teacher-forcing-for-recurrent-neural-networks/">teacher forcing</a>) with a batch size of 32. Additionally, I have a “test set” – a set of data collected after turning off and on the robot– to get at least some measure of over fitting. I train with the Adam optimizer and early stopped when the test loss stopped improving (in this case, 30 min).</p>
<p>To gain a quick sanity check of the model, I recreated a plot from <a href="https://arxiv.org/abs/1708.02596">Nagabandi et al.</a>, where they run the forward dynamics model for some amount of time and compare it to the ground truth data. Despite a seemingly horribly high loss, the predicted trajectories on test data appear reasonably good, at least good enough to test it for control.</p>
<p><img src="/assets/images/blog3/plot.png" alt="Plot" /></p>
<figcaption class="caption">The learned model seems to capture general trends, and responds to the control signal reasonably over a 10 second rollout. Top: Command signals. Bottom: The actual sensor readings (solid), as well as the predicted sensor readings (dashed) sampled from the model.
Each color represents a component of the sensors angle/quaternion. Because it's in 1D, only 2 components are changing.
</figcaption>
<h3 id="model-predictive-control-policy">Model Predictive Control Policy:</h3>
<p>With the model done, I next turned to actually using it. A forward dynamics model alone is not enough to do control. One needs to convert this into some form of policy. For this, I used <a href="https://en.wikipedia.org/wiki/Model_predictive_control">model predictive control</a>. Despite the fancy Wikipedia page, I believe this is actually a rather simple procedure – first, “plan” somehow into the future using the current model, then perform only the first step of this plan. This will bring you to a new state in which another round of planning occurs.</p>
<p>As my first control task, I tried simply to return to a given state. To keep things simple, instead of “planning” I picked a few fixed action sets. In my case, setting the motor on for 3 seconds each at one of the following powers: 1.0, 0.5, 0.0, -0.5, -1.0. I used these “plans” and did rollouts under the model. The best plan was the action sequence that put the final state closest to the target position. This is totally not ideal, and will not even converge in a lot of cases, but it’s a start! I tried this on the robot, and it almost kinda works, but oscillates around the correct solution.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/DMk6j4W8jpg" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen=""></iframe>
<figcaption class="caption">Testing out the controller. Because I didn't have a switch around in my apartment, the motors are engaged and disengaged with a wire. Sadly, there is a slight oscillation. </figcaption>
<p>Why does it oscillate? I am not entirely sure, but I am fairly confident that it has to do with the control frequency and latency in the model. First, I am issuing the command AFTER the model has finished making predictions. This means that while the model is churning away, the previous command is still executing on the robot and modifying the current state. This type of time delay can (and seems to) lead to oscillations. To validate this, I increased the amount of compute, and the oscillations got bigger.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/dWsH1znXzy8" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen=""></iframe>
<figcaption class="caption">When increasing the amount of compute per update, the oscillations grow bigger and never seem to stop unless the sensors are already in the correct location. </figcaption>
<h3 id="next-up">Next up:</h3>
<p>I have a few ideas as to how to remedy this time delay issue – I think it’s enough to set a fixed control frequency (say, 10 Hertz), and in the planning stage, plan as if the previous action was being executed for this amount of time. Doing this, however, will require another few weekends. I am going to put this effort on hold for now though and shift focus to a second, more powerful version of the mechanical design. Update soon!</p>lukemetzRobotic Arm Electronics and Firmware2018-02-19T12:00:00+00:002018-02-19T12:00:00+00:00https://lukemetz.github.io/project-log-electronics<hr />
<p>This is the second post in the documentation of my attempt to build a robotic arm to make me tea. For the mechanical build, see the <a href="/project-log-matcha-making-robot-arm/">first installment</a>. As usual, this is a learning process, so if something seems wrong or could be done better, please let me know in the comments or feel free to tweet me!</p>
<p>With the mechanical pieces in a semi-functioning state, my next action was to figure out how to drive and control this contraption. In this post, I provide a high level overview of the electronics and sensors I used and briefly touch on the firmware and communication which enable technically simple, yet incredibly difficult manual control. Bellow you can see the arm struggling with the underpowered motors.</p>
<iframe width="560" height="315" src="https://www.youtube.com/embed/tRS1kZs6feg" frameborder="0" allow="autoplay; encrypted-media" allowfullscreen=""></iframe>
<figcaption class="caption">Arm being controlled via xbox 360 controller.</figcaption>
<h3 id="motors-and-electronics">Motors and Electronics:</h3>
<p>For my main motors I acquired a few <a href="https://www.amazon.com/gp/product/B01N0XZOZU/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">inexpensive servos</a>. I “hacked” these to convert them into being able to sustain continuous rotation by taking each one apart, removing the mechanical stops, the potentiometer, and all the control electronics. I wired the motor directly to the positive and negative ports of the connector, and left the signal wire disconnected.
In retrospect, and given that my current motors are a little underpowered, this was a poor choice (at least for the lower joints which need more power) and I will need to redesign to use a much more powerful motor :(.
Because I removed all the control boards, I still needed something to drive the motors in both directions at a variable speed. For this, I am using a few motor controllers based on the <a href="https://www.amazon.com/gp/product/B06XR1YNH4/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">L298N chip</a>. These little boards are incredibly effective for the price ($1.23 dollars each from AliExpress).
I acquired 3 of them for the 5 motors in my build thus far, with 1 drive output left free for the eventual gripper.</p>
<p>Finally, I needed some way to sense the world.
For a while I debated using a camera, but decided against it as it would require a lot more effort on the control side of things.
Instead, I opted for a bunch of “Absolute Orientation” sensors (the BNO055 chip with a breakout board made by <a href="https://www.amazon.com/gp/product/B017PEIGIG/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">Adafruit</a>. I figured a few of these, with their ability to sense gravity, should be enough to make predictions about the state of the arm and thus allow me to control its behavior.
These are great and very modular as they are not tied to the motor in any way. Additionally, nothing is stopping me from using these as well as a camera if I conclude that the latter is needed.
Because I don’t plan this far into the future, I 3D printed little sensor holders, and hot glued them to the main body of the arm. Additionally, wire routing is done in a similar way.</p>
<p><img src="/assets/images/blog2/sensor3_2.jpg" alt="Sensors" /></p>
<figcaption class="caption"> 3D printed sensor holders.
</figcaption>
<p>These sensor breakout boards communicate over I2C and, sadly, all share the same I2C address, which means that normally I can only use one at a time. To remedy this, I am using an <a href="https://www.amazon.com/gp/product/B015HJX33Y/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">I2C multiplexer</a> which allows me to select the chip I seek to read and write from.</p>
<p><img src="/assets/images/blog2/electronics_s.jpg" alt="Wires" /></p>
<figcaption class="caption"> All electronics haphazardly scattered next to the arm.
</figcaption>
<h3 id="firmware">Firmware:</h3>
<p>Everything on the arm is controlled off of an old Arduino Diecimila. The board is programed to listen on the Serial/UART connection for commands sent from my PC, as well as read from the sensors and send data back to a host. This is a fairly low level connection, and thus still needs some protocol added. I figured a simple packet system would be straightforward and chose <a href="https://en.wikipedia.org/wiki/Serial_Line_Internet_Protocol">SLIP packets</a> for their simplicity and seemingly wide library support . Both the host and the Arduino read these packets and use them to communicate with each other.</p>
<h3 id="pc-control">PC Control:</h3>
<p>At present, the host/PC control is in Python script that grabs controls from the user input (in my case an XBox controller), packages it in a SLIP packet, and sends it over serial/UART to the Arduino. While performing this function, it periodically reads the sensor values. Additionally, I log out both controls and sensor readings to a <a href="http://ndjson.org/">ndjson</a>. This can be used for visualization and, as I will detail in a later post, to train the control models!</p>
<p>Operation of the arm from the Xbox 360 controller is surprisingly hard as the controls do not have any notion of gravity correction. For example, watch in this gif how as the arm extends over the center of mass, it comes crashing down as the motor is no longer fighting against gravity, but assisting it. Still, I can actually move the arm around now which is progress!</p>
<p><img src="/assets/images/blog2/plt.png" alt="Wires" class="smaller-image" /></p>
<figcaption class="caption"> Initial plot of some data while controlling
with an xbox. Only the first axis shown.
</figcaption>
<h3 id="next-up">Next up:</h3>
<p>The foregoing is seemingly enough to make the arm operate but there is no way I can perform any sort of delicate actions. Next up I plan on trying to remedy this with machine learning!</p>lukemetzProject Log: Matcha Making Robot Arm2018-01-23T22:10:00+00:002018-01-23T22:10:00+00:00https://lukemetz.github.io/project-log-matcha-making-robot-arm<hr />
<p><img src="/assets/images/projectlog1/full_bent.jpg" alt="Cross Section" class="smaller-image" /></p>
<p><a href="https://en.wikipedia.org/wiki/Matcha">Matcha</a>, a type of green tea, is great! Sadly, it requires a whole 2 min of <a href="http://http://matchasource.com/how-to-prepare-matcha-green-tea/">preparation</a>; I am far too lazy to make it as much as I would like. Naturally, I am trying to fix this dilemma in the most overcomplicated manner available – a home-built, 6-axis robotic arm! This is the first post in my quest to frothy green tea goodness.</p>
<p>Really though, this project is actually an excuse to learn more about robotics and to get back into making things. It’s clear that robotics is going to play a more predominant role in our lives (e.g. self driving cars). My day job exposes me to the control side of things (ML / AI), but I know almost nothing about the rest of it (mechanical, electrical, systems, so on). I don’t really know what I am doing in any of this, so please leave comments or questions on anything; I appreciate your input!
As for this first post, I plan to go over some of the mechanical work that I have already completed.</p>
<h3 id="overall-design-constraints">Overall design constraints:</h3>
<p>I want to strike a balance between specialized hardware and general purpose utility. Specialized hardware is easier to design and is cheaper, but less general. Cooking equipment has thus far lived in this realm – things like blenders, toasters, and bread makers. I plan to be someplace in the middle – a general base, 6-axis arm, with a specialized environment, tools, and controls.</p>
<p>I don’t have access to a shop, so it must be constructed in my apartment. My main tools are CAD (OnShape), a 3D printer (Lulzbot Taz 6) and an Amazon Prime account. I also like to keep costs down for obvious reasons.</p>
<h3 id="base--axis-1">Base – axis 1:</h3>
<p>The first axis needs to be quite strong – high radial and capable of axial loads to handle the weight of the arm pushing down and the lever effect of the arm when extending. I settled on using a <a href="https://www.amazon.com/gp/product/B002BBOHMI/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">thrust bearing</a>. Others have used these for arms, so this is nothing new. These bearings are capable of handling huge (100s of kg) axial force. I can leverage this and help address the non-axial loads by sandwiching my axle (a m8 threaded rod) between two rigid pieces of plastic.</p>
<p><img src="/assets/images/projectlog1/axis1_cross.png" alt="Cross Section" /></p>
<figcaption class="caption">Cross section showing thrust bearing. A 3D printed disk, attached to the threaded rod via 2 nuts, is sandwiched between 2 thrust bearings.</figcaption>
<p>To drive this, I secured a <a href="https://www.amazon.com/gp/product/B01N0XZOZU/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">20KG servo</a> geared 2/1 to the 3D-printed base, and added a few more <a href="https://www.amazon.com/gp/product/B07211VH78/ref=oh_aui_search_detailpage?ie=UTF8&psc=1">skateboard bearings</a> around to prevent off-axis movement. The resulting structure is reasonably strong but there is still some play in the 3D printed parts. Good enough for now!</p>
<div class="side-by-side">
<div class="toleft">
<img class="image" src="https://lukemetz.github.io/assets/images/projectlog1/axis1.png" alt="Axis 1" />
</div>
<div class="toright">
<img class="image" src="https://lukemetz.github.io/assets/images/projectlog1/axis1.jpg" alt="Axis 1" />
</div>
</div>
<h3 id="axes-2---4">Axes 2 - 4:</h3>
<p>My main consideration was what to do with motors. I knew I didn’t want to deal with any of the more exotic actuation types such as pneumatic. In the interest of keeping things simple, I decided to place motors as close to the driving axis as possible and go with simple gearing as opposed to more complex cable / belt driven solutions. While researching I was astonished at the complexity of industrial arms. In particular the geek group has a few different teardown-like videos of two rather large KUKA arms that I <a href="https://www.youtube.com/watch?v=6YiPrytt_Ss">would</a> <a href="https://www.youtube.com/watch?v=EfmjhfN8D-Q">recommend</a>.
I tried to keep my pieces as small and lightweight as possible both to lower print times and to keep the weight down. I once again went with the same servo form factor for motors mostly due to ease of use. I used the same 20kg motor, but expect to need to swap out the first axis with a higher torque version. The axes are M8 threaded rod, held via friction to some 3D printed parts.</p>
<div class="side-by-side">
<div class="toleft">
<img class="image" src="https://lukemetz.github.io/assets/images/projectlog1/axis2.png" alt="Axis 2" />
</div>
<div class="toright">
<img class="image" src="https://lukemetz.github.io/assets/images/projectlog1/axis2.jpg" alt="Axis 2" />
</div>
</div>
<h3 id="axis-5">Axis 5:</h3>
<p>Not much design thought went into this. It needs to rotate, but doesn’t need to have as much axial load so no thrust bearing is needed. I used another [20KG servo](https://www.amazon.com/gp/product/B01N0XZOZU/ref=oh_aui_search_detailpage?ie=UTF8&psc=1], once again geared down to rotate a threaded rod. Skateboard bearings are used above and below to keep the axle in line.</p>
<div class="side-by-side">
<div class="toleft">
<img class="image" src="https://lukemetz.github.io/assets/images/projectlog1/axis5.png" alt="Axis 5" />
</div>
<div class="toright">
<img class="image" src="https://lukemetz.github.io/assets/images/projectlog1/axis5.jpg" alt="Axis 5" />
</div>
</div>
<h3 id="future-posts">Future Posts:</h3>
<p>I hope to update this blog periodically as the project progresses with interesting tidbits I find along the way.
As for next steps, gripper design is currently still up in the air. Additionally, I am in the process of putting together sensors and controls. My first pass taught me that this is going to be MUCH harder than I originally thought…. Hopefully more to come soon!</p>
<p><img src="/assets/images/projectlog1/full_upright.jpg" alt="Cross Section" class="much-smaller-image" /></p>lukemetzCardboard Quadcopter2012-12-12T22:10:00+00:002012-12-12T22:10:00+00:00https://lukemetz.github.io/cardboard-quadcopter<p><img src="/assets/images/cardboard_quad.jpg" alt="quadcopter" /></p>
<hr />
<p>A fun school project I did a while back. A small group of us built a
laser cut cardboard quadcopter capable of autonomous flight via an
onboard raspberry pi.</p>
<p>A full build process can be found at <a href="http://www.instructables.com/id/Autonomous-Cardboard-Rasberry-Pi-Controlled-Quad/">Instructables</a>.</p>lukemetz