import os, time, json, shutil, socket, random
from tqdm import tqdm
import sys
import argparse
import yaml
import numpy as np
import pandas as pd
import sqlite3
import torch
import torch.nn as nn
from torch.nn.utils import _stateless
import functools
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

from Critical_Parameters import get_all_initializations
from model_train import FullyConnected, model
from scipy.interpolate import interp1d

# import functorch
# from functorch import vmap, jacrev, make_functional

# disable warning spam
# functorch._C._set_vmap_fallback_warning_enabled(False)

FILENAME = os.path.realpath(__file__)
CDIR = os.path.dirname(FILENAME)
DATADIR = os.path.join(CDIR, 'data', )
config_path = os.path.join(CDIR, 'config', 'train.yaml')


# def ntk(module: nn.Module, input1: torch.Tensor, input2: torch.Tensor,parameters: dict[str, nn.Parameter] = None,
#     compute='full') -> torch.Tensor:
def ntk(module, input1, input2, parameters=None, compute='full'):
    if compute == 'full':
        einsum_expr = 'Naf,Mbf->NMab'
    elif compute == 'trace':
        einsum_expr = 'Naf,Maf->NM'
    elif compute == 'diagonal':
        einsum_expr = 'Naf,Maf->NMa'
    else:
        raise ValueError(compute)

    if parameters is None:
        parameters = dict(module.named_parameters())
    keys, values = zip(*parameters.items())

    def func(*params: torch.Tensor, _input: torch.Tensor = None):
        _output: torch.Tensor = _stateless.functional_call(
            module, {n: p for n, p in zip(keys, params)}, _input)
        return _output  # (N, C)

    jac1: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input1), values, vectorize=True)
    jac2: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input2), values, vectorize=True)
    jac1 = [j.flatten(2) for j in jac1]
    jac2 = [j.flatten(2) for j in jac2]
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]).sum(0)
    return result


# def empirical_ntk(net, x1, x2, compute='full'):
#     fnet, params = make_functional(net)
#     fnet_single = lambda params, x: fnet(params, x.unsqueeze(0)).squeeze(0)
#     # Compute J(x1)
#     jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)
#     jac1 = [j.flatten(2) for j in jac1]
#
#     # Compute J(x2)
#     jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)
#     jac2 = [j.flatten(2) for j in jac2]
#
#     # Compute J(x1) @ J(x2).T
#     einsum_expr = None
#     if compute == 'full':
#         einsum_expr = 'Naf,Mbf->NMab'
#     elif compute == 'trace':
#         einsum_expr = 'Naf,Maf->NM'
#     elif compute == 'diagonal':
#         einsum_expr = 'Naf,Maf->NMa'
#     else:
#         assert False
#
#     result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])
#     result = result.sum(0)
#     return result


def count_parameters(model, verbose=False):
    # if verbose:
    #     table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        # if verbose:
        #     table.add_row([name, params])
        total_params += params
    # if verbose:
    #     print(table)
    #     print(f"Total Trainable Params: {total_params}")
    return total_params


def parse_args():
    """Parse args."""
    # Initialize the command line parser
    parser = argparse.ArgumentParser()
    # Read command line argument
    parser.add_argument('--config', default='algebraiclu', type=str,
                        help='where the hyperparameters are stored, yaml file')
    parser.add_argument('--activation_control', default=-1.0, type=float,
                        help='activation control parameter, if positive will be used instead of the config value')
    parser.add_argument('--seed', default=-1, type=int,
                        help='seed for RNG, if positive will be used instead of the one from the config file')
    parser.add_argument('--stats_name', default='metric', type=str, help='name of table for stats db')
    # parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str, help='cuda/cpu')
    parser.add_argument('--device', default="cpu", type=str, help='cuda/cpu')
    parser.add_argument('--data_folder', default=DATADIR, type=str, help='data folder')
    parser.add_argument('--lr', default=-1.0, type=float,
                        help='learning rate, if positive will be used instead of config')
    parser.add_argument('--momentum', default=-1.0, type=float, help='momentum')
    parser.add_argument('--weight_decay', default=-1.0, type=float, help='weight_decay')
    parser.add_argument('--depth', default=-1, type=int, help='network depth')
    parser.add_argument('--width', default=-1, type=int, help='network width')
    parser.add_argument('--epochs', default=-1, type=int, help='training epochs')
    # boolean parser was giving troubles, so I convert it to integers: False=0, True=1
    parser.add_argument('--default_init', default=0, type=int, help='default initialization')
    parser.add_argument('--critical', default=0, type=int, help='critical initialization')
    parser.add_argument('--renormalize', default=0, type=int, help='data renormalization')
    parser.add_argument('--task_name', default='cifar10', type=str, help='task name', choices=['mnist', 'cifar10'])
    parser.add_argument('--comments', default='NTK_minibatch', type=str, help='Extra comments with influence on the code')
    args = parser.parse_args()

    return args


def train(train_dataloader, net, loss_f, optimizer, device, args):
    for batch, label in tqdm(train_dataloader):
        batch = batch.reshape([batch.shape[0], -1]).to(device)
        batch = batch - torch.max(batch)
        label = label.to(device)
        # pred = net(batch - torch.max(batch)).to(device)
        pred = net(batch).to(device)
        loss = loss_f(pred, label)
        net.performance.append(loss.data.cpu().numpy())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def test(test_dataloader, net, loss_f, device, args):
    net.eval()
    with torch.no_grad():
        for batch, label in test_dataloader:
            if not 'conv' in args.comments:
                batch = batch.reshape([batch.shape[0], -1]).to(device)
            batch = batch - torch.max(batch)
            # pred = net(batch-torch.max(batch)).to(device)
            pred = net(batch).to(device)
            test_loss = loss_f(pred.to(device), label.to(device))
            net.test_performance.append(test_loss.data.cpu().numpy())


def main():
    """Main."""
    time_start = time.perf_counter()

    args = parse_args()
    print(json.dumps(vars(args), sort_keys=True, indent=4))

    if args.config in ['relu', 'gelu', 'swish', 'gumbellu', 'gudermanlu', 'algebraiclu']:
        config_path = os.path.join(CDIR, 'config', 'train.yaml')
        activation = args.config
    else:
        config_path = args.config
        activation = None

    with open(config_path, "r") as stream:
        try:
            config = yaml.safe_load(stream)
        except yaml.YAMLError as exc:
            print(exc)

    save_dir = config['save_dir']

    named_tuple = time.localtime()  # get struct_time
    time_string = time.strftime("%Y-%m-%d--%H-%M-%S-", named_tuple)
    random_string = ''.join([str(r) for r in np.random.choice(10, 4)])
    EXPERIMENTDIR = os.path.join(save_dir, time_string + random_string +
                                 '-' + config['ARCHITECTURE']['activation'] +
                                 '-' + args.task_name)
    os.makedirs(EXPERIMENTDIR, exist_ok=True)

    # Get cpu or gpu device for training.
    if args.device == "cuda":
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        torch.cuda.empty_cache()
    elif args.device == "cpu":
        device = torch.device("cpu")

    print(f"Using {device} device")

    # architecture parameters    
    architecture = config['ARCHITECTURE']

    width = architecture['width']
    if args.width > 0:
        width = args.width
    width = 784 if args.task_name == 'mnist' else 1024

    depth = architecture['depth']
    if args.depth > 0:
        depth = int(args.depth)
    # input_dim = architecture['input_dim']
    input_dim = 784 if args.task_name == 'mnist' else 1024
    variance_list = architecture['variance']
    classes = config['classes']

    if activation is None:
        activation = architecture['activation']

    control = architecture['activation_control']
    if args.activation_control > 0.0:
        control = float(args.activation_control) / 10.0
    args.activation_control = control
    power = architecture['activation_power']
    epochs = config['epochs'] if args.epochs < 0.0 else args.epochs
    lr = config['lr'] if args.lr < 0.0 else args.lr

    weight_decay = config['weight_decay'] if args.weight_decay < 0.0 else args.weight_decay
    momentum = config['momentum'] if args.momentum < 0.0 else args.momentum

    # sampling parameters
    seed = config['SAMPLING']['seed']
    if args.seed > 0:
        seed = args.seed

    if args.seed == -1:
        seed = 0

    # set random seed for reproducibility
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    # variances for weights
    Cw = variance_list[0]
    Cb = variance_list[1]

    Kstar = 1
    if activation in ['gelu', 'swish', 'gumbellu', 'gudermanlu', 'algebraiclu']:
        df = get_all_initializations()
        # act = activation[1:]
        Kstar = control ** 2 * df.loc[df['act'] == activation]['kstar'].values[0]

        if args.critical == 1:
            Cw = df.loc[df['act'] == activation]['cw'].values[0]
            Cb = control ** 2 * df.loc[df['act'] == activation]['cb'].values[0]

    elif activation == 'relu' and 'relukstar' in args.comments:
        df = get_all_initializations()
        Kstar = 0.1 ** 2 * df.loc[df['act'] == 'swish']['kstar'].values[0]

    if 'rescaledlr' in args.comments:
        lr /= Kstar

    print("Cw = " + str(Cw), "Cb = " + str(Cb))
    # Only store one network at a time
    torch.cuda.empty_cache()

    #     # Create current instance
    #     global pre_act_logs, post_act_logs
    #     pre_act_logs=[]
    #     post_act_logs=[]

    net = model(width, depth, input_dim, classes, activation, Cw, Cb, control, power, args.default_init, args)
    net.to(device)
    net.name = activation + "_CW_" + str(Cw) + "_Cb_" + str(Cb) + "_width_" + str(width) + "_depth_" \
               + str(depth) + "_control_" + str(control) + "_seed_" + str(seed) + f'_{args.task_name}'

    net.args = vars(args)

    # global fnet

    if args.task_name == 'mnist':
        list_transforms = [transforms.ToTensor()]

        if not 'nonormalize' in args.comments:
            list_transforms.append(transforms.Normalize((0.1307,), (0.3081,)))

        if args.renormalize == 1:
            list_transforms.append(transforms.Normalize(0, 1.0 / np.sqrt(Kstar + 1e-16)))

        transform = transforms.Compose(list_transforms)

        trainset = torchvision.datasets.MNIST(root=args.data_folder, train=True, download=True, transform=transform)
        testset = torchvision.datasets.MNIST(root=args.data_folder, train=False, download=True, transform=transform)

    elif args.task_name == 'cifar10':
        list_transforms = [transforms.ToTensor()]

        if not 'nonormalize' in args.comments:
            list_transforms.append(transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)))

        list_transforms.append(transforms.Grayscale())

        if args.renormalize == 1:
            list_transforms.append(transforms.Normalize(0, 1.0 / np.sqrt(Kstar + 1e-16)))
        transform = transforms.Compose(list_transforms)

        trainset = torchvision.datasets.CIFAR10(root=args.data_folder, train=True, download=True, transform=transform)
        testset = torchvision.datasets.CIFAR10(root=args.data_folder, train=False, download=True, transform=transform)

    else:
        raise NotImplementedError

    n_classes = np.amax(torch.as_tensor(trainset.targets).cpu().numpy()) + 1

    trainset_fnet = trainset
    if not 'minibatch' in args.comments:
        train_targets = torch.as_tensor(trainset.targets)
        n_train_samples = 2000
        ntrspc = int(n_train_samples / n_classes)  # n_train_samples_per_class
        idx = np.arange(len(train_targets))

        idx_goal = np.zeros(n_train_samples, dtype=int)
        for i in range(n_classes):
            idx_goal[i * ntrspc:(i + 1) * ntrspc] = idx[(train_targets == i).cpu().numpy()][:ntrspc]

        trainset = torch.utils.data.Subset(trainset, indices=idx_goal)

        n_test_samples = 2000
        ntespc = int(n_train_samples / n_classes)  # n_train_samples_per_class

        test_targets = torch.as_tensor(testset.targets)
        idx = np.arange(len(test_targets))
        idx_goal = np.zeros(n_test_samples, dtype=int)
        for i in range(n_classes):
            idx_goal[i * ntespc:(i + 1) * ntespc] = idx[(test_targets == i).cpu().numpy()][:ntespc]

        testset = torch.utils.data.Subset(testset, indices=idx_goal)

        batch_size = n_train_samples
        test_batch_size = n_test_samples
        shuffle = False
    else:
        batch_size = 2048
        test_batch_size = batch_size
        shuffle = True

    train_dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=shuffle)
    test_dataloader = DataLoader(testset, batch_size=test_batch_size, shuffle=shuffle)

    idx = np.arange(len(trainset_fnet.targets))
    idx_goal = np.zeros(n_classes, dtype=int)
    for i in range(n_classes):
        idx_goal[i:(i + 1)] = idx[(torch.as_tensor(trainset_fnet.targets) == i).cpu().numpy()][0]

    subset_train_fnet = torch.utils.data.Subset(trainset_fnet, indices=idx_goal)
    fnet_dataloader = DataLoader(subset_train_fnet, batch_size=10, shuffle=False)

    optimizer = torch.optim.SGD(net.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    loss_f = torch.nn.CrossEntropyLoss()

    for epoch in range(epochs):
        net.epoch = epoch
        print("epoch", epoch)

        if 'NTK' in args.comments:
            for batch, label in fnet_dataloader:
                batch = batch.reshape([batch.shape[0], -1])  # .to(device)

                NTKhat_ab = ntk(net, batch, batch, compute='trace').to(device)
                NTK_diag = abs(NTKhat_ab.detach().cpu().numpy()).trace()
                NTK_off_diag = abs(NTKhat_ab.detach().cpu().numpy()).sum() - NTK_diag
                net.NTK.append([NTK_diag, NTK_off_diag])
        else:
            net.NTK.append([])

        net.train()
        train(train_dataloader, net, loss_f, optimizer, device, args)
        test(test_dataloader, net, loss_f, device, args)

    net.state_dict(EXPERIMENTDIR)
    time_elapsed = (time.perf_counter() - time_start)
    print('All done, in ' + str(time_elapsed) + 's')

    results = {}
    results.update(vars(args))
    results.update(time_elapsed=time_elapsed)
    results.update(hostname=socket.gethostname())
    string_result = json.dumps(results, indent=4)
    path = os.path.join(EXPERIMENTDIR, 'results.txt')
    with open(path, "w") as f:
        f.write(string_result)

    shutil.make_archive(EXPERIMENTDIR, 'zip', EXPERIMENTDIR)


if __name__ == '__main__':
    """Entry point."""
    main()
