| @@ -1,6 +1,7 @@ | |||
| # Réseau inspiré de http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf | |||
| from keras.callbacks import History | |||
| from resnet18 import ResNet18 | |||
| import tensorflow as tf | |||
| import numpy as np | |||
| @@ -12,7 +13,7 @@ import os | |||
| def displayConvFilers(model, layer_name, num_filter=4, filter_size=(3,3), num_channel=0, fig_size=(2,2)): | |||
| layer_dict = dict([(layer.name, layer) for layer in mnist_model.layers]) | |||
| layer_dict = dict([(layer.name, layer) for layer in model.layers]) | |||
| weight, biais = layer_dict[layer_name].get_weights() | |||
| print(weight.shape) | |||
| @@ -34,23 +35,48 @@ def snake(x): | |||
| ## Chargement et normalisation des données | |||
| resnet18 = tf.keras.datasets.cifar10 | |||
| (train_images, train_labels), (test_images, test_labels) = resnet18.load_data() | |||
| val_images = train_images[40000:] | |||
| val_labels = train_labels[40000:] | |||
| train_images = train_images[:40000] | |||
| train_labels = train_labels[:40000] | |||
| train_images = train_images / 255.0 | |||
| val_images = val_images /255.0 | |||
| test_images = test_images / 255.0 | |||
| # POUR LES CNN : On rajoute une dimension pour spécifier qu'il s'agit d'imgages en NdG | |||
| train_images = train_images.reshape(50000,32,32,3) | |||
| train_images = train_images.reshape(40000,32,32,3) | |||
| val_images = val_images.reshape(10000,32,32,3) | |||
| test_images = test_images.reshape(10000,32,32,3) | |||
| # One hot encoding | |||
| train_labels = tf.keras.utils.to_categorical(train_labels) | |||
| val_labels = tf.keras.utils.to_categorical(val_labels) | |||
| test_labels = tf.keras.utils.to_categorical(test_labels) | |||
| filter_size_conv1 = (3,3) | |||
| model = ResNet18(10) | |||
| model.build(input_shape = (None,32,32,3)) | |||
| filter_size_conv1 = (5,5) | |||
| ''' | |||
| ## Définition de l'architecture du modèle | |||
| model = tf.keras.models.Sequential() | |||
| # Expliquez à quoi correspondent les valeurs numériques qui définissent les couches du réseau | |||
| model.add(tf.keras.layers.Conv2D(filters=6,kernel_size=filter_size_conv1,padding="same", activation=snake, input_shape=(32, 32, 3))) | |||
| model.add(tf.keras.layers.AveragePooling2D()) | |||
| model.add(tf.keras.layers.Conv2D(filters=16,kernel_size=(5,5),padding="valid", activation=snake)) | |||
| model.add(tf.keras.layers.AveragePooling2D()) | |||
| model.add(tf.keras.layers.Flatten()) | |||
| model.add(tf.keras.layers.Dense(120 , activation=snake)) | |||
| #mnist_model.add(tf.keras.layers.Dropout(0.5)) | |||
| model.add(tf.keras.layers.Dense(84 , activation=snake)) | |||
| model.add(tf.keras.layers.Dense(10 , activation='softmax')) | |||
| ''' | |||
| # expliquer le nombre de paramètre de ce réseau | |||
| print(model.summary()) | |||
| @@ -61,26 +87,33 @@ sgd = tf.keras.optimizers.Adam() | |||
| model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy']) | |||
| # On visualise avant l'entrainement | |||
| ''' | |||
| displayConvFilers(mnist_model, 'conv2d', | |||
| num_filter=6, | |||
| filter_size=filter_size_conv1, | |||
| num_channel=0, | |||
| fig_size=(2,3) | |||
| ) | |||
| ''' | |||
| model.fit(train_images, | |||
| history = model.fit(train_images, | |||
| train_labels, | |||
| batch_size=64, | |||
| epochs=4 | |||
| epochs=4, | |||
| validation_data=(val_images, val_labels), | |||
| ) | |||
| ## Evaluation du modèle | |||
| test_loss, test_acc = model.evaluate(test_images, test_labels) | |||
| print('Test accuracy:', test_acc) | |||
| fig, axs = plt.subplots(2, 1, figsize=(15,15)) | |||
| axs[0].plot(history.history['loss']) | |||
| axs[0].plot(history.history['val_loss']) | |||
| axs[0].title.set_text('Training Loss vs Validation Loss') | |||
| axs[0].legend(['Train', 'Val']) | |||
| axs[1].plot(history.history['accuracy']) | |||
| axs[1].plot(history.history['val_accuracy']) | |||
| axs[1].title.set_text('Training Accuracy vs Validation Accuracy') | |||
| axs[1].legend(['Train', 'Val']) | |||
| plt.savefig('./resnet18snake.png') | |||
| ''' | |||
| displayConvFilers(mnist_model, | |||
| @@ -92,11 +125,3 @@ displayConvFilers(mnist_model, | |||
| ) | |||
| ''' | |||
| ''' | |||
| displayConvFilers(mnist_model, 'conv2d_1', | |||
| num_filter=16, | |||
| filter_size=(5,5), | |||
| num_channel=1, | |||
| fig_size=(4,4) | |||
| ) | |||
| ''' | |||