Browse Source

snake achieve cifar as well as relu

Emilien
emilien 4 years ago
parent
commit
b2b9b96282
1 changed files with 4 additions and 3 deletions
  1. 4
    3
      code/resnet18/resnet18_snake.py

+ 4
- 3
code/resnet18/resnet18_snake.py View File

@@ -33,7 +33,8 @@ def snake(x):
## Chargement et normalisation des données
resnet18 = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = resnet18.load_data()

train_images = train_images.astype('float32')
test_images = test_images.astype('float32')

from sklearn.model_selection import train_test_split
train_images, val_images, train_labels, val_labels = train_test_split(train_images,train_labels, test_size = 0.2,shuffle = True)
@@ -98,8 +99,8 @@ model.compile(sgd, loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(train_images,
train_labels,
batch_size=64,
epochs=4,
batch_size=256,
epochs=50,
validation_data=(val_images, val_labels),
)


Loading…
Cancel
Save