import pandas as pd
from torch import nn
from torch.utils.data import Subset
import matplotlib.pyplot as plt
from lstnn.dataset import get_dataset, PuzzleDataset
from lstnn.model import FFN, LSTM_combined, LSTM
from lstnn.seed import set_global_seed
from curricula import get_curriculum
import numpy as np
import torch
import math
import argparse

# Create parser for options
parser = argparse.ArgumentParser()

# These parameters must be passed to the function
parser.add_argument('--model_label',
                    type=str)

parser.add_argument('--num_layers',
                    type=int)

parser.add_argument('--dropout',
                    type=int)

parser.add_argument('--hidden_size',
                    type=int)

parser.add_argument('--curriculum',
                    type=str)

parser.add_argument('--seed',
                    type=int,
                    default=None)

parser.add_argument('--device',
                    type=str,
                    default='cpu')

parser.add_argument('--learning_rate',
                    type=float)

parser.add_argument('--training_acc_cutoff',
                    type=float)

parser.add_argument('--cutoff_length',
                    type=float)

parser.add_argument('--weight_init',
                    type=float)

# global parameters
training_files = ['../data/nn/generated_puzzle_data_binary_dist80.csv',
                  '../data/nn/generated_puzzle_data_ternary_dist80.csv',
                  '../data/nn/generated_puzzle_data_quaternary_dist80.csv']
validation_file = '../data/nn/puzzle_data_original.csv'

# number of threads to use
torch.set_num_threads(8)

# size of minibatches for training and testing
# for testing there is only 108 total
train_batch_size = 36
valid_batch_size = 108

# input and output sizes
input_size = 4 * 4 * 5  # latin squares dimensions
output_size = 4         # 4 possible motor responses


def results_to_df(df, results, epoch, block):
    # little helper function to save results
    # to df
    for phase in ['train', 'test', 'validation']:
        # save out total scores
        row = pd.DataFrame({'epoch': [epoch],
                            'block': [block],
                            'loss': results[phase+'_loss'],
                            'accuracy': results[phase+'_acc']['total'],
                            'condition': 'average',
                            'phase': phase})
        df = pd.concat([df, row], ignore_index=True)

        # dynamically get conditions as training
        # phase may only include certain conds
        conds = list(results[phase+'_acc'].keys())
        conds.remove('total')
        for condition in conds:
            row = pd.DataFrame({'epoch': [epoch],
                                'block': [block],
                                'loss': np.nan,
                                'accuracy': results[phase+'_acc'][condition],
                                'condition': condition,
                                'phase': phase})
            df = pd.concat([df, row], ignore_index=True)
    return df


def evaluate_model(dataloader, model, loss_fn, device='cpu'):

    # define conditions
    if type(dataloader.dataset) is torch.utils.data.dataset.Subset:
        conditions = list(np.array(dataloader.dataset.dataset.conditions)[
            dataloader.dataset.indices])
    elif type(dataloader.dataset) is PuzzleDataset:
        conditions = dataloader.dataset.conditions

    # Placeholders to save the loss and accuracy at each iteration
    test_loss = []
    test_acc = {'total': []}
    for condition in np.unique(conditions):
        test_acc[condition.lower()] = []

    with torch.no_grad():
        for i, batch in enumerate(dataloader):

            # get features
            test_features, test_labels, index = batch[0], batch[1], batch[2]
            test_features = test_features.to(device)
            test_labels = test_labels.to(device)

            # Compute prediction and loss
            out = model(test_features)
            print(out)
            loss = loss_fn(out, test_labels)

            # compute task accuracy
            accuracy = torch.sum(torch.argmax(out, dim=1) == torch.argmax(
                test_labels, dim=1)).item() / test_labels.shape[0]

            # Store current values of loss and accuracy
            test_loss.append(loss.item())
            test_acc['total'].append(accuracy)

            # calculate accuracy per condition
            acc_trial = torch.argmax(out, dim=1) == torch.argmax(
                test_labels, dim=1)

            for condition in np.unique(batch[3]):
                avg_acc = np.mean(acc_trial.cpu().numpy()[
                                  np.array(batch[3]) == condition])
                test_acc[condition.lower()].append(avg_acc)

    # calc averages
    avg_test_acc = {key: np.mean(test_acc[key]) for key in test_acc}
    avg_test_loss = np.mean(test_loss)
    return avg_test_loss, avg_test_acc


def train(dataloader, model, loss_fn, optimizer, device='cpu'):

    # define device
    device = torch.device(device)
    model.to(device)

    # define conditions
    if type(dataloader.dataset) is torch.utils.data.dataset.Subset:
        conditions = list(np.array(dataloader.dataset.dataset.conditions)[
            dataloader.dataset.indices])
    elif type(dataloader.dataset) is src.dataset.PuzzleDataset:
        conditions = dataloader.dataset.conditions

    # Placeholders to save the loss and accuracy at each iteration
    train_loss = []
    train_acc = {'total': []}
    for condition in np.unique(conditions):
        train_acc[condition.lower()] = []

    for i, batch in enumerate(dataloader):

        # get features
        train_features, train_labels, index = batch[0], batch[1], batch[2]
        train_features = train_features.to(device)
        train_labels = train_labels.to(device)

        # Compute prediction and loss
        out = model(train_features)
        loss = loss_fn(out, train_labels)

        # Backpropagation
        optimizer.zero_grad()  # clear previous gradients
        loss.backward()        # compute gradients
        optimizer.step()       # update weights

        # compute task accuracy
        accuracy = torch.sum(torch.argmax(out, dim=1) == torch.argmax(
            train_labels, dim=1)).item() / train_labels.shape[0]

        # Store current values of loss and accuracy
        train_loss.append(loss.item())
        train_acc['total'].append(accuracy)

        # calculate accuracy per condition
        acc_trial = torch.argmax(out, dim=1) == torch.argmax(
            train_labels, dim=1)

        for condition in np.unique(batch[3]):
            avg_acc = np.mean(acc_trial.cpu().numpy()[
                              np.array(batch[3]) == condition])
            train_acc[condition.lower()].append(avg_acc)

    # calc averages
    avg_train_acc = {key: np.mean(train_acc[key]) for key in train_acc}
    avg_train_loss = np.mean(train_loss)

    return {'train_loss': avg_train_loss, 'train_acc': avg_train_acc}, model


def run_model(model_label, num_layers=1, dropout=0.0,
              hidden_size=160, curriculum='All', learning_rate=0.01,
              training_acc_cutoff=0.99, cutoff_length=10, weight_init=1.0,
              seed=None, device='cpu'):

    # Initialise the model to be used
    if model_label == 'FFN':
        model = FFN(hidden_size=hidden_size,
                    input_dim=input_size,
                    output_dim=output_size)

    elif model_label == 'LSTM':
        model = LSTM(hidden_size=hidden_size,
                     num_layers=num_layers,
                     dropout=dropout,
                     bidirectional=True,
                     device=device)

    elif model_label == 'LSTMcombined':
        model = LSTM_combined(hidden_size=hidden_size,
                              num_layers=num_layers,
                              bidirectional=True,
                              device=device)
    # set seed
    if seed is None:
        pass
    else:
        set_global_seed(seed)

    # init weights
    if weight_init is None:
        pass

    # elif weight_init == 'xavier':
    #     print('xavier')
    #     for name, p in model.named_parameters():
    #         if 'lstm' in name:
    #             if 'weight_ih' in name:
    #                 nn.init.xavier_uniform_(p.data)
    #             elif 'weight_hh' in name:
    #                 nn.init.orthogonal_(p.data)
    #             elif 'bias_ih' in name:
    #                 p.data.fill_(0)
    #                 # Set forget-gate bias to 1
    #                 n = p.size(0)
    #                 p.data[(n // 4):(n // 2)].fill_(1)
    #             elif 'bias_hh' in name:
    #                 p.data.fill_(0)
    #         elif 'fc' in name:
    #             if 'weight' in name:
    #                 nn.init.xavier_uniform_(p.data)
    #             elif 'bias' in name:
    #                 p.data.fill_(0)

    else:
        for name, weight in model.named_parameters():
            if 'weight' in name:
                if 'lstm' in name and model.bidirectional is True:
                    stdv = weight_init / math.sqrt(model.hidden_size/2)
                else:
                    stdv = weight_init / math.sqrt(model.hidden_size)
                nn.init.normal_(weight, 0, stdv)

    # define the loss and optimiser
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    loss_fn = nn.CrossEntropyLoss()

    # start training
    df = pd.DataFrame(columns=['epoch', 'block', 'loss',
                               'accuracy', 'condition', 'phase'])

    # create the datasets and data loaders
    train_index, test_index, train_df = get_curriculum(
        curriculum, training_files)

    train_dataloader = torch.utils.data.DataLoader(Subset(get_dataset(
        training_files), train_index), batch_size=train_batch_size, shuffle=True)

    test_dataloader = torch.utils.data.DataLoader(Subset(get_dataset(
        training_files), test_index), batch_size=train_batch_size, shuffle=True)

    valid_dataloader = torch.utils.data.DataLoader(get_dataset(
        validation_file), batch_size=valid_batch_size, shuffle=False)

    # train until cri.dateria is reached
    epoch = 0
    mean_train_correct = 0
    cutoff_satisfied = 0
    while cutoff_satisfied < cutoff_length:

        # training
        results, model = train(train_dataloader,
                               model,
                               loss_fn,
                               optimizer,
                               device=device)

        # update mean training accuracy and cutoff
        mean_train_correct = results['train_acc']['total'].copy()
        if mean_train_correct > training_acc_cutoff:
            cutoff_satisfied += 1
        else:  # reset the cutoff
            cutoff_satisfied = 0

        # testing
        results['test_loss'], results['test_acc'] = evaluate_model(
            test_dataloader, model, loss_fn, device)
        mean_test_correct = results['test_acc']['total'].copy()

        # validating
        results['validation_loss'], results['validation_acc'] = evaluate_model(
            valid_dataloader, model, loss_fn, device)
        mean_valid_correct = results['validation_acc']['total'].copy()

        print('Epoch ', epoch,
              ': Train acc = ', np.round(mean_train_correct, 3),
              ', Test acc = ', np.round(mean_test_correct, 3),
              ', Val. acc = ', np.round(mean_valid_correct, 3)
              )

        # update results df
        df = results_to_df(df, results, epoch, block=0)

        # update epoch
        epoch += 1

    # save outputs
    out = f"../results_new/model-{model_label}_" \
          f"nl-{num_layers}_" \
          f"do-{dropout}_" \
          f"hs-{hidden_size}_" \
          f"curr-{curriculum}_" \
          f"lr-{learning_rate}_" \
          f"co-{training_acc_cutoff}_" \
          f"col-{cutoff_length}_" \
          f"wi-{weight_init}_" \
          f"s-{seed}"
    torch.save(model.state_dict(), out+'.pt')
    df.to_csv(out+'.csv', index=False)
    return model, df


if __name__ == '__main__':
    # Read in user-specified parameters
    args = parser.parse_args()

    # run the glm
    run_model(args.model_label, args.num_layers, args.dropout,
              args.hidden_size,
              args.curriculum, args.learning_rate,
              args.training_acc_cutoff, args.cutoff_length,
              args.weight_init, args.seed, args.device)
