If you're vaguely familiar with the world of AI then you'll be familiar with Andrej Karpathy. An AI researcher, co-founder of OpenAI and director of AI at Tesla for a number of years. Karpathy has taken more of a step back from driving the charge in recent times, focusing more attention to Eureka Labs, an education company he's started.
To this latter endeavour, Karpathy is brilliant at distilling complex topics down into digestible formats for those without strong backgrounds in research domains. Since 2023 he has uploaded numerous YouTube videos on neural networks (NN) and LLM's. All worth watching for those curious on the topics.
In early February, he announced microgpt. A GitHub Gist, only 200 lines of Python (and not the mangled kind that is hard to read), that trains a GPT-like network. A great testament to Karpathy's foundational knowledge, and extremely welcoming to the uninitiated in giving a simple entry-point to a world of much greater complexity.
I've spent a number of hours since it's release familiarising myself with the code and diving deeper into the concepts. Pulling back the layers is interesting and satisfying, realising how simple the underlying mathematics are. I am no AI researcher, or someone with a particularly maths-heavy background. I did my fair share of calculus at school, in university, and continually where needs arise during work or free time, and to me the experience has been great.
I encourage anyone that may have an interest in this domain to explore this further, as well as supporting content. Whilst I think the content is very much approachable, some parts do assume some prior knowledge, or can generally be more daunting. As such I wanted to use this entry to question and answer some key points. Either questions that I found myself asking, or answers to what I would assume are some questions I think learners would ask (or would be good to ask) as they go through the content. They are not in any particular order, and may not make complete sense in isolation, yet I would encourage you to have a cursory glance and return frequently if you do review the content!
- why loss function
- gradients why
- derivative
- data input vs weight
- why LLMs are so 'good'
"What is a neural network trying to do?"
Conceptually, a neural network is attempting to make a prediction of some output given some set of inputs. What we want that output to be obviously depends on the context, essentially what we desire for the network to be good at predicting.
You can imagine, even if we don't know what happens within the network, that providing some inputs that are very similar to what we want the network to output would be beneficial. But why would we think that is the case?
Since we know what data we are pushing in, it is possible for us to know what we might expect out. With such knowledge, we can inform the network when it produces an output (using some method we don't yet know) how 'far' off it was. This is the loss function. Generally, this isn't a yes or no question, since we're predicting numbers it is possible to measure how far off the prediction was. The further away, the greater the loss.
When it is said that a NN is being 'trained', this just means that the network is given many different inputs and is making predictions with some intuition into how far off the prediction was from the actual answer. And armed with that heuristic, it is possible to narrow in towards a more accurate prediction.
"What is a derivative"
Without even looking at NN, just know they are powered by derivatives. We're straight into calculus, but not of the scary kind!
A derivative is just answering: "If I adjust the input to some function, how does it impact the output".
y = 3x
x is our input and y is our output. If x changes by 1 then y changes by 3, because of the multiplication in this example. So the derivative is 3. It's how sensitive the output is to a change in the input.
If we were to plot y = 3x on a chart, it would be a straight line. Intuitively, this tells us that the change is inputs is predictable, and hence the gradient (rate of change) will be too. Of course, not every function we plot on a graph will be straight, so calculating it's sensitivity to change is a little more difficult but not impossible. There are many rules to use in calculations.
"How do derivatives relate to loss?"
These NN produce their output (prediction) from the inputs. The network, is a collection of interconnected nodes (neurons), with the input being represented as nodes at the start of the network and the output node at the end. Each node will:
- Take the inputs (either the raw data or if deeper in the network the output from a prior node)
- Multiply them by it's weights
- Add a bias
- Put the result through an activation function
- Pass the result to subsequent nodes
Since we have the loss once a final output is determined, we know how far off the prediction was. As we also know what function we just ran our data through we can also calculate the derivative. However, going beyond our simple example, we know our input(s) passed through many functions before a result. These can simply be represented as nested functions:
L = f(g(h(x))) # L = loss # f = loss function # g(), h() = nodes
That's where the chain rule applies as such that we can calculate the derivative at each 'step' of the overall function. Combining all of these in the network together gets us the gradient, which exactly like the derivative, tells us in which direction the loss is happening such that we can move in the opposite direction to get closer to the actual answer.
"Why do Neural networks have multiple inputs?"
You will rarely see a NN with a single input except in the most basic cases. Because NN's are a series of interconnected nodes, even two or three layer networks with a handful of inputs, quickly explode in their number of edges (remember that a single node can take input from multiple other nodes) and subsequently the complexity of calculations.
It isn't immediately obvious without example why a NN would accept multiple inputs at a single point. A great example of this was a Multi Layer Perceptron (MLP) I created at university. The aim of this project was given a dataset on penguins, containing species, island, height, and sex, can you predict from raw data which species might the penguin be. You can probably see where this is going but by passing multiple data points on the penguins to a network you can learn better relationships in data that allow for more accurate prediction. Two species might have very similar heights but differ in weight and a NN trained on this nuance is likely to perform better.
Why are LLM's so 'good'?
This is really the big question in technology at this current time, and one in which, including myself, nobody really has a definitive answer to. If you ask any sane person if a system that can predict the next token in a sequence could write functional code, solve puzzles, and exhibit some form of reasoning, they'd said no.
From a birds-eye view it doesn't make sense, but as has manifested, successfully crafting a multi-billion parameter network and feeding it the corpus of available human text, allows it to learn patterns. And hidden within those patterns are logic. Logic that appears as intelligence because it reflects back exactly what we have learnt and applied hundreds and thousands of times.
Fundamentally, the LLM is doing exactly what any human would do, given some input, and what I know what do I need to do next. Give a human something foreign to them and they can be equally flabbergasted as an LLM. They'll also hallucinate, sometimes good and sometimes bad. This rule applies up to the current edge, give a complex multi-step task to an LLM and watch it begin to fall apart. Losing track of what was done or misremembering details. Tasks that a human would perform much better on.
Is that a computational limit? or perhaps something cognitively that a network can't represent? I don't know, but I'm excited to find out.