185 lines
4.4 KiB
Plaintext
185 lines
4.4 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Notebook entrainement d'une IA à résoudre le XOR\n",
|
|
"[](https://mybinder.org/v2/git/http%3A%2F%2Fgitea.louisgallet.fr%2FCours-particulier%2FNotebooks/master?urlpath=%2Fdoc%2Ftree%2FNotebook_IA_Training.ipynb)\n",
|
|
"\n",
|
|
"La table de veritée du XOR est la suivante:\n",
|
|
"> x: entrée 1 ; y: entrée 2 ; r: sortie\n",
|
|
"\n",
|
|
"x = 0 ; y = 0 ; r = 0 \n",
|
|
"x = 1 ; y = 0 ; r = 1 \n",
|
|
"x = 0 ; y = 1 ; r = 1 \n",
|
|
"x = 1 ; y = 1 ; r = 0 "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "xwVyblbIKSYS",
|
|
"trusted": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"!pip install numpy\n",
|
|
"!pip install matplotlib\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "wgVHhV0zKeOV",
|
|
"trusted": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"def sigmoid(x):\n",
|
|
" return 1 / (1 + np.exp(-x))\n",
|
|
"\n",
|
|
"def sigmoid_derivative(x):\n",
|
|
" return x * (1 - x)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "Bv03so9jKhZH",
|
|
"trusted": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Entrées et sorties du XOR\n",
|
|
"X = np.array([[0,0],[0,1],[1,0],[1,1]])\n",
|
|
"y = np.array([[0],[1],[1],[0]])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "-utexs_vKjws",
|
|
"trusted": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"np.random.seed(42)\n",
|
|
"\n",
|
|
"# 2 neurones en entrée, 3 neurones cachés, 1 neurone en sortie\n",
|
|
"w1 = np.random.randn(2, 4)\n",
|
|
"b1 = np.zeros((1, 4))\n",
|
|
"\n",
|
|
"w2 = np.random.randn(4, 1)\n",
|
|
"b2 = np.zeros((1, 1))\n",
|
|
"\n",
|
|
"# Pour enregistrer la perte au fil des époques\n",
|
|
"loss_history = []"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"id": "AGlpWtKdKmcZ",
|
|
"trusted": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"learning_rate = 0.1\n",
|
|
"epochs = 10000\n",
|
|
"\n",
|
|
"for epoch in range(epochs):\n",
|
|
" # Forward pass\n",
|
|
" z1 = np.dot(X, w1) + b1\n",
|
|
" a1 = sigmoid(z1)\n",
|
|
"\n",
|
|
" z2 = np.dot(a1, w2) + b2\n",
|
|
" a2 = sigmoid(z2)\n",
|
|
"\n",
|
|
" # Calcul de l'erreur\n",
|
|
" loss = np.mean((y - a2) ** 2)\n",
|
|
" loss_history.append(loss)\n",
|
|
"\n",
|
|
" # Backpropagation\n",
|
|
" d_a2 = (a2 - y)\n",
|
|
" d_z2 = d_a2 * sigmoid_derivative(a2)\n",
|
|
"\n",
|
|
" d_w2 = np.dot(a1.T, d_z2)\n",
|
|
" d_b2 = np.sum(d_z2, axis=0, keepdims=True)\n",
|
|
"\n",
|
|
" d_a1 = np.dot(d_z2, w2.T)\n",
|
|
" d_z1 = d_a1 * sigmoid_derivative(a1)\n",
|
|
"\n",
|
|
" d_w1 = np.dot(X.T, d_z1)\n",
|
|
" d_b1 = np.sum(d_z1, axis=0, keepdims=True)\n",
|
|
"\n",
|
|
" # Mise à jour des poids\n",
|
|
" w2 -= learning_rate * d_w2\n",
|
|
" b2 -= learning_rate * d_b2\n",
|
|
" w1 -= learning_rate * d_w1\n",
|
|
" b1 -= learning_rate * d_b1"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/",
|
|
"height": 474
|
|
},
|
|
"id": "m60PeuYnKoie",
|
|
"outputId": "d5e217cf-a0b9-49a8-a01b-d9575b7f797f",
|
|
"trusted": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.plot(loss_history)\n",
|
|
"plt.title(\"Perte d'erreurs au fil des époques\")\n",
|
|
"plt.xlabel(\"Époques\")\n",
|
|
"plt.ylabel(\"Taux d'erreur moyen\")\n",
|
|
"plt.grid(True)\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"colab": {
|
|
"base_uri": "https://localhost:8080/"
|
|
},
|
|
"id": "A2sN1_fHKrh8",
|
|
"outputId": "fa906657-126c-4e13-cb6e-c6e7afd9bfd7",
|
|
"trusted": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"print(\"Sortie finale après entraînement :\")\n",
|
|
"a1 = sigmoid(np.dot(X, w1) + b1)\n",
|
|
"a2 = sigmoid(np.dot(a1, w2) + b2)\n",
|
|
"for i in range(4):\n",
|
|
" print(f\"Entrée : {X[i]} → Prédit : {a2[i][0]:.4f} → Arrondi : {round(a2[i][0])}\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"colab": {
|
|
"provenance": []
|
|
},
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"name": "python3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|