import pandas as pd
from torch import nn
from torch.utils.data import Subset
import matplotlib.pyplot as plt
from lstnn.dataset import get_dataset, PuzzleDataset
from lstnn.model import FFN, LSTM_combined, LSTM
import lstnn.transformer_main as transformer_main
from lstnn.seed import set_global_seed
import lstnn
import numpy as np
import torch
import time
import argparse
import os
import warnings
import lstnn.fmri.fmri_dataset as fmri_dataset
import lstnn.fmri.fmri_transformer as fmri_transformer
from torchmetrics import MeanSquaredError
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torch")



# number of threads to use
torch.set_num_threads(8)
parser = argparse.ArgumentParser('./main.py', description='Train RNN model on primitives dataset')
#parser.add_argument('--seed', type=int, default=1, help='seed (int) | default: 1')
parser.add_argument('--model_label', type=str, default='fMRI-Transformer', help='model (str) | default: Transformer')
parser.add_argument('--seeds', nargs='+', type=int, help='<Required> Set flag, list of seeds', required=True)
parser.add_argument('--attnheads', type=int, default=4, help='attention heads per layer (int) | default: 1')
parser.add_argument('--layers', type=int, default=1, help='num transformer layer/blocks (int) | default: 1')
parser.add_argument('--pe', type=str, default='1dpe', help='positional encoding (str) | default: None (absolute 1d)')
parser.add_argument('--pe_init', type=float, default=1.0, help='initialization SD for learnable positional encodings | default: 1.0')
parser.add_argument('--embedding_dim', type=int, default=64, help='embedding dimensionality (int) | default: 1')
parser.add_argument('--mask', type=float, default=0.15, help='percentage of tokens to mask | default: 0.15')
parser.add_argument('--lr', type=float, default=0.0001, help='weight decay (float)) | default: 0.01')
parser.add_argument('--device', type=str, default='cuda', help='device (str) | default: mps')
parser.add_argument('--n_iterations', type=int, default=10000, help='')
parser.add_argument('--dropout', type=float, default=0.0, help='')
parser.add_argument('--wdecay', type=float, default=0.0, help='')
parser.add_argument('--optimizer', type=str, default='AdamW', help='Optimizer (str) | default: AdamW')


def run(args):

    datadir = '~/projects/hcp_preproc/'
    datadir = os.path.expanduser(datadir)
    attnheads = args.attnheads
    nblocks = args.layers
    hidden_size = args.embedding_dim
    learning_rate = args.lr
    mask = args.mask
    seeds = args.seeds
    device = args.device
    device = torch.device(device)
    pe = args.pe
    pe_init = args.pe_init
    n_iterations = args.n_iterations
    dropout = args.dropout
    wdecay = args.wdecay
    model_label = args.model_label
    optim = args.optimizer

    if pe_init!=1.0 and pe!='learn': 
        # altering pe_init only makes sense with pe=='learn'
        raise Exception("This argument doesn't make sense")

    if pe=='learn': 
        pestr = 'learn-' + str(pe_init)
    else:
        pestr = pe

    pe_dict = {'1dpe':'absolute',
               '2dpe':'absolute2d',
               'rope':'rope',
               'rope2':'rope2',
               'shaw':'relative',
               'nope':'nope',
               'cnope':'cnope',
               'scnope':'scnope',
               'rndpe':'rndpe',
               'rnd2':'rnd2',
               'learn0':'learn0',
               'learn':'learn',
               'clearn':'clearn',
               'learnu':'learnu',
               'brain':'brain'
               }
    
    # fixed parameters
    checkpoint_freq = 5000 #iterations

    if optim == 'sgd':
        optim_str = 'sgd_'
    else:
        optim_str = ''

    # create results directory
    resultdir = f"/dccstor/synagi_compgen/LSTNN/results/fmri/" \
                f"model-{model_label}_{optim_str}" \
                f"mask-{mask}_" \
                f"pe-{pestr}_" \
                f"nl-{nblocks}_" \
                f"do-{dropout}_" \
                f"wd-{wdecay}_" \
                f"at-{attnheads}_" \
                f"hs-{hidden_size}_" \
                f"lr-{learning_rate}/"
    os.makedirs(resultdir, exist_ok = True) 
    
    max_tokens = 360  # number of brain regions in Glasser et al. 2016 atlas
    input_dim = 1 # 5 possible input codes

    # size of batches
    train_batch_size = 128
    valid_batch_size = 128

    #### Load in training data
    print('Loading in training data')
    train_dataset = fmri_dataset.fMRIDataset(
        datadir=datadir,
        subjectset=1 # training data cohort
        )
    train_datadict = train_dataset.data2dict(zscore=True)
    #### Load in validation data
    print('Loading in validation data')
    valid_dataset = fmri_dataset.fMRIDataset(
        datadir=datadir,
        subjectset=2 # training data cohort
        )
    valid_datadict = valid_dataset.data2dict(zscore=True)

    for seed in seeds:
        # initialize result dataframe
        df_results = pd.DataFrame(columns=['Iteration', 'Loss', 'Condition'])
        set_global_seed(seed)

        model = fmri_transformer.Transformer(input_dim=input_dim,
                    output_dim=input_dim, #masked pretraining so outputdim is inputdim
                    max_tokens=max_tokens,
                    nhead=attnheads,
                    nblocks=nblocks,
                    embedding_dim=hidden_size,
                    dropout=dropout,
                    positional_encoding=pe_dict[pe],
                    pe_init=pe_init)

        device = torch.device(device)
        model.to(device)

        # define the loss and optimiser
        if optim == 'sgd':
            optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
        elif optim == 'AdamW':
            optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=wdecay)
        else:
            raise Exception('Wrong optimizer')
        loss_fn = nn.MSELoss()
        mse_metric = MeanSquaredError().to(device)

        # create the datasets and data loaders
        

        train_datasampler = fmri_dataset.DatasetSampler(train_datadict,train_dataset.subjects)
        train_dataloader = torch.utils.data.DataLoader(train_datasampler, batch_size=train_batch_size, shuffle=True)
        train_dataloader_iter = iter(train_dataloader)

        valid_datasampler = fmri_dataset.DatasetSampler(valid_datadict,valid_dataset.subjects)
        valid_dataloader = torch.utils.data.DataLoader(valid_datasampler, batch_size=valid_batch_size, shuffle=True)
        valid_dataloader_iter = iter(valid_dataloader)


        # train until cri.dateria is reached
        iteration = 0
        start = time.time()
        #while cutoff_satisfied < cutoff_length:
        while iteration <= n_iterations:

            try:
                batch = next(train_dataloader_iter)
            except:
                train_dataloader_iter = iter(train_dataloader)
                batch = next(train_dataloader_iter)
            batch = batch.to(device)

            # load checkpoint, if it exists
            out = f"s-{seed}_" \
                  f"i-{iteration}" 
            if os.path.isfile(resultdir + out + '.pt'):
                # old: overwrite for now
                # os.remove(resultdir + out + '.pt')
                next_checkpt = f"s-{seed}_i-{iteration+checkpoint_freq}"
                if os.path.isfile(resultdir + next_checkpt + '.pt'):
                    iteration += checkpoint_freq
                    print(resultdir + out, 'checkpoint exists, skipping to iteration', iteration)
                    continue
                else:
                    df_results = pd.read_csv(f"{resultdir}/df_{seed}.csv")
                    # ONly keep iterations up until iteration
                    df_results = df_results.loc[df_results.Iteration<=iteration]
                    model.load_state_dict(torch.load(resultdir + out + '.pt',map_location=device))
                    iteration += 1
                    print(resultdir + out, 'checkpoint exists, skipping to iteration', iteration)

            # masked training
            mask_tensor = torch.rand(batch.shape)
            mask_tensor = mask_tensor > mask
            flip_mask = mask_tensor == False
            masked_batch = torch.mul(batch,mask_tensor.to(device))
            out = model(masked_batch)
            #### Only predict tokens that are masked
            loss = loss_fn(out[flip_mask], batch[flip_mask]) 

            # Backpropagation
            optimizer.zero_grad()  # clear previous gradients
            loss.backward()        # compute gradients
            optimizer.step()       # update weights

            if iteration % 100 == 0:
                # evaluate performance on validation loss
                try:
                    valid_batch = next(valid_dataloader_iter)
                except:
                    valid_dataloader_iter = iter(valid_dataloader)
                    valid_batch = next(valid_dataloader_iter)
                valid_batch = valid_batch.to(device)

                mask_tensor = torch.rand(valid_batch.shape)
                mask_tensor = mask_tensor > mask
                masked_valid_batch = torch.mul(valid_batch,mask_tensor.to(device))
                model.eval()
                with torch.no_grad():
                    valid_out = model(masked_valid_batch)
                    valid_loss = mse_metric(valid_out, valid_batch)
                model.train()
                
                end = time.time()
                print('Iteration ', iteration,
                    ': Train loss = ', np.round(loss.detach().item(), 7),
                    ', Val. loss = ', np.round(valid_loss.item(),7),
                    ', time = ', np.round(end-start, 3)
                    )
                # update results df
                #### Update training loss
                row = pd.DataFrame({'Iteration': iteration,
                    'Condition': 'Train',               
                    'Loss': loss.detach().item(),               
                    'Seed': seed,               
                    },index=[0])
                df_results = pd.concat([df_results,row],ignore_index=True)
                #### Update validation loss
                row = pd.DataFrame({'Iteration': iteration,
                    'Condition': 'Validation',               
                    'Loss': valid_loss.item(),               
                    'Seed': seed,               
                    },index=[0])
                df_results = pd.concat([df_results,row],ignore_index=True)
                df_results.to_csv(f"{resultdir}/df_{seed}.csv", index=False)

            if iteration % checkpoint_freq == 0:

                # save model
                out = f"s-{seed}_" \
                      f"i-{iteration}" 
                print(resultdir + out)
                torch.save(model.state_dict(), resultdir + out+'.pt')

            # update epoch
            iteration += 1

if __name__ == '__main__':
    args = parser.parse_args()
    run(args)

