import torch
import pickle
from dyn_sys import NoisyDynSys
from helpers import get_script_dir, set_seeds
from config import REL_PATH_DATA_FLDR, AugmType, AUG_TOK_EXTRA_POS



def gen_train_test_dyn_sys(noisy_dyn_sys, cfg, train_size, val_size, s_noise_var, m_noise_var, silent=False):
    '''
        here the variances are always assumed to be diagonal matrices, irrespective of how A and C are.
        And so they are represented as vectors
    '''
    # val_size = min(train_size, val_size)

    new_seq_len = cfg["seq_len"] + 1 # need one extra element for next token prediction

    train_data = {}
    train_ys, train_As, train_Cs, train_x0, train_xs, train_Qw, train_Qv \
            = noisy_dyn_sys.get_data(train_size, new_seq_len, cfg["A_type"], cfg["C_type"], \
                                     cfg["single_sys"], cfg["is_diag"], s_noise_var, m_noise_var, silent)

    train_data["seqs"] = train_ys
    train_data["trans"] = train_As
    train_data["meas"] = train_Cs
    train_data["Qw"] = train_Qw
    train_data["Qv"] = train_Qv
    train_data["x0"] = train_x0
    train_data["states"] = train_xs
    train_data["is_diag"] = cfg["is_diag"]


    test_data = {}
    if val_size != 0:
        test_ys, test_As, test_Cs, test_x0, test_xs, test_Qw, test_Qv\
            = noisy_dyn_sys.get_data(val_size, new_seq_len, cfg["A_type"], cfg["C_type"],\
                                        cfg["single_sys"], cfg["is_diag"], 
                                        s_noise_var[:val_size, :, :] if s_noise_var is not None else s_noise_var, 
                                        m_noise_var[:val_size, :, :] if m_noise_var is not None else m_noise_var,
                                        silent)
        test_data["seqs"] = test_ys
        test_data["trans"] = test_As
        test_data["meas"] = test_Cs
        test_data["x0"] = test_x0
        test_data["Qw"] = test_Qw
        test_data["Qv"] = test_Qv
        test_data["states"] = test_xs
        test_data["is_diag"] = cfg["is_diag"]

    return train_data, test_data

def gen_test_data_diff_rand_seed(noisy_dyn_sys, cfg, rand_seeds, noise_var):
    device = cfg["device"]
    no_rand_seeds = len(rand_seeds)
    new_seq_len = cfg["seq_len"] + 1
    s_noise_var = noise_var * torch.ones(1, cfg["state_dim"], 1, device=device)
    m_noise_var = noise_var * torch.ones(1, cfg["state_dim"], 1, device=device)
    test_data = {}
    test_ys = torch.zeros(no_rand_seeds, new_seq_len, cfg["state_dim"], device=device)
    test_As = torch.zeros(no_rand_seeds, cfg["state_dim"], device=device)
    test_Cs = torch.zeros(no_rand_seeds, cfg["state_dim"], device=device)
    test_x0s = torch.zeros(cfg["state_dim"], 1, device=device)
    test_xs = torch.zeros(no_rand_seeds, new_seq_len, cfg["state_dim"], device=device)
    test_Qws = torch.zeros(no_rand_seeds, cfg["state_dim"], device=device)
    test_Qvs = torch.zeros(no_rand_seeds, cfg["state_dim"], device=device)
    #print("s_noise_var[0, :]'s shape: ", s_noise_var[0, :].shape)
    for i in range(no_rand_seeds):
        set_seeds(rand_seeds[i])
        test_y, test_A, test_C, test_x0, test_x, test_Qw, test_Qv\
            = noisy_dyn_sys.get_data(1, new_seq_len, cfg["A_type"], cfg["C_type"],\
                                        cfg["single_sys"], cfg["is_diag"], 
                                        s_noise_var if s_noise_var is not None else s_noise_var, 
                                        m_noise_var if m_noise_var is not None else m_noise_var)
        test_ys[i, :, :] = test_y
        test_As[i, :] = test_A
        test_Cs[i, :] = test_C
        test_x0s = test_x0
        test_xs[i, :, :] = test_x
        test_Qws[i, :] = test_Qw
        test_Qvs[i, :] = test_Qv
    test_data["seqs"] = test_ys
    test_data["trans"] = test_As
    test_data["meas"] = test_Cs
    test_data["x0"] = test_x0s
    test_data["Qw"] = test_Qws
    test_data["Qv"] = test_Qvs
    test_data["states"] = test_xs
    test_data["is_diag"] = cfg["is_diag"]

    return test_data


def create_tokens(orig_tokens, cfg, silent=False):
    ######## If no aug, simply return the next token pred format for orig_tokens
    aug_type = cfg["aug_type"]
    
    if aug_type is None:
        inputs = orig_tokens[:, :-1, :]
        labels = orig_tokens[:, 1:, :]

        return inputs.to(device), labels.to(device)
    
    assert aug_type in AUG_TOK_EXTRA_POS, "Augmentation type not supported yet."

    batch_sz = orig_tokens.shape[0]
    seq_len = orig_tokens.shape[1]
    dim = orig_tokens.shape[2]
    window_size = cfg["wdw_size"]
    device = cfg["device"]

    if not dim == 1:
        raise ValueError("This method only currently supports y dimension to be 1. Extension is needed for higher dimensions.")

    # term1: dependence on past, i.e., x; term2: predicted variable, i.e. y, term3: atrificially 
    # added extra positions to aid with computation
    num_concats = window_size + 1 + AUG_TOK_EXTRA_POS[aug_type]
    aug_token_size = num_concats * dim
    
    silent or print("\n -- Original token dimension = {}, seq. len. = {}, window_size = {}, aug_type = {}, new token dimension = {}.\n"
                        .format(dim, seq_len, window_size, aug_type, aug_token_size))
    
    aug_seq_len = seq_len - window_size
    # Without the first block of aug_tokens as zeroes
    flattened_tokens = orig_tokens.reshape(batch_sz, dim * seq_len)

    aug_tokens = torch.zeros(batch_sz, aug_seq_len, aug_token_size, device=device)
    if aug_type == AugmType.TF_GD:
        for i in range(aug_seq_len):
            start_ind = i * dim 
            end_ind = start_ind + (window_size + 1) * dim
            aug_tokens[:, i, : (window_size + 1) * dim] = flattened_tokens[:, start_ind : end_ind]

        inputs = aug_tokens.detach().clone()
        inputs[:, -1, (-1) * dim : ] = 0.

        labels = aug_tokens.detach().clone()

    elif aug_type == AugmType.MESA_OPT:

        for i in range(aug_seq_len):
            start_ind = i * dim 
            end_ind = start_ind + (window_size + 1) * dim
            aug_tokens[:, i, (-1) * (window_size + 1) * dim : ] = torch.flip(flattened_tokens[:, start_ind : end_ind], dims=(-1,))

        inputs = aug_tokens[:, :-1, :].detach().clone() #aug_tokens.detach().clone().to(device=device)
        labels = aug_tokens[:, 1:, :].detach().clone() #aug_tokens.detach().clone().to(device=device)

        # print(aug_tokens[0, -3, :])
        # print(aug_tokens[0, -2, :])
        # print(aug_tokens[0, -1, :])
        # print(orig_tokens[0, -8:, :])
        # print(inputs[0, -1, :])
        # print(labels[0, -1, :])
        # exit()
    
    else:
        raise TypeError("No such token augmentation type exists.")
    
    # These guys are already created on the correct device, no need for .to(device)
    return inputs, labels


def get_dynsys_data(from_saved, cfg, train_size, test_size, silent=False):
    noise_var = cfg["noise_var"]
    data_fldr = get_script_dir() + REL_PATH_DATA_FLDR
    file_path = data_fldr + "/" + cfg["A_type"].name + "_dynsys_" + str(cfg["state_dim"])\
                          +"_"+str(train_size)+"_"+str(noise_var)+".pkl"

    if not from_saved:
            noisy_dyn_sys = NoisyDynSys(cfg["state_dim"], cfg["output_dim"], cfg["device"])
            s_noise_var = noise_var * torch.ones(train_size, cfg["state_dim"], 1, device=cfg["device"])
            m_noise_var = noise_var * torch.ones(train_size, cfg["output_dim"], 1, device=cfg["device"])

            if not cfg["multi_rand_seeds"]:
                train_data, test_data = gen_train_test_dyn_sys(noisy_dyn_sys, cfg, train_size, test_size, s_noise_var, m_noise_var, silent)
            else:
                train_data, _ = gen_train_test_dyn_sys(noisy_dyn_sys, cfg, train_size, test_size, s_noise_var, m_noise_var)
                no_rand_seeds = cfg["test_rand_seed_number"]
                rand_seed_0 = cfg["rand_seed"]
                rand_seeds = []
                for i in range(no_rand_seeds):
                    rand_seeds.append(rand_seed_0 + 100 * i)
                test_data = gen_test_data_diff_rand_seed(noisy_dyn_sys, cfg, rand_seeds, noise_var)
            # with open(file_path, 'wb') as f: 
            #     pickle.dump([train_data, test_data, cfg], f)
            # silent or print("\n**** Data saved at path" + file_path + ".\n")
    else:
        silent or print("\n**** Loading data from path {} ...".format(file_path))
        with open(file_path, 'rb') as f: 
            [train_data, test_data, cfg] = pickle.load(f)
        silent or print("\n**** Finished loading data.")

    return train_data, test_data


def extract_pred_and_label(outputs, true_labels, cfg):
    preds = outputs
    label = true_labels

    if cfg["aug_type"] == AugmType.TF_GD:
        # Get token as per kwangjun's paper
        preds = outputs[:, -1, (-1) * cfg["output_dim"]: ]
        label = true_labels[:, -1, (-1) * cfg["output_dim"]: ]
        # print("\nOriginal final token = {}, extracted prediction = {}.\n".format(preds[0, :], label[0, : ]))
        # exit()
    elif cfg["aug_type"] == AugmType.MESA_OPT:
        pass # No change to the augmentation, as that is in next token pred scenario

    else:
        raise TypeError("Specified augmentation type does not exist.")
    
    return preds, label

