"""
This file contains functions useful for training and validating, including the training loop
the validation loop, and data generation.
"""

from torch_geometric.loader import DataLoader
import numpy as np
import torch
import utils
import os
import pickle as pkl

def get_data_loader(data_path, batch_size, knn, modes, downsample=1, weight='', fibers=False, bd_conditions=None, withDiff=None, num_samps=None):
    vars_to_read = ['t', 'x', 'u']
    if bd_conditions:
        vars_to_read.append('bcs_dicts')

    if withDiff == 'withDiff':
        vars_to_read.append('diffusion')

    if fibers == True:
        vars_to_read.append('fibers')

    data = utils.read_pickle(vars_to_read, data_path)
    if downsample > 1:
        data['t'] = data['t'][:, np.arange(0, data['t'].shape[1], downsample)]  # / 500 for heart
        data['u'] = data['u'][:, np.arange(0, data['u'].shape[1], downsample)]  # / 100 for heart
    if num_samps == None:
        num_sims = len(data['x'])
    else:
        num_sims = num_samps
    eigs_path = os.path.join(data_path, 'knn='+str(knn)+weight, 'HK')
    have_eigs = os.path.isdir(eigs_path)
    if have_eigs == False:
        edges = list()
        vecs = list()
        vals = list()
        weights = list()

        print('Making {} edges and eigen vectors'.format(num_sims))
        for i in range(num_sims):
            print(i)
            edges.append(utils.get_knn_edge_index(data['x'][i], knn))
            if weight != '':
                tempWeights = utils.get_HK_weights(data['diffusion'][i], data['x'][i], edges[i])
                weights.append(tempWeights)
                qqq, www = utils.get_weighted_sparse_gft_eig_vec(edges[i], weights[i])
                vecs.append(www)
                vals.append(qqq)
            else:
                qqq, www = utils.get_gft_eig_vec(edges[i])
                vecs.append(www)
                vals.append(qqq)

        os.makedirs(eigs_path)
        with open(os.path.join(eigs_path, 'edges.pkl'), 'wb') as fp:
            pkl.dump(edges, fp)
        with open(os.path.join(eigs_path, 'vecs.pkl'), 'wb') as fp:
            pkl.dump(vecs, fp)
        with open(os.path.join(eigs_path, 'vals.pkl'), 'wb') as fp:
            pkl.dump(vals, fp)
        if weight != '':
            with open(os.path.join(eigs_path, 'weights.pkl'), 'wb') as fp:
                pkl.dump(weights, fp)

    else:
        with open(os.path.join(eigs_path, 'edges.pkl'), 'rb') as fp:
            edges = pkl.load(fp)
        with open(os.path.join(eigs_path, 'vecs.pkl'), 'rb') as fp:
            vecs = pkl.load(fp)
        with open(os.path.join(eigs_path, 'vals.pkl'), 'rb') as fp:
            vals = pkl.load(fp)

    dataset = utils.generate_torchgeom_dataset(data, num_sims, bd_conditions, withDiff, fibers, edges, vecs, vals, modes)
    if batch_size is None:
        batch_size = len(dataset)
    else:
        batch_size = batch_size

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    return loader


def train_fn(engine, batch, model, optimizer, bd_conditions, device, loss_criterion, withDiff=None, max_t=None):

    # Returns loss value
    model.train()
    optimizer.zero_grad()
    pos = batch.pos
    redfor = batch.redfor
    redvals = batch.redvals
    fibers = batch.fibers

    params_dict = {
        'pos': pos.to(device),
        'fibers': fibers.to(device),
        'redfor': redfor.to(device),
        'redvals': redvals.to(device)
        }

    if withDiff == 'withDiff':
        params_dict['diffusion'] = batch.diffusion.to(device)

    model.odefunc.update_params(params_dict)

    y0 = batch.x.to(device)
    t = batch.t[0:max_t].to(device)
    y_pd = model(y0, t)
    y_gt = batch.y.transpose(0, 1).to(device)
    loss = loss_criterion(y_pd, y_gt)
    loss.backward()
    optimizer.step()

    return loss.item()


def validation_fn(engine, batch, model, bd_conditions, device, withDiff=None, max_t=None):
    model.eval()
    with torch.no_grad():
        edge_index = batch.edge_index
        pos = batch.pos
        redfor = batch.redfor
        redvals = batch.redvals
        fibers = batch.fibers

        params_dict = {
            'pos': pos.to(device),
            'fibers': fibers.to(device),
            'redfor': redfor.to(device),
            'redvals': redvals.to(device)
        }

        if withDiff == 'withDiff':
            tem_diff = batch.diffusion
            params_dict['diffusion'] = tem_diff.to(device)

        model.odefunc.update_params(params_dict)
        y0 = batch.x.to(device)
        t = batch.t[0:max_t].to(device)
        y_pd = model(y0, t)
        y_gt = batch.y.transpose(0, 1).to(device)
        return y_pd, y_gt
