"""
This file contains functions useful for training and validating, including the training loop
the validation loop, and data generation.
"""

from torch_geometric.loader import DataLoader
import numpy as np
import torch
import utils
import os
import pickle as pkl
from torch_scatter import scatter_add

def get_data_loader(data_path, batch_size, knn, modes, downsample=1, weight='', fibers=False, withDiff=None, num_samps=None, newres=None):
    vars_to_read = ['t', 'x', 'u']
    # if bd_conditions != None:
    #     vars_to_read.append('bcs_dicts')

    if withDiff == 'withDiff':
        vars_to_read.append('diffusion')

    if fibers == True:
        vars_to_read.append('fibers')

    data = utils.read_pickle(vars_to_read, data_path)

    if newres != None:
        data = utils.new_res(vars_to_read, data_path, newres)
        data_path = os.path.join(data_path, str(newres))

    if num_samps == None:
        num_sims = len(data['x'])
    else:
        num_sims = num_samps


    eigs_path = os.path.join(data_path, 'knn='+str(knn)+weight, 'HK')
    # eigs_path = os.path.join(data_path, 'knn=' + str(knn) + weight, 'Isotropic_test')
    have_eigs = os.path.isdir(eigs_path)
    if have_eigs == False:
        edges = list()
        vecs = list()
        vals = list()
        adj = list()
        lap = list()
        # vecsx = list()
        # vecsy = list()
        weights = list()
        # weightsx = list()
        # weightsy = list()

        print('Making {} edges and eigen vectors'.format(num_sims))
        for i in range(num_sims):
            print(i)
            edges.append(utils.get_knn_edge_index(data['x'][i], knn))
            if weight != '':
                tempWeights = utils.get_HK_weights(data['diffusion'][i], data['x'][i], edges[i])
                weights.append(tempWeights)
                qqq, www, adjtemp, laptemp = utils.get_weighted_sparse_gft_eig_vec(edges[i], weights[i])
                vecs.append(www)
                vals.append(qqq)
                adj.append(adjtemp)
                lap.append(laptemp)
            else:
                qqq, www = utils.get_gft_eig_vec(edges[i])
                vecs.append(www)
                vals.append(qqq)

        os.makedirs(eigs_path)
        with open(os.path.join(eigs_path, 'edges.pkl'), 'wb') as fp:
            pkl.dump(edges, fp)
        with open(os.path.join(eigs_path, 'vecs.pkl'), 'wb') as fp:
            pkl.dump(vecs, fp)
        with open(os.path.join(eigs_path, 'vals.pkl'), 'wb') as fp:
            pkl.dump(vals, fp)
        with open(os.path.join(eigs_path, 'adj.pkl'), 'wb') as fp:
            pkl.dump(adj, fp)
        with open(os.path.join(eigs_path, 'lap.pkl'), 'wb') as fp:
            pkl.dump(lap, fp)
        if weight != '':
            with open(os.path.join(eigs_path, 'weights.pkl'), 'wb') as fp:
                pkl.dump(weights, fp)
    else:
        with open(os.path.join(eigs_path, 'edges.pkl'), 'rb') as fp:
            edges = pkl.load(fp)
        with open(os.path.join(eigs_path, 'vecs.pkl'), 'rb') as fp:
            vecs = pkl.load(fp)
        with open(os.path.join(eigs_path, 'vals.pkl'), 'rb') as fp:
            vals = pkl.load(fp)
        with open(os.path.join(eigs_path, 'adj.pkl'), 'rb') as fp:
            adj = pkl.load(fp)
        if weight != '':
            with open(os.path.join(eigs_path, 'weights.pkl'), 'rb') as fp:
                weights = pkl.load(fp)
        else:
            weights = 1


    dataset = utils.generate_torchgeom_dataset(data, num_sims, withDiff, fibers, edges, vecs, vals, weights, modes, grad=True)
    if batch_size is None:
        batch_size = len(dataset)
    else:
        batch_size = batch_size

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return loader


def weighted_graph_grad_old(sig, edges, edge_weights):
    temp = torch.sqrt(edge_weights).squeeze() * torch.abs(sig[:, edges[1]]-sig[:, edges[0]])
    return scatter_add(temp.T, torch.tensor(edges[0]), dim=0).unsqueeze(dim=-1)

def weighted_graph_grad(sig, edges, edge_weights):
    temp = torch.sqrt(edge_weights).squeeze() * (sig[:, edges[1]]-sig[:, edges[0]])
    return temp.T

class rel_h1(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, u_pd, u, grad_u, edges, edge_weights):
        rel_l2 = torch.linalg.norm(u_pd-u, axis=0)/torch.linalg.norm(u, axis=0)
        grad_u_pd = weighted_graph_grad(u_pd.squeeze(), edges, edge_weights)
        rel_grad = torch.linalg.norm(grad_u_pd-grad_u, axis=0)/torch.linalg.norm(grad_u, axis=0)
        return torch.mean(rel_l2+rel_grad)

def train_fn(engine, batch, model, optimizer, bd_conditions, device, loss_criterion, withDiff=None, max_t=None):

    # Returns loss value
    model.train()
    optimizer.zero_grad()
    pos = batch.pos
    redfor = batch.redfor
    redvals = batch.redvals
    fibers = batch.fibers

    params_dict = {
        'pos': pos.to(device),
        'fibers': fibers.to(device),
        'redfor': redfor.to(device),
        'redvals': redvals.to(device)
        }


    if withDiff == 'withDiff':
        params_dict['diffusion'] = batch.diffusion.to(device)

    model.odefunc.update_params(params_dict)

    y0 = batch.x.to(device) #torch.hstack([batch.x, batch.x]).to(device)
    t = batch.t[0:max_t].to(device)
    y_pd = model(y0, t)
    y_gt = batch.y.transpose(0, 1).to(device)
    grad_y_pd = weighted_graph_grad(y_pd.cpu().squeeze(), batch.edge_index, batch.edge_weights).to(device)
    grad_y_gt = batch.grad_y.to(device)
    rel_h1_loss = rel_h1()
    loss = rel_h1_loss(y_pd, y_gt, grad_y_gt, batch.edge_index.to(device), batch.edge_weights.to(device))
    loss.backward()
    optimizer.step()

    return loss_criterion(y_pd, y_gt).item()


def validation_fn(engine, batch, model, bd_conditions, device, withDiff=None, max_t=None):
    model.eval()
    with torch.no_grad():
        edge_index = batch.edge_index
        pos = batch.pos
        redvals = batch.redvals
        fibers = batch.fibers

        params_dict = {
            'pos': pos.to(device),
            'fibers': fibers.to(device),
            'redfor': redfor.to(device),
            'redvals': redvals.to(device)
        }

        if withDiff == 'withDiff':
            tem_diff = batch.diffusion
            params_dict['diffusion'] = tem_diff.to(device)

        model.odefunc.update_params(params_dict)
        y0 = batch.x.to(device)
        t = batch.t[0:max_t].to(device)
        y_pd = model(y0, t)
        y_gt = batch.y.transpose(0, 1).to(device)
        return y_pd, y_gt
