| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566 |
- import tensorflow as tf
- from matplotlib import pyplot as plt
- import numpy as np
- from Model import MyModel
-
- LEN_TRAIN = 5000
- LEN_SEQ = 100
- PRED = 0.02
- HIDDEN = 128
-
- model = MyModel(HIDDEN)
-
- dataset = np.load('dataset.npy')
- datasetp = np.roll(dataset, 1)
- datasetp[0] = dataset[0]
- data = (dataset - datasetp)/datasetp
- data = data*2/(max(data) - min(data))
-
- annee = np.array(list(range(len(data))))/365
- annee = annee - annee[-1]
- start_pred = int(len(data)*(1-PRED))
- print(len(data))
- print(start_pred)
- plt.figure(1)
- plt.plot(annee[:start_pred],data[:start_pred], label="apprentissage")
- plt.plot(annee[start_pred:],data[start_pred:], label="validation")
- plt.legend()
- plt.show()
-
- X_train = [data[0:start_pred-1]]
- Y_train = [data[1:start_pred]]
-
- X_train = np.expand_dims(np.array(X_train),2)
- Y_train = np.expand_dims(np.array(Y_train),2)
-
-
- model.compile(optimizer='adam',
- loss='binary_crossentropy',
- metrics=['binary_crossentropy'])
- import os
- os.system("rm -rf log_dir")
-
- model.fit(x=X_train, y=Y_train, epochs=30)
-
- Pred = X_train.copy()
- while len(Pred[0]) < len(data) :
- print(len(data) - len(Pred[0]))
- Pred = np.concatenate((Pred, np.array([[model.predict(Pred)[0][-1]]])),1)
- Pred = Pred/2*(max(data) - min(data))
- data_Pred = dataset.copy()
- Pred = np.squeeze(Pred)
- for i in range(start_pred,len(data_Pred)) :
- data_Pred[i] = data_Pred[i-1]*Pred[i] + data_Pred[i-1]
-
- plt.figure(2)
- plt.plot(annee[:start_pred],dataset[:start_pred], label="apprentissage")
- plt.plot(annee[start_pred:],dataset[start_pred:], label="validation")
- plt.plot(annee, data_Pred, label="prediction")
- plt.legend()
-
- fig, axs = plt.subplots(2)
- axs[0].plot(annee[:start_pred],dataset[:start_pred], label="apprentissage")
- axs[0].plot(annee[start_pred:],dataset[start_pred:], label="validation")
- axs[1].plot(annee, data_Pred, label="prediction")
- plt.legend()
- plt.show()
|