Vous ne pouvez pas sélectionner plus de 25 sujets Les noms de sujets doivent commencer par une lettre ou un nombre, peuvent contenir des tirets ('-') et peuvent comporter jusqu'à 35 caractères.

training.py 2.0KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. import tensorflow as tf
  2. from matplotlib import pyplot as plt
  3. import numpy as np
  4. from Model import MyModel
  5. LEN_SEQ = 64
  6. PRED = 0.05
  7. HIDDEN = 128
  8. model = MyModel(HIDDEN)
  9. dataset = np.load('dataset.npy')
  10. scale = (np.max(dataset) - np.min(dataset))
  11. data = dataset/scale
  12. shift = np.min(data)
  13. data = data - shift
  14. annee = np.array(list(range(len(data))))/365
  15. annee = annee - annee[-1]
  16. start_pred = int(len(data)*(1-PRED))
  17. print(len(data))
  18. print(start_pred)
  19. plt.figure(1)
  20. plt.plot(annee[:start_pred],data[:start_pred], label="apprentissage")
  21. plt.plot(annee[start_pred:],data[start_pred:], label="validation")
  22. plt.legend()
  23. plt.show()
  24. X_train_tot = [data[0:start_pred-1]]
  25. Y_train_tot = [data[1:start_pred]]
  26. X_train_tot = np.expand_dims(np.array(X_train_tot),2)
  27. Y_train_tot = np.expand_dims(np.array(Y_train_tot),2)
  28. X_train = X_train_tot[:,:LEN_SEQ,:]
  29. Y_train = Y_train_tot[:,:LEN_SEQ,:]
  30. for i in range(len(X_train_tot[0]) - LEN_SEQ) :
  31. X_train = np.concatenate((X_train, X_train_tot[:,i:i+LEN_SEQ,:]),0)
  32. Y_train = np.concatenate((Y_train, Y_train_tot[:,i:i+LEN_SEQ,:]),0)
  33. print(X_train_tot.shape)
  34. print(X_train.shape)
  35. model.compile(optimizer='adam',
  36. loss='binary_crossentropy',
  37. metrics=['binary_crossentropy'])
  38. import os
  39. os.system("rm -rf log_dir")
  40. model.fit(x=X_train, y=Y_train, batch_size=16, epochs=5, shuffle=True)
  41. Pred = X_train_tot.copy()
  42. while len(Pred[0]) < len(data) :
  43. print(len(data) - len(Pred[0]))
  44. Pred = np.concatenate((Pred, np.array([[model.predict(Pred)[0][-1]]])),1)
  45. Pred = Pred + shift
  46. Pred = Pred * scale
  47. data_Pred = np.squeeze(Pred)
  48. plt.figure(2)
  49. plt.plot(annee[:start_pred],dataset[:start_pred], label="apprentissage")
  50. plt.plot(annee[start_pred:],dataset[start_pred:], label="validation")
  51. plt.plot(annee, data_Pred, label="prediction")
  52. plt.legend()
  53. fig, axs = plt.subplots(2)
  54. axs[0].plot(annee[:start_pred],dataset[:start_pred], label="apprentissage")
  55. axs[0].plot(annee[start_pred:],dataset[start_pred:], label="validation")
  56. axs[1].plot(annee, data_Pred, label="prediction")
  57. plt.legend()
  58. plt.show()