import torch
from torch.nn.modules import padding
from tqdm import tqdm
import numpy as np
import os
import argparse
from datetime import datetime
import sys
import time
import logging
import copy

sys.path.append('./')
sys.path.append('../')
import logging
from src.tlp_rnn_fusion.rnn_dataset_torch import HomogeneousMNIST, SSTDatasetPT, ClassificationDataset
from src.tlp_rnn_fusion.rnn_models import RNNWithDecoder, RNNWithEncoderDecoder, LSTMWithDecoder, LSTMWithEncoderDecoder
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
from src.tlp_rnn_fusion import embedding
from src.tlp_rnn_fusion import split_mnist


RANDOM_SEED = 543

GLOVE_EMBEDDING = None


def get_dataloaders(exp_args):
    global GLOVE_EMBEDDING
    if exp_args.dataset_name == "MyMNIST":
        if exp_args.normalize:
            transform = transforms.Compose([transforms.Normalize((0.1307,), (0.3081,))])
        else:
            transform = None
        train_dataset = HomogeneousMNIST(exp_args.train_data_path, transform=transform)
        test_dataset = HomogeneousMNIST(exp_args.test_data_path, transform=transform)
    elif exp_args.dataset_name == "MNISTNorm":
        transform = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize((0.1307,), (0.3081,))])
        train_dataset = datasets.MNIST(root=exp_args.train_data_path, download=True, train=True, transform=transform)
        test_dataset = datasets.MNIST(root=exp_args.test_data_path, train=False, download=True, transform=transform)
    elif exp_args.dataset_name == "SplitMNIST":
        if exp_args.normalize:
            transform = transforms.Compose([transforms.ToTensor(),
                                            transforms.Normalize((0.1307,), (0.3081,))])
        else:
            transform = transforms.Compose([transforms.ToTensor()])
        train_dataset = split_mnist.SplitMNIST(root='./data', nsplits=exp_args.nsplits,
                                               split_index=exp_args.split_index, download=True,
                                               train=True, transform=transform, scale_factor=exp_args.ds_scale_factor)
        test_dataset = split_mnist.SplitMNIST(root='./data', nsplits=exp_args.nsplits,
                                              split_index=exp_args.split_index, download=True,
                                              train=False, transform=transform, scale_factor=exp_args.ds_scale_factor)
    elif exp_args.dataset_name == "SSTPT":
        if exp_args.use_compact_embedding:
            GLOVE_EMBEDDING = embedding.CompactGloveEmbedding(exp_args.glove_path, dataset_name=exp_args.dataset_name,
                                                              data_path=exp_args.train_data_path)
        else:
            GLOVE_EMBEDDING = embedding.GloveEmbedding(exp_args.glove_path)
        train_dataset = SSTDatasetPT(exp_args.train_data_path, glove_embedding=GLOVE_EMBEDDING, train=True)
        test_dataset = SSTDatasetPT(exp_args.test_data_path, glove_embedding=GLOVE_EMBEDDING, train=False)
    elif exp_args.dataset_name in ['AG_NEWS', 'DBpedia']:
        if exp_args.use_compact_embedding:
            GLOVE_EMBEDDING = embedding.CompactGloveEmbedding(exp_args.glove_path, dataset_name=exp_args.dataset_name,
                                                              data_path=exp_args.train_data_path)
        else:
            GLOVE_EMBEDDING = embedding.GloveEmbedding(exp_args.glove_path)
        seq_len = {'AG_NEWS': 160 if exp_args.model_name == 'lstm' else 60, 
                   'DBpedia': 100 if exp_args.model_name == 'lstm' else 60}
        train_dataset = ClassificationDataset(exp_args.train_data_path, glove_embedding=GLOVE_EMBEDDING, 
                                              tag='train', max_seq_len=seq_len[exp_args.dataset_name])
        test_dataset = ClassificationDataset(exp_args.train_data_path, glove_embedding=GLOVE_EMBEDDING, 
                                              tag='test', max_seq_len=seq_len[exp_args.dataset_name])
    else:
        raise NotImplementedError

    np.random.seed(RANDOM_SEED)
    num_train = len(train_dataset)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    valid_size = 0.1
    valid_split = int(np.floor(valid_size * num_train))
    train_idx, valid_idx = indices[valid_split:], indices[:valid_split]
    valid_sampler = SubsetRandomSampler(valid_idx)
    valloader = DataLoader(train_dataset, batch_size=exp_args.batch_size, sampler=valid_sampler)

    if exp_args.dataset_name in ["SSTPT"]:
        test_size = 0.1
        test_split = int(np.floor(test_size * num_train))
        train_idx_tmp, test_idx = train_idx[test_split:], train_idx[:test_split]
        test_sampler = SubsetRandomSampler(test_idx)
        testloader = DataLoader(train_dataset, batch_size=exp_args.batch_size, sampler=test_sampler)
        train_idx = train_idx_tmp
    else:
        testloader = DataLoader(test_dataset, batch_size=exp_args.batch_size, shuffle=False)

    train_sampler = SubsetRandomSampler(train_idx)
    trainloader = DataLoader(train_dataset, batch_size=exp_args.batch_size, sampler=train_sampler)
    
    return trainloader, valloader, testloader


def get_optimizer(exp_args, model):
    print('Optimizer is {}'.format(exp_args.optimizer))
    if exp_args.optimizer == 'Adam':
        return torch.optim.Adam(model.parameters(), lr=exp_args.learning_rate, weight_decay=exp_args.weight_decay)
    elif exp_args.optimizer == 'SGD':
        return torch.optim.SGD(model.parameters(), lr=exp_args.learning_rate, weight_decay=exp_args.weight_decay,
                               momentum=exp_args.momentum)
    else:
        raise NotImplementedError


class Trainer:

    def __init__(self, exp_args, model, trainloader, valloader, testloader,
                 initialize_embedding=False):
        self.model = model
        self.trainloader = trainloader
        self.valloader = valloader
        self.testloader = testloader
        self.exp_args = exp_args
        # if exp_args.loss_fn == "cross_entropy":
        m = torch.nn.LogSoftmax(dim=1)
        loss = torch.nn.NLLLoss(reduction='mean')

        def loss_func(last_output, y):
            return (loss(m(last_output), y))

        self.loss_fn = loss_func

        global GLOVE_EMBEDDING
        self.glove_embedding = GLOVE_EMBEDDING

        if exp_args.initialize_embedding and isinstance(self.model, (RNNWithEncoderDecoder, LSTMWithEncoderDecoder)):
            print("Initializing weights using glove embeddings!")
            self.model.encoder.weight.data = copy.deepcopy(self.glove_embedding.embedding.weight.data)
        
        self.optimizer = get_optimizer(exp_args, self.model)
        self.ds_names = ["SSTPT", 'AG_NEWS', 'DBpedia']

    def train_epoch(self, epoch, tag='train'):
        if tag == 'train':
            dataloader = self.trainloader
            self.model.train()
        elif tag == 'val':
            dataloader = self.valloader
            self.model.eval()
        elif tag == 'test':
            dataloader = self.testloader
            self.model.eval()
        else:
            raise NotImplementedError

        correct = 0
        loss = 0
        num_samples = 0
        for i_batch, samples_batched in enumerate(
                iter(dataloader)):  # samples_batched - word sentences in a batch of sentences (batch_size,100)
            #print(i_batch, " = ", time.perf_counter())
            #import pdb
            #pdb.set_trace()
            x_batched, y_batched = samples_batched  # x_batched (bz,80), y_batched (bz,80,1)
            if self.exp_args.dataset_name in ["MNISTNorm", "SplitMNIST"]:
                x_batched = torch.squeeze(x_batched)
            x_batched = x_batched.to(self.exp_args.device)
            y_batched = y_batched.to(self.exp_args.device)

            
            if self.exp_args.dataset_name in self.ds_names and not isinstance(self.model, (LSTMWithEncoderDecoder, RNNWithEncoderDecoder)):
                x_batched = self.glove_embedding.get_batch_embedding(x_batched)

            # init_hiddens = [torch.zeros(1,x_batched.size(0),hidden_dim).to(self.exp_args.device) for hidden_dim in self.model.channels[1:-1] ]
            # logits = self.model(x_batched,init_hiddens)
            if tag == 'train':
                logits = self.model(x_batched)  # logits size(batch_size,seq_len,vocab_size)
            else:
                with torch.no_grad():
                    logits = self.model(x_batched)
            
            last_logits = logits[:, -1, :]  # size(batch_size,vocab_size)
            
            batch_loss = self.loss_fn(last_logits, y_batched)
            loss += batch_loss

            num_samples += y_batched.size(0)
            prediction = torch.argmax(last_logits.detach(), dim=1)  # size(batch_size)
            correct += torch.sum(y_batched == prediction)

            if tag == 'train':
                self.optimizer.zero_grad()
                batch_loss.backward()
                self.optimizer.step()

        accuracy = 100.0 * correct / num_samples
        print("Epoch:", epoch, ",", tag, ", accuracy:", accuracy)
        return accuracy

    def train(self):
        print('Starting train.')
        best_val_acc = 0
        val_acc = self.train_epoch(0, tag='val')
        test_acc = 0

        all_epoch_val_test = []

        for epoch in range(1, self.exp_args.num_epochs + 1):
            print('Epoch {}/{}'.format(epoch, self.exp_args.num_epochs))
            _ = self.train_epoch(epoch, tag='train')
            val_acc = self.train_epoch(epoch, tag='val')
            test_acc = self.train_epoch(epoch, tag='test')

            all_epoch_val_test.append([epoch, val_acc.item(), test_acc.item()])

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                save_path = os.path.join(self.exp_args.model_save_path, 'best_val_acc_model.pth'.format(epoch))
                torch.save({'epoch': epoch,
                            'val_acc': val_acc,
                            'test_acc': test_acc,
                            'model_state_dict': self.model.state_dict(),
                            'config': self.model.get_model_config()},
                           save_path)

        save_path = os.path.join(self.exp_args.model_save_path, 'final_model.pth')
        torch.save({'epoch': self.exp_args.num_epochs,
                    'val_acc': val_acc,
                    'test_acc': test_acc,
                    'model_state_dict': self.model.state_dict(),
                    'config': self.model.get_model_config()},
                   save_path)

        print('Training finished.')
        print('Model saved at {}'.format(save_path))
        print('Best acc val:{}, test:{}'.format(val_acc, test_acc))
        try:
            np.savetxt(save_path + "all_epoch_val_test.csv", all_epoch_val_test, delimiter=',')
            print("all_epoch_val_test was saved successfully")
        except:
            print("all_epoch_val_test not saved successfully")

    def evaluate(self, save_results=False):
        print('Starting evaluation')
        val_acc = self.train_epoch(epoch=0, tag='val')
        test_acc = self.train_epoch(epoch=0, tag='test')
        print('Validation acc:{}, Test acc:{}'.format(val_acc, test_acc))
        if save_results:
            log_path = self.exp_args.model_load_path + '.log'
            logging.basicConfig(filename=log_path, filemode='w', level=logging.INFO)
            logging.info('Validation acc:{}'.format(val_acc))
            logging.info('Test acc:{}'.format(test_acc))


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--model_name', type=str, default='rnn', choices=['rnn', 'lstm'],
                        help='name of models to fuse')
    parser.add_argument('--dataset_name', type=str, default='MNIST',
                        help='name of the dataset to use')
    parser.add_argument('--train_data_path', type=str, default='train.csv',
                        help='dataset path')
    parser.add_argument('--test_data_path', type=str, default='train.csv',
                        help='dataset path')

    parser.add_argument('--train_data_path_x', type=str, default='train.csv',
                        help='dataset path')
    parser.add_argument('--test_data_path_x', type=str, default='train.csv',
                        help='dataset path')
    parser.add_argument('--train_data_path_y', type=str, default='train.csv',
                        help='dataset path')
    parser.add_argument('--test_data_path_y', type=str, default='train.csv',
                        help='dataset path')

    parser.add_argument('--num_epochs', type=int, default=2,
                        help='Number of iterations for trainer')
    parser.add_argument('--learning_rate', type=float, default=1e-3,
                        help='Learning Rate')
    parser.add_argument('--batch_size', type=int, default=10,
                        help='batch size for training')
    parser.add_argument('--device', type=str, default='cuda',
                        help='Which device for Pytorch to use')
    parser.add_argument('--optimizer', type=str, default='SGD', )

    parser.add_argument('--evaluate', default=False, action='store_true')
    parser.add_argument('--pretrained', action='store_true',
                        help='check whether to load a previously trained model or to initialize a new model')
    parser.add_argument('--model_load_path', type=str, default='',
                        help='path to load a model')
    parser.add_argument('--model_save_path', type=str, default='',
                        help='path to save a model')

    # arguments for simple RNN models
    parser.add_argument('--vocab_size', type=int, default=None,
                        help='vocabulary size')
    parser.add_argument('--embed_dim', type=int, default=None,
                        help='embedding dimension of an RNN model')
    parser.add_argument('--hidden_dims', type=str, default=None,
                        help='list of hidden dimensions')
    parser.add_argument('--input_dim', type=int, default=None)

    parser.add_argument('--hidden_activations', type=str, default=None,
                        help='list of hidden activations')
    parser.add_argument('--encoder', default=False, action='store_true')

    parser.add_argument('--bias', action='store_true',
                        help='bias')

    parser.add_argument('--weight_decay', type=float, default=0)
    parser.add_argument('--momentum', type=float, default=0)

    parser.add_argument('--normalize', default=False, action='store_true')

    parser.add_argument('--nsplits', type=int, default=1,
                        help='Number of splits of the dataset')
    parser.add_argument('--split_index', type=int, default=1,
                        help='The current index of split dataset used!')
    parser.add_argument('--ds_scale_factor', type=float, default=1.0,
                        help='To understand effect of ds scaling')
    parser.add_argument("--seed", default=24601, type=int)
    parser.add_argument('--glove_path', type=str, default='None')

    parser.add_argument('--donot_use_embedding', default=False, action='store_true')
    parser.add_argument('--use_compact_embedding', default=False, action='store_true')
    parser.add_argument('--initialize_embedding', default=False, action='store_true')

    exp_args = parser.parse_args(sys.argv[1:])
    torch.manual_seed(exp_args.seed)  # Setting the seed for reproducible results
    print("Seed:{}".format(exp_args.seed))
    hidden_dims = [int(i) for i in exp_args.hidden_dims.strip('[]').split(',')]
    hidden_activations = None if exp_args.hidden_activations is None else exp_args.hidden_activations.strip('[]').split(
        ',')

    """
    Set up dataset loaders
    """
    trainloader, valloader, testloader = get_dataloaders(exp_args)

    """
    Set up model
    """
    print("exp_args.encoder", exp_args.encoder)
    if exp_args.model_name == "rnn" and exp_args.encoder:
        # dimensions of input, hidden, output must be specified
        padding_idx = None
        if exp_args.use_compact_embedding:
            if exp_args.dataset_name == 'SSTPT':
                padding_idx = 14309
        model = RNNWithEncoderDecoder(output_dim=exp_args.vocab_size, input_dim=exp_args.input_dim,
                                      embed_dim=exp_args.embed_dim, hidden_dims=hidden_dims,
                                      hidden_activations=hidden_activations, bias=exp_args.bias,
                                      use_embedding=not exp_args.donot_use_embedding, 
                                      padding_idx=padding_idx)
    elif exp_args.model_name == "rnn" and not exp_args.encoder:
        # dimensions of input, hidden, output must be specified
        model = RNNWithDecoder(output_dim=exp_args.vocab_size, embed_dim=exp_args.embed_dim, hidden_dims=hidden_dims,
                               hidden_activations=hidden_activations, bias=exp_args.bias)
    elif exp_args.model_name == "lstm" and exp_args.encoder:
        # dimensions of input, hidden, output must be specified
        padding_idx = None
        if exp_args.use_compact_embedding:
            if exp_args.dataset_name == 'SSTPT':
                padding_idx = 14309
        model = LSTMWithEncoderDecoder(output_dim=exp_args.vocab_size, input_dim=exp_args.input_dim,
                                       embed_dim=exp_args.embed_dim, hidden_dims=hidden_dims,
                                       hidden_activations=hidden_activations, bias=exp_args.bias,
                                       padding_idx=padding_idx)
    elif exp_args.model_name == "lstm" and not exp_args.encoder:
        # dimensions of input, hidden, output must be specified
        model = LSTMWithDecoder(output_dim=exp_args.vocab_size, embed_dim=exp_args.embed_dim, hidden_dims=hidden_dims,
                                hidden_activations=hidden_activations, bias=exp_args.bias)

    print(model.get_model_config)
    if exp_args.pretrained is True and os.path.isfile(exp_args.model_load_path):
        try:
            checkpoint = torch.load(exp_args.model_load_path)
            model.load_state_dict(checkpoint['model_state_dict'])
            print(f'----------finish loading model from path: {exp_args.model_load_path}----------')
        except:
            print("Failed to load model, Unexpected error:", sys.exc_info()[0])
            pass

    model.to(exp_args.device)

    """
    Set up trainer
    """
    trainer = Trainer(
        exp_args=exp_args,
        model=model,
        trainloader=trainloader,
        valloader=valloader,
        testloader=testloader,
        initialize_embedding=not exp_args.evaluate
    )

    """
    start training 
    """
    if exp_args.evaluate:
        trainer.evaluate(save_results=True)
    else:
        trainer.train()


if __name__ == "__main__":
    main()
