import torch
from tqdm import tqdm
import  matplotlib.pyplot as plt
from model import *
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import numpy as np
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#################################### Function defination ##########################################

def set_random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)



def make_data_dict(data_path, fold=0, args=None):
    """to be polish"""
    full_data = np.load(data_path, allow_pickle=True).item()


    data_dict = full_data["data"][fold]
    data_dict["CONTI_2_DISCT_dicts"] = full_data["CONTI_2_DISCT_dicts"]
    data_dict["DISCT_2_CONTI_dicts"] = full_data["DISCT_2_CONTI_dicts"]

    data_dict["ndims"] = full_data["ndims"]

    return data_dict



def get_ind_time(data_dict, key1, key2):
    x1 = torch.tensor(data_dict[key1], dtype=torch.float32) .to(device)
    ind1 = x1[:, :2]
    t1 = x1[:, -1]
    x2 = torch.tensor(data_dict[key2], dtype=torch.long).to(device)
    ind2 = x2[:, :2]
    t2 = x2[:, -1]
    return ind1, ind2, t1, t2




def Write_excel(list0,list1, list2, output_path):
    data = {
        "Epoch":list0,
        'RMSE': list1,
        'MAE': list2
    }
    df = pd.DataFrame(data)
    df.to_excel(output_path, index=True)



def load_data(data_path, flag, fold):
    if flag:
        full_data = np.load(data_path, allow_pickle=True).item()
        full_data = full_data["data"][fold]
        tr_ind_conti = torch.tensor(full_data["tr_ind_conti"], dtype=torch.float32).to(device)
        tr_ind = torch.tensor(full_data["tr_ind"], dtype=torch.long).to(device)


        tr_time_conti = tr_ind_conti[:, 3]
        tr_time_ind = tr_ind[:, 3]
        tr_ind_conti = tr_ind_conti[:, :3]
        tr_ind = tr_ind[:, :3]


        tr_y = torch.tensor(full_data["tr_y"], dtype=torch.float32).to(device)
        time_uni = torch.tensor(full_data["time_uni"], dtype=torch.float32).to(device)
        u_ind_uni = torch.tensor(full_data["u_ind_uni"], dtype=torch.float32).to(device)
        v_ind_uni = torch.tensor(full_data["v_ind_uni"], dtype=torch.float32).to(device)
        w_ind_uni = torch.tensor(full_data["w_ind_uni"], dtype=torch.float32).to(device)
        return tr_ind_conti, tr_ind, tr_time_conti, tr_time_ind, tr_y, time_uni,  u_ind_uni, v_ind_uni, w_ind_uni
    else:
        full_data = np.load(data_path, allow_pickle=True).item()
        full_data = full_data["data"][fold]
        te_ind_conti = torch.tensor(full_data["te_ind_conti"], dtype=torch.float32).to(device)
        te_ind = torch.tensor(full_data["te_ind"], dtype=torch.long).to(device)

        te_time_conti = te_ind_conti[:, 3]
        te_time_ind = te_ind[:, 3]
        te_ind_conti = te_ind_conti[:, :3]
        te_ind = te_ind[:, :3]

        te_y = torch.tensor(full_data["te_y"], dtype=torch.float32).to(device)
        return te_ind_conti, te_ind, te_time_conti, te_time_ind, te_y

def visual_temp2(tr_ind, tr_time_ind, y, ind):
    target = np.array(ind)
    matches = np.all(tr_ind == target, axis=1)
    indices = np.nonzero(matches)[0]

    time = tr_time_ind[indices]
    temp = y[indices]


    ind_sort = np.argsort(np.squeeze(time))

    time_sort = time[ind_sort]
    temp_sort = temp[ind_sort]


    plt.plot(time_sort, temp_sort, marker="x")
    plt.xlabel("time")
    plt.ylabel("temperature")
    plt.show()
def normalize_data(tr_y, te_y):
    data_mean = tr_y.mean()
    data_std = tr_y.std()
    tr = (tr_y - data_mean) / data_std
    te = (te_y - data_mean) / data_std
    return tr, te, data_mean, data_std


def loss_fn(pred, gt): #RMSE
    # pred, gt: (3000,) tensor
    MSE_loss = torch.nn.MSELoss()
    return torch.sqrt(MSE_loss(pred, gt))


def loss_fn2(pred, gt): #MAE
    # pred, gt: (3000,) tensor
    MAE_loss = torch.nn.L1Loss()
    return  MAE_loss(pred, gt)

epsilon = 1e-7
def train(model,  train_loader,  optimizer, loss_fn, epoch):
    # set model to training mode
    model.train()
    # Use tqdm for progress bar
    loss_list = []
    for i, (train_ind_batch, tr_time_ind_batch, train_y_batch) in tqdm(enumerate(train_loader)):
        optimizer.zero_grad()
        output_batch,  nfe_forward, kl_loss = model(train_ind_batch, tr_time_ind_batch)
        N = train_ind_batch.shape[0]
        c = torch.sigmoid(torch.FloatTensor([(epoch-200)/10])).to(device) #cold start of automatic rank determination mechanism
        loss  = loss_fn(output_batch,  train_y_batch.squeeze()) 
        loss_all = loss + (c)*(kl_loss[0]/ N ) + (1-c)*kl_loss[1] # with FARD
        #loss_all = loss # without FARD

        loss_all.backward()
        optimizer.step()
        loss_list.append(loss.item())

    loss_mean = np.mean(loss_list)
    return loss_mean, nfe_forward



def evaluating(model, test_loader, loss_fn, loss_fn2):
    # set model to training mode
    #model.eval()
    # Use tqdm for progress bar
    result = torch.Tensor([]) .to(device)
    labels = []
    print("evaluating....")
    for test_ind_batch, te_time_ind_batch, test_y_batch in test_loader:
        output_batch, nfe_forward, _ = model(test_ind_batch, te_time_ind_batch)
        result = torch.cat((result, output_batch), dim=0)

    rmse = loss_fn(result, te_y.squeeze())*data_std
    mae = loss_fn2(result, te_y.squeeze())*data_std
    return int(rmse.item()*10000)/10000, int(mae.item()*10000)/10000

def evaluating_and_save(model, test_loader):
    result = torch.Tensor([]) .to(device)
    labels = []
    print("evaluating....")
    for test_ind_batch, te_time_ind_batch, test_y_batch in test_loader:
        output_batch, nfe_forward, _ = model(test_ind_batch, te_time_ind_batch)
        result = torch.cat((result, output_batch), dim=0)
    scio.savemat(r"./output/Catte_"+data_name+"_result.mat", {"result": (data_mean+data_std*result.detach()).cpu().numpy()}) 
   


def train_parallel(model, train_loader, test_loader,  optimizer, loss_fn, loss_fn2, max_iter):
        loss_min = 10
        rmse_min = 10
        Epoch_list = []
        RMSE_list = []
        MAE_list = []
        UV_sum_list = []
        gamma_list = []
        rmse, mae = evaluating(model, test_loader, loss_fn, loss_fn2)
        print("Epoach:-1", "Evaluating RMSE:", rmse, " MAE:", mae)
        for epoch in range(max_iter):
            loss, nfe_forward = train(model, train_loader, optimizer, loss_fn, epoch)
            
            if epoch % 20 == 0:  # record
                Epoch_list.append(epoch)
                UV_sum_list.append(nfe_forward[1])
                gamma_list.append(nfe_forward[2])
            if epoch>1500 and epoch % 100 == 0:
                rmse, mae = evaluating(model, test_loader, loss_fn, loss_fn2)
                if rmse < rmse_min:
                    rmse_min = rmse
                    evaluating_and_save(model, test_loader)
                print("Epoach", epoch, "Evaluating RMSE:", rmse, " MAE:", mae)
            if epoch % 20 == 0:
                print("epoch:", epoch, "RMSE loss:", loss, "nfe_forward", nfe_forward)






if __name__ == "__main__":
    set_random_seed(231)
    data_name = r"traffic"
    data_path =  r"./data_set/traffic_5x20x16_m3_fold5.npy"

    tr_ind_conti, tr_ind, tr_time_conti, tr_time_ind, tr_y, time_uni,  u_ind_uni, v_ind_uni, w_ind_uni = load_data(data_path, flag=1, fold=0)
    te_ind_conti, te_ind, te_time_conti, te_time_ind, te_y = load_data(data_path, flag=0, fold=0)
    tr_y, te_y, data_mean, data_std= normalize_data(tr_y, te_y)
    j,r = 10,10
    learning_rate = 5e-4
    model = Catte_4D(j,r, time_uni,  u_ind_uni, v_ind_uni, w_ind_uni).to(device)

    optimizer = optim.AdamW(model.parameters(), learning_rate)

    train_dataset = TensorDataset(tr_ind, tr_time_ind, tr_y)
    test_dataset = TensorDataset(te_ind, te_time_ind, te_y)

    train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=4096, shuffle=False)
    train_parallel(model, train_loader, test_loader, optimizer, loss_fn, loss_fn2, max_iter=2001)


