| @@ -33,7 +33,8 @@ def snake(x): | |||
| ## Chargement et normalisation des données | |||
| resnet18 = tf.keras.datasets.cifar10 | |||
| (train_images, train_labels), (test_images, test_labels) = resnet18.load_data() | |||
| train_images = train_images.astype('float32') | |||
| test_images = test_images.astype('float32') | |||
| from sklearn.model_selection import train_test_split | |||
| train_images, val_images, train_labels, val_labels = train_test_split(train_images,train_labels, test_size = 0.2,shuffle = True) | |||
| @@ -98,8 +99,8 @@ model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy']) | |||
| history = model.fit(train_images, | |||
| train_labels, | |||
| batch_size=64, | |||
| epochs=4, | |||
| batch_size=256, | |||
| epochs=50, | |||
| validation_data=(val_images, val_labels), | |||
| ) | |||