"""
This file contains functions useful for training and validating, including the training loop
the validation loop, and data generation.
"""

from torch_geometric.loader import DataLoader
import numpy as np
import torch
import utils
import os
import pickle as pkl
from torch_scatter import scatter_add
from scipy import signal

def get_data_loader(data_path, batch_size, knn, modes, downsample=1, weight='', fibers=False, withDiff=None, num_samps=None, newres=None):
    dataset = []
    vars_to_read = ['t', 'x', 'u', 'UAC']
    # if bd_conditions != None:
    #     vars_to_read.append('bcs_dicts')

    if withDiff == 'withDiff':
        vars_to_read.append('diffusion')

    if fibers == True:
        vars_to_read.append('fibers')
        
    for pid in os.listdir(data_path):
        print(pid)
        data = utils.read_pickle(vars_to_read, os.path.join(data_path,pid))
    # if downsample > 1:
    #     data['t'] = data['t'][:, np.arange(0, data['t'].shape[1], downsample)]  # / 500 for heart
    #     data['u'] = data['u'][:, np.arange(0, data['u'].shape[1], downsample)]  # / 100 for heart
        if newres != None:
            data = utils.new_res(vars_to_read, data_path, newres)
            data_path = os.path.join(data_path, str(newres))
            
        if num_samps == None:
            num_sims = len(data['u'])
        else:
            num_sims = num_samps
        
        eigs_path = os.path.join(os.path.join(data_path,pid), 'knn='+str(knn)+weight, 'HK')
        have_eigs = os.path.isdir(eigs_path)
        if have_eigs == False:
            edges = list()
            vecs = list()
            vals = list()
            adj = list()
            lap = list()
            weights = list()

            print('Making {} edges and eigen vectors'.format(1))
            for i in range(1):
                print(i)
                edges.append(utils.get_knn_edge_index(data['x'][i], knn))
                if weight != '':
                    tempWeights = utils.get_HK_weights(data['diffusion'][i], data['x'][i], edges[i])
                    weights.append(tempWeights)
                    qqq, www, adjtemp, laptemp = utils.get_weighted_sparse_gft_eig_vec(edges[i], weights[i])
                    vecs.append(www)
                    vals.append(qqq)
                    adj.append(adjtemp)
                    lap.append(laptemp)
                else:
                    qqq, www = utils.get_gft_eig_vec(edges[i])
                    vecs.append(www)
                    vals.append(qqq)

                os.makedirs(eigs_path)
            with open(os.path.join(eigs_path, 'edges.pkl'), 'wb') as fp:
                pkl.dump(edges, fp)
            with open(os.path.join(eigs_path, 'vecs.pkl'), 'wb') as fp:
                pkl.dump(vecs, fp)
            with open(os.path.join(eigs_path, 'vals.pkl'), 'wb') as fp:
                pkl.dump(vals, fp)
            with open(os.path.join(eigs_path, 'adj.pkl'), 'wb') as fp:
                pkl.dump(adj, fp)
            with open(os.path.join(eigs_path, 'lap.pkl'), 'wb') as fp:
                pkl.dump(lap, fp)
            if weight != '':
                with open(os.path.join(eigs_path, 'weights.pkl'), 'wb') as fp:
                    pkl.dump(weights, fp)
        
        else:
            with open(os.path.join(eigs_path, 'edges.pkl'), 'rb') as fp:
                edges = pkl.load(fp)
            with open(os.path.join(eigs_path, 'vecs.pkl'), 'rb') as fp:
                vecs = pkl.load(fp)
            with open(os.path.join(eigs_path, 'vals.pkl'), 'rb') as fp:
                vals = pkl.load(fp)
            with open(os.path.join(eigs_path, 'adj.pkl'), 'rb') as fp:
                adj = pkl.load(fp)
            if weight != '':
                with open(os.path.join(eigs_path, 'weights.pkl'), 'rb') as fp:
                    weights = pkl.load(fp)
            else:
                weights = 1


        dataset.append(utils.generate_torchgeom_dataset(data, num_sims, withDiff, fibers, edges, vecs, vals, weights, modes, grad=True, pid=pid))
    
    dataset = [x for xs in dataset for x in xs]
    if batch_size is None:
        batch_size = len(dataset)
    else:
        batch_size = batch_size

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return loader

def get_test_data_loader(data_path, batch_size, knn, modes, downsample=1, weight='', fibers=False, withDiff=None, num_samps=None, newres=None):
    dataset = []
    vars_to_read = ['t', 'x', 'u']
    # if bd_conditions != None:
    #     vars_to_read.append('bcs_dicts')

    if withDiff == 'withDiff':
        vars_to_read.append('diffusion')

    if fibers == True:
        vars_to_read.append('fibers')
        
    for pid in os.listdir(data_path):
        print(pid)
        data = utils.read_pickle(vars_to_read, os.path.join(data_path,pid))
    # if downsample > 1:
    #     data['t'] = data['t'][:, np.arange(0, data['t'].shape[1], downsample)]  # / 500 for heart
    #     data['u'] = data['u'][:, np.arange(0, data['u'].shape[1], downsample)]  # / 100 for heart
        if newres != None:
            data = utils.new_res(vars_to_read, data_path, newres)
            data_path = os.path.join(data_path, str(newres))
            
        if num_samps == None:
            num_sims = len(data['u'])
        else:
            num_sims = num_samps
        
        eigs_path = os.path.join(os.path.join(data_path,pid), 'knn='+str(knn)+weight, 'HK')
        have_eigs = os.path.isdir(eigs_path)
        if have_eigs == False:
            edges = list()
            vecs = list()
            vals = list()
            adj = list()
            lap = list()
            weights = list()

            print('Making {} edges and eigen vectors'.format(1))
            for i in range(1):
                print(i)
                edges.append(utils.get_knn_edge_index(data['x'][i], knn))
                if weight != '':
                    tempWeights = utils.get_HK_weights(data['diffusion'][i], data['x'][i], edges[i])
                    weights.append(tempWeights)
                    qqq, www, adjtemp, laptemp = utils.get_weighted_sparse_gft_eig_vec(edges[i], weights[i])
                    vecs.append(www)
                    vals.append(qqq)
                    adj.append(adjtemp)
                    lap.append(laptemp)
                else:
                    qqq, www = utils.get_gft_eig_vec(edges[i])
                    vecs.append(www)
                    vals.append(qqq)

                os.makedirs(eigs_path)
            with open(os.path.join(eigs_path, 'edges.pkl'), 'wb') as fp:
                pkl.dump(edges, fp)
            with open(os.path.join(eigs_path, 'vecs.pkl'), 'wb') as fp:
                pkl.dump(vecs, fp)
            with open(os.path.join(eigs_path, 'vals.pkl'), 'wb') as fp:
                pkl.dump(vals, fp)
            with open(os.path.join(eigs_path, 'adj.pkl'), 'wb') as fp:
                pkl.dump(adj, fp)
            with open(os.path.join(eigs_path, 'lap.pkl'), 'wb') as fp:
                pkl.dump(lap, fp)
            if weight != '':
                with open(os.path.join(eigs_path, 'weights.pkl'), 'wb') as fp:
                    pkl.dump(weights, fp)
        
        else:
            with open(os.path.join(eigs_path, 'edges.pkl'), 'rb') as fp:
                edges = pkl.load(fp)
            with open(os.path.join(eigs_path, 'vecs.pkl'), 'rb') as fp:
                vecs = pkl.load(fp)
            with open(os.path.join(eigs_path, 'vals.pkl'), 'rb') as fp:
                vals = pkl.load(fp)
            with open(os.path.join(eigs_path, 'adj.pkl'), 'rb') as fp:
                adj = pkl.load(fp)
            if weight != '':
                with open(os.path.join(eigs_path, 'weights.pkl'), 'rb') as fp:
                    weights = pkl.load(fp)
            else:
                weights = 1


        dataset.append(utils.generate_torchgeom_dataset(data, num_sims, withDiff, fibers, edges, vecs, vals, weights, modes, grad=True, pid=pid))
    
    dataset = [x for xs in dataset for x in xs]
    if batch_size is None:
        batch_size = len(dataset)
    else:
        batch_size = batch_size

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    return loader
    
    
# def get_data_loader(data_path, batch_size, knn, modes, downsample=1, weight='', bd_conditions=None, withDiff=None):
#     vars_to_read = ['t', 'x', 'u']
#     if bd_conditions:
#         vars_to_read.append('bcs_dicts')
#
#     if withDiff == 'withDiff':
#         vars_to_read.append('diffusion')
#
#     data = utils.read_pickle(vars_to_read, data_path)
#     if downsample > 1:
#         data['t'] = data['t'][:, np.arange(0, data['t'].shape[1], downsample)]  # / 500 for heart
#         data['u'] = data['u'][:, np.arange(0, data['u'].shape[1], downsample)]  # / 100 for heart
#     num_sims = len(data['x'])
#     eigs_path = os.path.join(data_path, 'knn='+str(knn)+weight, 'HK')
#     have_eigs = os.path.isdir(eigs_path)
#     if have_eigs == False:
#         edges = list()
#         vecs = list()
#         vals = list()
#         # vecsx = list()
#         # vecsy = list()
#         weights = list()
#         # weightsx = list()
#         # weightsy = list()
#
#         print('Making {} edges and eigen vectors'.format(num_sims))
#         for i in range(num_sims):
#             print(i)
#             edges.append(utils.get_knn_edge_index(data['x'][i], knn))
#             if weight != '':
#                 tempWeights = utils.get_HK_weights(data['diffusion'][i], data['x'][i], edges[i])
#                 # weightsx.append(tempWeights[:, 0])
#                 # weightsy.append(tempWeights[:, 1])
#                 weights.append(tempWeights)
#                 # vecsx.append(utils.get_weighted_sparse_gft_eig_vec(edges[i], weightsx[i]))
#                 # vecsy.append(utils.get_weighted_sparse_gft_eig_vec(edges[i], weightsy[i]))
#                 qqq, www = utils.get_weighted_sparse_gft_eig_vec(edges[i], weights[i])
#                 vecs.append(www)
#                 vals.append(qqq)
#             else:
#                 qqq, www = utils.get_gft_eig_vec(edges[i])
#                 vecs.append(www)
#                 vals.append(qqq)
#             # if not weights:
#             #     wpass = None
#
#         os.makedirs(eigs_path)
#         with open(os.path.join(eigs_path, 'edges.pkl'), 'wb') as fp:
#             pkl.dump(edges, fp)
#         # with open(os.path.join(eigs_path, 'vecsx.pkl'), 'wb') as fp:
#         #     pkl.dump(vecsx, fp)
#         # with open(os.path.join(eigs_path, 'vecsy.pkl'), 'wb') as fp:
#         #     pkl.dump(vecsy, fp)
#         with open(os.path.join(eigs_path, 'vecs.pkl'), 'wb') as fp:
#             pkl.dump(vecs, fp)
#         with open(os.path.join(eigs_path, 'vals.pkl'), 'wb') as fp:
#             pkl.dump(vals, fp)
#         if weight != '':
#             with open(os.path.join(eigs_path, 'weights.pkl'), 'wb') as fp:
#                 pkl.dump(weights, fp)
#             # with open(os.path.join(eigs_path, 'weightsx.pkl'), 'wb') as fp:
#             #     pkl.dump(weightsx, fp)
#             # with open(os.path.join(eigs_path, 'weightsy.pkl'), 'wb') as fp:
#             #     pkl.dump(weightsy, fp)
#     else:
#         with open(os.path.join(eigs_path, 'edges.pkl'), 'rb') as fp:
#             edges = pkl.load(fp)
#         # with open(os.path.join(eigs_path, 'vecsx.pkl'), 'rb') as fp:
#         #     vecsx = pkl.load(fp)
#         # with open(os.path.join(eigs_path, 'vecsy.pkl'), 'rb') as fp:
#         #     vecsy = pkl.load(fp)
#         with open(os.path.join(eigs_path, 'vecs.pkl'), 'rb') as fp:
#             vecs = pkl.load(fp)
#         with open(os.path.join(eigs_path, 'vals.pkl'), 'rb') as fp:
#             vals = pkl.load(fp)
#         # with open(os.path.join(eigs_path, 'weights.pkl'), 'rb') as fp:
#         #     weight = pkl.load(fp)
#
#     dataset = utils.generate_torchgeom_dataset(data, num_sims, bd_conditions, withDiff, edges, vecs, vals, modes)
#     if batch_size is None:
#         batch_size = len(dataset)
#     else:
#         batch_size = batch_size
#
#     loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
#
#     return loader

# def graph_grad_torch(sig, adj):
#     adj_norm = torch.tensor(torch.nan_to_num(adj/adj))
#     # h1 = (adj_norm.tranpose(0,1)*sig.tanspose(0,1))
#     # h1 = sig-h1
#     # h1 = torch.abs(h1*adj)
#     return torch.sum(torch.abs((adj.T * (sig.squeeze()-(adj_norm.T * sig.permute(1, 2, 0)).T).T).T), axis=1)

def lag(x,y):
    """
    Perform Cross-Correlation on x and y and returns the lag for each point in a mesh
    x    : 1st signal
    y    : 2nd signal

    returns
    lags : lags of between signals
    """
    corr = signal.correlate(x, y, mode="full")
    lags = signal.correlation_lags(len(x), len(y), mode="full")
    return lags[np.argmax(corr)]



def weighted_graph_grad_old(sig, edges, edge_weights):
    temp = torch.sqrt(edge_weights).squeeze() * torch.abs(sig[:, edges[1]]-sig[:, edges[0]])
    return scatter_add(temp.T, torch.tensor(edges[0]), dim=0).unsqueeze(dim=-1)

def weighted_graph_grad(sig, edges, edge_weights):
    temp = torch.sqrt(edge_weights).squeeze() * (sig[:, edges[1]]-sig[:, edges[0]])
    return temp.T

def train_fn(engine, batch, model, optimizer, bd_conditions, device, loss_criterion, eigenpairs, withDiff=None, max_t=None):

    # Returns loss value
    model.train()
    optimizer.zero_grad()
    # edge_index = batch.edge_index
    pid = batch.pid
    
    # pos = batch.pos
    # redforx = batch.redforx
    # redfory = batch.redfory
    # redfor = eigenpairs[pid[0]][0]
    # redinv = eigenpairs[pid[0]][1]
    # redvals = eigenpairs[pid[0]][2]
    # fibers = batch.fibers

    redfor = eigenpairs[pid[0]][0]
    redvals = eigenpairs[pid[0]][1]
    UAC = eigenpairs[pid[0]][2]


    # with torch.no_grad():
    #     rel_pos = pos[edge_index[1]] - pos[edge_index[0]]
    params_dict = {
        # 'pos': pos.to(device),
        # 'edge_index': edge_index.to(device),
        # 'rel_pos': rel_pos.to(device),
        # 'redforx': redforx.to(device),
        # 'redfory': redfory.to(device),
        # 'fibers': fibers.to(device),
        'redfor': redfor, #.to(device),
        # 'redinv': redinv,
        'redvals': redvals, #.to(device)
        'UAC': UAC
        }


    # if bd_conditions != 'none':
    #     params_dict['bcs_dict'] = utils.concatenate_bcs_dicts(batch)
    #     if bd_conditions == 'neumann':
    #         # TODO: Parts of this can be factored out if using the same mesh
    #         shape = (len(params_dict['bcs_dict']['boundary'][0]), len(batch.batch))
    #         params_dict['bd_constraint'] = utils.concatenate_neumann_bd_constraint_matrices(
    #             batch.bd_constraint, shape
    #         ).to(device)

    # if withDiff == 'withDiff':
    #    # tem_diff = batch.diffusion
    #    # params_dict['diffusion'] = tem_diff.to(device)
    #     params_dict['diffusion'] = batch.diffusion.to(device)

    model.odefunc.update_params(params_dict)

    y0 = batch.x.to(device) #torch.hstack([batch.x, batch.x]).to(device)
    t = batch.t[0:max_t].to(device)
    y_pd = model(y0, t)
    y_gt = batch.y.transpose(0, 1).to(device)
    #y_norm = torch.norm(y_gt, dim=0)
    grad_y_pd = weighted_graph_grad(y_pd.squeeze(), batch.edge_index.to(device), batch.edge_weights.to(device)).to(device)
    grad_y_gt = batch.grad_y.to(device)
    #grad_y_norm = torch.norm(grad_y_gt, dim=0).float()

    #loss = loss_criterion(y_pd, y_gt) + 5*loss_criterion(grad_y_pd.float(), grad_y_gt.float())/torch.norm(grad_y_gt) #
    loss = torch.mean(torch.sqrt(torch.sum(loss_criterion(y_pd, y_gt), dim=0)/torch.sum(y_gt**2, dim=0))) + torch.mean(torch.sqrt(torch.sum(loss_criterion(grad_y_pd, grad_y_gt), dim=0)/torch.sum(grad_y_gt**2, dim=0)))
    loss.backward()
    optimizer.step()

    return loss.item()


def validation_fn(engine, batch, model, bd_conditions, device, eigenpairs, withDiff=None, max_t=None):
    model.eval()
    with torch.no_grad():
        # edge_index = batch.edge_index
        # pos = batch.pos
        # with torch.no_grad():
        #     rel_pos = pos[edge_index[1]] - pos[edge_index[0]]

        pid = batch.pid

        # redforx = batch.redforx
        # redfory = batch.redfory
        # redfor = batch.redfor
        # redfor = eigenpairs[pid[0]][0]
        # redinv = eigenpairs[pid[0]][1]
        # redvals = eigenpairs[pid[0]][2]
    # fibers = batch.fibers

        redfor = eigenpairs[pid[0]][0]
        redvals = eigenpairs[pid[0]][1]
        UAC = eigenpairs[pid[0]][2]



    # with torch.no_grad():
    #     rel_pos = pos[edge_index[1]] - pos[edge_index[0]]
        params_dict = {
        # 'pos': pos.to(device),
        # 'edge_index': edge_index.to(device),
        # 'rel_pos': rel_pos.to(device),
        # 'redforx': redforx.to(device),
        # 'redfory': redfory.to(device),
        # 'fibers': fibers.to(device),
            'redfor': redfor, #.to(device),
            # 'redinv': redinv,
            'redvals': redvals, #.to(device)
            'UAC': UAC #.to(device)
            }
        # if bd_conditions != 'none':
        #     params_dict['bcs_dict'] = utils.concatenate_bcs_dicts(batch)
        #     if bd_conditions == 'neumann':
        #         # TODO: Parts of this can be factored out if using the same mesh
        #         shape = (len(params_dict['bcs_dict']['boundary'][0]), len(batch.batch))
        #         params_dict['bd_constraint'] = utils.concatenate_neumann_bd_constraint_matrices(
        #             batch.bd_constraint, shape
        #         ).to(device)

        # if withDiff == 'withDiff':
        #     tem_diff = batch.diffusion
        #     params_dict['diffusion'] = tem_diff.to(device)

        model.odefunc.update_params(params_dict)
        y0 = batch.x.to(device)
        t = batch.t[0:max_t].to(device)
        y_pd = model(y0, t)
        y_gt = batch.y.transpose(0, 1).to(device)
        return y_pd, y_gt

def train_validation_fn(engine, batch, model, bd_conditions, device, eigenpairs, withDiff=None, max_t=None):
    model.eval()
    with torch.no_grad():
        # edge_index = batch.edge_index
        # pos = batch.pos
        # with torch.no_grad():
        #     rel_pos = pos[edge_index[1]] - pos[edge_index[0]]

        pid = batch.pid
        
        # redforx = batch.redforx
        # redfory = batch.redfory
        # redfor = batch.redfor
        # redfor = eigenpairs[pid[0]][0]
        # redinv = eigenpairs[pid[0]][1]
        # redvals = eigenpairs[pid[0]][2]

        redfor = eigenpairs[pid[0]][0]
        redvals = eigenpairs[pid[0]][1]
        UAC = eigenpairs[pid[0]][2]
    # fibers = batch.fibers


    # with torch.no_grad():
    #     rel_pos = pos[edge_index[1]] - pos[edge_index[0]]
        params_dict = {
        # 'pos': pos.to(device),
        # 'edge_index': edge_index.to(device),
        # 'rel_pos': rel_pos.to(device),
        # 'redforx': redforx.to(device),
        # 'redfory': redfory.to(device),
        # 'fibers': fibers.to(device),
            'redfor': redfor, #.to(device)
            # 'redinv': redinv,
            'redvals': redvals, #.to(device)
            'UAC': UAC
            }
        # if bd_conditions != 'none':
        #     params_dict['bcs_dict'] = utils.concatenate_bcs_dicts(batch)
        #     if bd_conditions == 'neumann':
        #         # TODO: Parts of this can be factored out if using the same mesh
        #         shape = (len(params_dict['bcs_dict']['boundary'][0]), len(batch.batch))
        #         params_dict['bd_constraint'] = utils.concatenate_neumann_bd_constraint_matrices(
        #             batch.bd_constraint, shape
        #         ).to(device)

        # if withDiff == 'withDiff':
        #     tem_diff = batch.diffusion
        #     params_dict['diffusion'] = tem_diff.to(device)

        model.odefunc.update_params(params_dict)
        y0 = batch.x.to(device)
        t = batch.t[0:max_t].to(device)
        y_pd = model(y0, t)
        y_gt = batch.y.transpose(0, 1).to(device)
        return y_pd, y_gt
        
def higher_time_resolution_fn(engine, batch, model, bd_conditions, device, eigenpairs, withDiff=None, max_t=None):
    model.eval()
    with torch.no_grad():
        # edge_index = batch.edge_index
        # pos = batch.pos
        # with torch.no_grad():
        #     rel_pos = pos[edge_index[1]] - pos[edge_index[0]]

        pid = batch.pid

        # redforx = batch.redforx
        # redfory = batch.redfory
        # redfor = batch.redfor
        redfor = eigenpairs[pid[0]][0]
        redvals = eigenpairs[pid[0]][1]
    # fibers = batch.fibers


    # with torch.no_grad():
    #     rel_pos = pos[edge_index[1]] - pos[edge_index[0]]
        params_dict = {
        # 'pos': pos.to(device),
        # 'edge_index': edge_index.to(device),
        # 'rel_pos': rel_pos.to(device),
        # 'redforx': redforx.to(device),
        # 'redfory': redfory.to(device),
        # 'fibers': fibers.to(device),
            'redfor': redfor, #.to(device),
            'redvals': redvals #.to(device)
            }
        # if bd_conditions != 'none':
        #     params_dict['bcs_dict'] = utils.concatenate_bcs_dicts(batch)
        #     if bd_conditions == 'neumann':
        #         # TODO: Parts of this can be factored out if using the same mesh
        #         shape = (len(params_dict['bcs_dict']['boundary'][0]), len(batch.batch))
        #         params_dict['bd_constraint'] = utils.concatenate_neumann_bd_constraint_matrices(
        #             batch.bd_constraint, shape
        #         ).to(device)

        # if withDiff == 'withDiff':
        #     tem_diff = batch.diffusion
        #     params_dict['diffusion'] = tem_diff.to(device)

        model.odefunc.update_params(params_dict)
        y0 = batch.x.to(device)
        t = batch.t[0:max_t]
        #### NOTE!!! .001 is hard coded but should be the new dt desired.....
        t = torch.linspace(t[0], t[-1], steps=int((t[-1]-t[0])/.001)+1).to(device)
        y_pd = model(y0, t)
        #y_gt = batch.y.transpose(0, 1).to(device)
        return y_pd
