Introduction to Gradient Descent: A Beginner's Guide
Gradient Descent is one of the most fundamental algorithms in machine learning and deep learning. Whether you're training a linear regression model or a deep neural network, gradient descent helps you minimize the error. In this post, we'll explore what gradient descent is, how it works, and how to implement it in Python.
What is Gradient Descent?
At its core, Gradient Descent is an optimization algorithm used to minimize the loss function of a model. The loss function measures how far off your predictions are from the actual results. Gradient Descent updates the parameters (weights and biases) of your model to reduce this loss.
You can think of it like hiking down a hill: the goal is to find the lowest point (where there is minimum loss), and you take steps downhill in the direction of the steepest descent (gradient).
The Math Behind Gradient Descent
The algorithm follows this general update rule:
θ = θ - α * ∇J(θ)
Where:
θ
: Parameters (weights) of the modelα
: Learning rate (how big your steps are)∇J(θ)
: Gradient of the loss function with respect to the parameters
The gradient tells us the direction to move in order to reduce the loss.
A Simple Example in Python
Let’s walk through a simple example: minimizing a quadratic function y = x^2
.
Step 1: Import Libraries
import numpy as np
import matplotlib.pyplot as plt
Step 2: Define the Function and Its Derivative
def function(x):
return x**2
def gradient(x):
return 2*x
Step 3: Implement Gradient Descent
def gradient_descent(starting_point, learning_rate, n_iterations):
x = starting_point
history = [x]
for _ in range(n_iterations):
grad = gradient(x)
x = x - learning_rate * grad
history.append(x)
return x, history
Step 4: Run and Visualize
final_x, trajectory = gradient_descent(starting_point=10, learning_rate=0.1, n_iterations=50)
# Plotting
x_vals = np.linspace(-10, 10, 400)
y_vals = function(x_vals)
plt.plot(x_vals, y_vals, label='y = x^2')
plt.scatter(trajectory, [function(x) for x in trajectory], color='red', s=10, label='Gradient Descent Path')
plt.title("Gradient Descent Minimizing x^2")
plt.xlabel("x")
plt.ylabel("y")
plt.legend()
plt.grid(True)
plt.show()
Few Things to Note
- Learning Rate: If it’s too small, convergence is slow. Too large, and it might overshoot or diverge.
- Local Minima: In complex functions, gradient descent might get stuck in local minima.
- Saddle Points: Points where the gradient is zero but not a minimum.
There are several variants of gradient descent used in practice:
- Batch Gradient Descent: Computes the gradient using the entire dataset.
- Stochastic Gradient Descent (SGD): Updates parameters for each training example.
- Mini-Batch Gradient Descent: Uses small batches (e.g., 32 or 64 examples) for updates.
Final Thoughts
Gradient Descent is the workhorse behind training most machine learning models. Understanding how it works is essential for debugging, tuning, and improving your models.