import torch
from kalman_filter import KalmanFilter
from scipy.optimize import minimize_scalar
import os
import pickle
from tqdm import tqdm 

def get_stable_state_kf(data):
    '''
        Based on the diagonal case
        Generate a vector [x_hat[t-1] | y[t]]' with both the Kalman filter estimate 
        and the new measurement for each system/batch.
    '''
    ## Assume that kalman gain and P reach a stable state after 2000 steps,
    ## given that under the right conditions convergence happens exponentially
    seq_len = int(2e3)

    A = data["trans"]
    C = data["meas"]
    Qw = data["Qw"]
    Qv = data["Qv"]
    is_diag = data["is_diag"]
    P_0 = torch.ones_like(A)

    return KalmanFilter.run_P_only(is_diag, seq_len, P_0, A, C, Qw, Qv)

def context_loss_func(ys, G_curr, G_opt):
    '''
    # TODO: need to decide here if I do for last element only or more
    ys: seq_len x dim
    G_curr: dim x dim
    G_opt: dim x dim
    computes the losses along the sequence length, as this is the context loss. 
    '''

    seq_len = ys.shape[-2]
    predicted = torch.matmul(G_curr, ys.transpose(-2, -1)) #[ dim, seqlen]
    gr_truth = torch.matmul(G_opt, ys.transpose(-2, -1)) #[ dim, seqlen]
    sq_sum = torch.sum((predicted - gr_truth) ** 2, dim=(-1, -2)) ## need to sum only over the last 2 dimensions

    return 1/(2 * seq_len) * sq_sum

def gd_step(ys, eta):
    '''
    GD Step corresponding to eta wrt last element in the ys sequence.
    # TODO: need to decide here if I do for last element only or more
    # TODO: somewhere aroudn in this script i am missing scalings of the loss
    compute GD step according to formula for problem formulation 1 where current prediction depends on 
    previous element only, history of 1

    ys: seq_len x dim
    eta :  1. One scalar for every batch element
    '''
    seq_len = ys.shape[-2]
    new_ys = torch.zeros_like(ys)
    
    # this modifies all y_js according to a step that, in a transformer, would only be appropriate
    # for the last element in the sequence (due to the causal mask)
    conditioning = torch.matmul(ys[:seq_len + 1, :].transpose(-2, -1), ys[:seq_len + 1, :])
    for j in range(seq_len):
        step = torch.matmul(conditioning, ys[j, :]) # [batch_sz, dim, 1]
        new_ys[j, :] = ys[j, :] - eta/seq_len * step

    return new_ys

def distance_and_cos(tf_ys, gd_ys):
    # batch_avg_dist
    avg_rel_dist = torch.mean(torch.norm(tf_ys - gd_ys, dim=-1))

    # batch_avg_cos
    cos_sim = torch.nn.CosineSimilarity(dim = -1)
    avg_cos = torch.mean(cos_sim(tf_ys, gd_ys))

    return avg_rel_dist, avg_cos


def get_best_lrs(ys, G_opt, G_0):
    '''
        ys: [batch_sz, seq_len, dim]
        G_opt: batch_sz, dim, dim
        G_0: batch_sz, dim, dim
    '''
    # TODO: probably all of these need to be converted to numpy arrays for the line search for the eta
    batch_sz = ys.shape[0]
    best_etas = torch.zeros(batch_sz)

    # this function I need to minimize with respect to eta
    print("Getting best LRs for gd steps ... ")
    for b in tqdm(range(batch_sz)):
        gd_eta = lambda eta: context_loss_func(gd_step(ys[b, :, :], eta), 
                                               G_0[b, :, :], G_opt[b, :, :]).detach().cpu().item()
        min_eta = 0.
        max_eta = torch.sum(ys[b, :, :]**2).item()
        res = minimize_scalar(gd_eta, bounds=(min_eta, max_eta), method='bounded')
        best_etas[b] = res.x
        # print(res.x)

    return best_etas


def get_gd_steps_history_p(is_diag, ys, As, Cs, Ks, p):
    assert p >= 1, "History length has to be at least 1."
    device = ys.device
    # create augmentation
    dim_y = ys.shape[-1]
    batch_sz = ys.shape[0]
    seq_len_y = ys.shape[-2]
    aug_seq_len_y = seq_len_y - p + 1
    aug_ys = torch.zeros(batch_sz, aug_seq_len_y, p * dim_y, device=device)
    flattened_ys = ys.reshape(batch_sz, dim_y * seq_len_y)

    ### Build augmented ys
    for i in range(aug_seq_len_y):
        aug_ys[:, i, :] = flattened_ys[:, i *  dim_y : i *  dim_y + p * dim_y]

    ### Build augmented Gs
    Id = torch.eye(dim_y, device=device).unsqueeze(0).repeat(batch_sz, 1, 1)
    kern = torch.zeros(batch_sz, dim_y, p * dim_y, device=device)

    KCs = torch.mul(Ks, Cs) if is_diag else torch.matmul(Ks, Cs)
    A_KC_pow = torch.ones_like(Ks) if is_diag else Id ##TODO this will have to be modified in case we'll ever have dim_x diff dim_y
    for i in range(p):
        kern[:, :, dim_y * (p - (i+1)) : dim_y * (p-i)] \
                    = torch.mul(torch.mul(A_KC_pow, Ks).unsqueeze(-1).transpose(-2, -1), Id) \
                        if is_diag \
                        else torch.matmul(A_KC_pow, Ks)
        A_KC_pow = torch.mul(As - KCs, A_KC_pow) if is_diag else torch.matmul(As - KCs, A_KC_pow)

    full_dim_Cs = torch.mul(Cs.unsqueeze(-1).transpose(-2, -1), Id) if is_diag else Cs
    G_opt = torch.matmul(full_dim_Cs, kern)

    G_0 = torch.zeros_like(G_opt)
    gd_preds = torch.zeros(batch_sz, aug_seq_len_y, aug_ys.shape[-1], device=device)

    best_etas = get_best_lrs(aug_ys, G_opt, G_0)
    for b in range(batch_sz):
        gd_preds[b, :, :] = gd_step(aug_ys[b, :, :], best_etas[b])[-1, :]

    # take only the last section of the last dim
    return gd_preds[:, :, -dim_y:]


def trained_transf_vs_gd_step(data, model):
    # TODO: need to figure out what ys to take, prob just the prefix
    if not data["is_diag"]:
        raise NotImplementedError()
    
    ys = data["seqs"][:, :-1, :]
    Cs = data["meas"]
    As = data["trans"]
    is_diag = data["is_diag"]
    Ks, _ = get_stable_state_kf(data)
    
    gd_preds = get_gd_steps_history_p(is_diag, ys, As, Cs, Ks[:, -1, :])
    tf_preds, _ = model(ys)

    return distance_and_cos(tf_preds[:, -1, :], gd_preds[:, -1, :])


def learned_tf_gd_output(input, target, is_diag, As, Cs, Ks, p):
    assert p >= 1, "History length has to be at least 1."    
    dim_y = input.shape[-1]
    batch_size = input.shape[0]
    seq_len = input.shape[1]

    '''
    ### Build augmented Gs
    Id = torch.eye(dim_y).unsqueeze(0).repeat(batch_size, 1, 1)
    kern = torch.zeros(batch_size, dim_y, p * dim_y)

    KCs = torch.mul(Ks, Cs) if is_diag else torch.matmul(Ks, Cs)
    A_KC_pow = torch.ones_like(Ks) if is_diag else Id ##TODO this will have to be modified in case we'll ever have dim_x diff dim_y
    for i in range(p):
        kern[:, :, dim_y * (p - (i+1)) : dim_y * (p-i)] \
                    = torch.mul(torch.mul(A_KC_pow, Ks).unsqueeze(-1).transpose(-2, -1), Id) \
                        if is_diag \
                        else torch.matmul(A_KC_pow, Ks)
        A_KC_pow = torch.mul(As - KCs, A_KC_pow) if is_diag else torch.matmul(As - KCs, A_KC_pow)
    full_dim_Cs = torch.mul(Cs.unsqueeze(-1).transpose(-2, -1), Id) if is_diag else Cs
    G_opt = torch.matmul(full_dim_Cs, kern)
    G_0 = torch.zeros_like(G_opt)
    '''
    if is_diag:
        KCs = torch.mul(Ks, Cs)
        G_opt = torch.zeros(batch_size, dim_y, dim_y, device=Ks.device)
        for i in range(batch_size):
            G_opt[i, :, :] = torch.diag(KCs[i, :])
    else:
        KCs = torch.matmul(Ks, Cs)
        G_opt = KCs
    G_0 = torch.zeros_like(G_opt)
    '''
    KCs = torch.mul(Ks, Cs) if is_diag else torch.matmul(Ks, Cs)
    G_opt = torch.zeros(batch_size, dim_y, dim_y)
    for i in range(batch_size):
        G_opt[i, :, :] = torch.diag(KCs[i, :])
    G_0 = torch.zeros_like(G_opt)
    '''
    output = torch.zeros_like(input)
    
    for i in range(batch_size):
        for j in range(seq_len):
            # Optimize eta based on G * y_i -> y_i+1, gd_output_loss = sum_{i = 0}^{j}(G_0 * y_i -y_i_updated)^2
            gd_output_loss = lambda eta: torch.sum((torch.matmul(G_0[i, :, :], input[i, :j+1, :].transpose(-2, -1)) - gd_updated_data(input, target, eta, G_0[i, :, :], i, j)) ** 2).detach().cpu().item()
            # Optimize eta based on G * y_i -> G_opt * y_i, different step size for different y_i
            #gd_output_loss = lambda eta: context_loss_func(gd_step(input[i, :j+1, :], eta), G_0[i, :, :], G_opt[i, :, :])
            
            min_eta = 0.
            max_eta = torch.sum(input[i, j, :]**2).item()
            res = minimize_scalar(gd_output_loss, bounds=(min_eta, max_eta), method='bounded', options={'xatol': 1e-3})
            eta_star = res.x
            conditioning = torch.matmul(input[i, :j+1, :].transpose(-2, -1), input[i, :j+1, :])
            #conditioning = torch.matmul(input[i, :j, :].transpose(-2, -1), input[i, :j, :])
            step = torch.matmul(conditioning, input[i, j, :])
            output[i, j, :] = input[i, j, :] - eta_star/(j+1) * step
            
            #if i == 0:
            #    print("optimized eta: ", eta_star)
    #cos_sim = torch.nn.CosineSimilarity(dim = -1)
    #avg_cos = torch.mean(cos_sim(output[:, -1, :], input[:, -1, :]))
    #print("The cosine similarity between the last GD step and the last input is: ", avg_cos)
    ### Use get_best_lrs
    # Optimize eta based on G * y_i -> G_opt * y_i, different step size for different y_i     
    '''
    for j in range(seq_len):        
        eta_star = get_best_lrs(input[:, :j+1, :], G_opt, G_0)
        for i in range(batch_size):    
            conditioning = torch.matmul(input[i, :j+1, :].transpose(-2, -1), input[i, :j+1, :])
            step = torch.matmul(conditioning, input[i, j, :])
            output[i, j, :] = input[i, j, :] - eta_star[i]/(j+1) * step
    '''
    # Optimize eta based on G * y_i -> G_opt * y_i, step size based on the entire sequence for all y_i
    '''
    eta_star = get_best_lrs(input, G_opt, G_0)
    for i in range(batch_size):
        output[i, :, :] = gd_step(input[i, :, :], eta_star[i])
    '''
    return output


def gd_updated_data(input, target, eta, G_0, batch_ind, seq_ind):
    '''
    input[batch_ind, :seq_ind+1, :] = [y_0, y_1, ..., y_seq_ind]^T
    target[batch_ind, :seq_ind+1, :] = [y_1, y_2, ..., y_seq_ind+1]^T

    sumyjyjT = y_0 * y_0^T + y_1 * y_1^T + ... + y_seq_ind * y_seq_ind^T
    sumyjp1yjT = y_1 * y_0^T + y_2 * y_1^T + ... + y_seq_ind+1 * y_seq_ind^T
    update = (G_0 * sumyjyjT - sumyjp1yjT) * y_i, i in [0, seq_ind]
    input_updated = y_i - eta * update, i in [0, seq_ind]
    '''
    sumyjyjT = torch.matmul(input[batch_ind, :seq_ind+1, :].transpose(-2, -1), input[batch_ind, :seq_ind+1, :])
    sumyjp1yjT = torch.matmul(target[batch_ind, :seq_ind+1, :].transpose(-2, -1), input[batch_ind, :seq_ind+1, :])
    #sumyjyjT = torch.matmul(input[batch_ind, :seq_ind, :].transpose(-2, -1), input[batch_ind, :seq_ind, :])
    #sumyjp1yjT = torch.matmul(target[batch_ind, :seq_ind, :].transpose(-2, -1), input[batch_ind, :seq_ind, :])
    update = torch.matmul(torch.matmul(G_0, sumyjyjT) - sumyjp1yjT, input[batch_ind, :seq_ind+1, :].transpose(-2, -1))
    input_updated = target[batch_ind, :seq_ind+1, :].transpose(-2, -1) - eta * update

    #print("The device of input_updated: ", input_updated.get_device())

    return input_updated



if __name__ == '__main__':
    
    EXP_RESULTS_FOLDER = os.getcwd() + "/exp_results"
    SAVED_DATA = os.getcwd() + "/saved_models"
    (cfg, model, results) = torch.load(SAVED_DATA + '/noisy_transformer.pt')

    with open(SAVED_DATA+"/dynsys_5_200000_0.01.pkl", 'rb') as f: 
            [train_data, test_data] = pickle.load(f)

    avg_dist, avg_cos = trained_transf_vs_gd_step(test_data, model)

    print(avg_dist, avg_cos)
    # plot_losses(EXP_RESULTS_FOLDER + "/", [avg_dist], ["gd vs transf"], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\|y^{transf} - y^{GD}\|$")
    # plot_losses(EXP_RESULTS_FOLDER + "/", [avg_cos], ["cos gd vs transf"], extra_title_info=str(datetime.datetime.now()), x_label='k (iterations)', y_label=r"$\cos(y^{transf}, y^{GD})$", y_scale='linear')