import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from models.validation_likelihood_tuning import get_autotuned_predictions_data

def GWN(A, scale):
    noise = A.copy()
    noise = noise + np.random.normal(0, scale, len(noise)) 
    return noise
        
def TSA(train, test, hypers, num_samples, model, scale, mean, std, z_score_flag, tau, epsilon):
    l = len(train)
    h = len(test)
    s = []
    f = []
    for i in range(tau):
        index = 0
        r = 0
        temp = 0
        for i in range(int(h/2)):
            if i + (l-h/2) not in s: 
                if z_score_flag:
                    train_1 = (train - mean)/std
                    train_2 = (train - mean)/std
                    train_ = (train - mean)/std
                else:
                    train_1 = train + 0.00000000001
                    train_2 = train + 0.00000000001
                    train_ = train + 0.00000000001
                    
                s_temp = s + [int(i + (l-h/2))]
                f_temp_1 = f + [epsilon]
                f_temp_2 = f + [-1*epsilon]
                
                train_1[s_temp] = train_1[s_temp] * (1+np.array(f_temp_1))
                train_2[s_temp] = train_2[s_temp] * (1+np.array(f_temp_2))
                
                pred_dict1 = get_autotuned_predictions_data(train_1, test, hypers, num_samples, model, verbose=False, parallel=False)
                pred1 = pred_dict1['median']
                pred1 = pd.Series(pred1, index=test.index)
                pred_dict2 = get_autotuned_predictions_data(train_2, test, hypers, num_samples, model, verbose=False, parallel=False)
                pred2 = pred_dict2['median']
                pred2 = pd.Series(pred2, index=test.index)
                pred_dict = get_autotuned_predictions_data(train_, test, hypers, num_samples, model, verbose=False, parallel=False)
                pred = pred_dict['median']
                pred = pd.Series(pred, index=test.index)
                
                dis_1 = pred1 - pred
                dis_2 = pred2 - pred
                
                dis_1 = dis_1.abs().mean()
                dis_2 = dis_2.abs().mean()
            
                if dis_1>dis_2:
                    if dis_1>r:
                        r = dis_1
                        index = int(i + (l-h/2))
                        temp = epsilon
                elif dis_2>dis_1:
                    if dis_2>r:
                        r = dis_2
                        index = int(i + (l-h/2))
                        temp = -1* epsilon
        s.append(index)
        f.append(temp)
    
    print('s is :',s)
    print('direction is :', f)
    print('prediction error is:', r)
    noise = train.copy()
    noise[s] = noise[s] * (1+np.array(f))
    
    return noise,s,f,r