{ "cells": [ { "cell_type": "markdown", "id": "5cd10ad1", "metadata": {}, "source": [ "# 01: Linear Regression\n", "\n", "In this notebook we'll cover:\n", "\n", "- Linear regression\n", "- Loss functions\n", "- Gradient descent\n", "- Linear regression as a neural network\n", "- Pytorch\n", "- Start to consider how neural networks fit non-linear functions" ] }, { "cell_type": "code", "execution_count": 2, "id": "7dd45616", "metadata": {}, "outputs": [], "source": [ "import itertools\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "import torch\n", "from torch import nn" ] }, { "cell_type": "markdown", "id": "4c05da31", "metadata": {}, "source": [ "## Dataset\n", "\n", "I have painstakingly collected 100 days of data on the average happiness (out of 10) of Turing staff in the office based on the number of oranges 🍊 and cups of coffee ☕️ available at the start of the day:" ] }, { "cell_type": "code", "execution_count": 3, "id": "633c934e", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | oranges | \n", "coffee | \n", "happy | \n", "
---|---|---|---|
0 | \n", "10 | \n", "528 | \n", "5.832712 | \n", "
1 | \n", "12 | \n", "509 | \n", "4.648788 | \n", "
2 | \n", "15 | \n", "628 | \n", "4.286937 | \n", "
3 | \n", "14 | \n", "389 | \n", "5.863064 | \n", "
4 | \n", "13 | \n", "216 | \n", "3.109686 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "
95 | \n", "11 | \n", "293 | \n", "2.525590 | \n", "
96 | \n", "15 | \n", "750 | \n", "6.802309 | \n", "
97 | \n", "17 | \n", "701 | \n", "6.446739 | \n", "
98 | \n", "11 | \n", "832 | \n", "7.072423 | \n", "
99 | \n", "19 | \n", "355 | \n", "4.444550 | \n", "
100 rows × 3 columns
\n", "\n", " | oranges | \n", "coffee | \n", "happy | \n", "
---|---|---|---|
count | \n", "1.000000e+02 | \n", "1.000000e+02 | \n", "1.000000e+02 | \n", "
mean | \n", "-6.661338e-17 | \n", "6.217249e-17 | \n", "-5.817569e-16 | \n", "
std | \n", "1.000000e+00 | \n", "1.000000e+00 | \n", "1.000000e+00 | \n", "
min | \n", "-3.094917e+00 | \n", "-2.163980e+00 | \n", "-2.121838e+00 | \n", "
25% | \n", "-6.171814e-01 | \n", "-6.119499e-01 | \n", "-6.178392e-01 | \n", "
50% | \n", "5.856466e-02 | \n", "-1.085261e-01 | \n", "8.677973e-02 | \n", "
75% | \n", "7.343107e-01 | \n", "6.999679e-01 | \n", "6.344933e-01 | \n", "
max | \n", "2.311051e+00 | \n", "2.550109e+00 | \n", "2.220604e+00 | \n", "