Du kannst nicht mehr als 25 Themen auswählen Themen müssen mit entweder einem Buchstaben oder einer Ziffer beginnen. Sie können Bindestriche („-“) enthalten und bis zu 35 Zeichen lang sein.

training.py 1.8KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import tensorflow as tf
  2. from matplotlib import pyplot as plt
  3. import numpy as np
  4. from Model import MyModel
  5. LEN_TRAIN = 5000
  6. LEN_SEQ = 100
  7. PRED = 0.02
  8. HIDDEN = 128
  9. model = MyModel(HIDDEN)
  10. dataset = np.load('dataset.npy')
  11. datasetp = np.roll(dataset, 1)
  12. datasetp[0] = dataset[0]
  13. data = (dataset - datasetp)/datasetp
  14. data = data*2/(max(data) - min(data))
  15. annee = np.array(list(range(len(data))))/365
  16. annee = annee - annee[-1]
  17. start_pred = int(len(data)*(1-PRED))
  18. print(len(data))
  19. print(start_pred)
  20. plt.figure(1)
  21. plt.plot(annee[:start_pred],data[:start_pred], label="apprentissage")
  22. plt.plot(annee[start_pred:],data[start_pred:], label="validation")
  23. plt.legend()
  24. plt.show()
  25. X_train = [data[0:start_pred-1]]
  26. Y_train = [data[1:start_pred]]
  27. X_train = np.expand_dims(np.array(X_train),2)
  28. Y_train = np.expand_dims(np.array(Y_train),2)
  29. model.compile(optimizer='adam',
  30. loss='binary_crossentropy',
  31. metrics=['binary_crossentropy'])
  32. import os
  33. os.system("rm -rf log_dir")
  34. model.fit(x=X_train, y=Y_train, epochs=30)
  35. Pred = X_train.copy()
  36. while len(Pred[0]) < len(data) :
  37. print(len(data) - len(Pred[0]))
  38. Pred = np.concatenate((Pred, np.array([[model.predict(Pred)[0][-1]]])),1)
  39. Pred = Pred/2*(max(data) - min(data))
  40. data_Pred = dataset.copy()
  41. Pred = np.squeeze(Pred)
  42. for i in range(start_pred,len(data_Pred)) :
  43. data_Pred[i] = data_Pred[i-1]*Pred[i] + data_Pred[i-1]
  44. plt.figure(2)
  45. plt.plot(annee[:start_pred],dataset[:start_pred], label="apprentissage")
  46. plt.plot(annee[start_pred:],dataset[start_pred:], label="validation")
  47. plt.plot(annee, data_Pred, label="prediction")
  48. plt.legend()
  49. fig, axs = plt.subplots(2)
  50. axs[0].plot(annee[:start_pred],dataset[:start_pred], label="apprentissage")
  51. axs[0].plot(annee[start_pred:],dataset[start_pred:], label="validation")
  52. axs[1].plot(annee, data_Pred, label="prediction")
  53. plt.legend()
  54. plt.show()