Generating an overfitting illustration with code

I've often seen illustrations of the over-/under-fitting phenomenon, and I've had to make such illustrations on occasions. Rather than having to draw it quick-and-dirty one more time in Inkscape, I've decided to code it.

Each time you execute the code, you'll get a slightly different plot, and you'll probably want to do it a few times to get the one plot you'll like best.

In [5]:
import numpy as np
import matplotlib.pyplot as plt
import warnings

np.random.seed(1)
warnings.simplefilter('ignore', np.RankWarning)
%matplotlib inline
In [6]:
x_noise = 0.2
y_noise = 0.3
x_min, x_max = -2, 2

train_x = [elem + x_noise * np.random.normal() for elem in range(x_min, x_max + 1)]
test_x = (x_max - x_min + 2) * np.random.rand(8) + x_min - 1
train_y = [elem**2 + x_noise * np.random.normal() for elem in train_x]
test_y = [elem**2 + y_noise * np.random.normal() for elem in test_x]

p = np.poly1d(np.polyfit(train_x, train_y, 1 + len(train_x)))

fig, ax = plt.subplots(figsize=(12, 8))

ax.scatter(train_x, train_y, color="blue", edgecolor="k", s=80, alpha=0.5, label="Training set")
ax.scatter(test_x, test_y, color="green", edgecolor="k", s=80, alpha=0.5, label="Testing set")
x_ = np.linspace(x_min - 1, x_max + 1, 100)
ax.plot(x_, x_ ** 2, color="blue", alpha=0.5, label="Best fit")
ax.plot(x_, p(x_), color="red", alpha=0.5, label="Overfitted model")
ax.hlines(np.mean(train_y), x_min - 1, x_max + 1, color="orange", alpha=0.5, label="Underfitted model")
plt.figtext(0.1, 0.9, '$y$')
plt.figtext(0.9, 0.1, '$x$')
#ax.set_xlabel(r"$x$") #, loc="right")
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.set_ylim(-1, (x_max + 1) ** 2 + 0.5)
ax.set_xlim(x_min - 1.5, x_max + 1.5)
plt.tick_params(left=False, right=False, labelleft=False,
                labelbottom=False, bottom=False)
ax.legend(loc=0, frameon=False);

And voilà, a nice plot that you can save as a PNG image, or even as an SVG if you want to fiddle with it a little bit more!

The code is available as a Github gist here.

By @Clément Chastagnol in
Tags : #python, #stats,