import os
import torch
import argparse
from irregular_sampled_datasets import (
    PersonData,
    ETSMnistData,
    XORData,
    Walker2dImitationData,
)
from pytorch_lightning.callbacks import Callback
import torch.utils.data as data
from torch_node_cell import mmRNN, IrregularSequenceLearner
from lipschitz_rnn import LipschitzRNN
from cornn import coRNN
import pytorch_lightning as pl
import numpy as np
from copy import deepcopy

parser = argparse.ArgumentParser()
parser.add_argument("--dataset", default="person")
parser.add_argument("--solver", default="fixed_euler")
parser.add_argument("--model", default="mmrnn")
parser.add_argument("--opt", default="adam")
parser.add_argument("--size", default=64, type=int)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--epochs", default=100, type=int)
parser.add_argument("--lr", default=0.01, type=float)
parser.add_argument("--cornn_gamma", default=2.7, type=float)
parser.add_argument("--cornn_epsilon", default=4.7, type=float)
parser.add_argument("--tau", default=1.0, type=float)
parser.add_argument("--lip_beta", default=0.8, type=float)
parser.add_argument("--lip_gamma", default=0.01, type=float)
parser.add_argument("--grad_clip", default=1.0, type=float)
parser.add_argument("--gpus", default=0, type=int)
args = parser.parse_args()


def load_dataset(args):
    validloader = None
    is_regression = False
    if args.dataset == "person":
        dataset = PersonData()
        train_x = torch.Tensor(dataset.train_x)
        train_y = torch.LongTensor(dataset.train_y)
        train_ts = torch.Tensor(dataset.train_t)
        test_x = torch.Tensor(dataset.test_x)
        test_y = torch.LongTensor(dataset.test_y)
        test_ts = torch.Tensor(dataset.test_t)
        train = data.TensorDataset(train_x, train_ts, train_y)
        test = data.TensorDataset(test_x, test_ts, test_y)
        return_sequences = True
        num_classes = int(torch.max(train_y).item() + 1)
    elif args.dataset == "walker2d":
        dataset = Walker2dImitationData()
        train_x = torch.Tensor(dataset.train_x)
        train_y = torch.Tensor(dataset.train_y)
        train_ts = torch.Tensor(dataset.train_times)
        valid_x = torch.Tensor(dataset.valid_x)
        valid_y = torch.Tensor(dataset.valid_y)
        valid_ts = torch.Tensor(dataset.valid_times)
        test_x = torch.Tensor(dataset.test_x)
        test_y = torch.Tensor(dataset.test_y)
        test_ts = torch.Tensor(dataset.test_times)
        train = data.TensorDataset(train_x, train_ts, train_y)
        valid = data.TensorDataset(valid_x, valid_ts, valid_y)
        test = data.TensorDataset(test_x, test_ts, test_y)
        return_sequences = True
        num_classes = train_x.size(-1)
        is_regression = True
        validloader = data.DataLoader(
            valid, batch_size=args.batch_size, shuffle=False, num_workers=4
        )
    else:
        if args.dataset == "et_mnist":
            dataset = ETSMnistData(time_major=False)
        elif args.dataset == "xor_dense":
            dataset = XORData(time_major=False, event_based=False, pad_size=32)
        elif args.dataset == "xor":
            dataset = XORData(time_major=False, event_based=True, pad_size=32)
        else:
            raise ValueError("Unknown dataset '{}'".format(args.dataset))
        return_sequences = False
        train_x = torch.Tensor(dataset.train_events)
        train_y = torch.LongTensor(dataset.train_y)
        train_ts = torch.Tensor(dataset.train_elapsed)
        train_mask = torch.Tensor(dataset.train_mask)
        test_x = torch.Tensor(dataset.test_events)
        test_y = torch.LongTensor(dataset.test_y)
        test_ts = torch.Tensor(dataset.test_elapsed)
        test_mask = torch.Tensor(dataset.test_mask)
        train = data.TensorDataset(train_x, train_ts, train_y, train_mask)
        test = data.TensorDataset(test_x, test_ts, test_y, test_mask)
        num_classes = int(torch.max(train_y).item() + 1)
    trainloader = data.DataLoader(
        train, batch_size=args.batch_size, shuffle=True, num_workers=4
    )
    testloader = data.DataLoader(
        test, batch_size=args.batch_size, shuffle=False, num_workers=4
    )
    in_features = train_x.size(-1)

    return (
        trainloader,
        validloader,
        testloader,
        in_features,
        num_classes,
        return_sequences,
        is_regression,
    )


(
    trainloader,
    validloader,
    testloader,
    in_features,
    num_classes,
    return_sequences,
    is_regression,
) = load_dataset(args)

if args.model == "mmrnn":
    rnn_model = mmRNN(
        in_features,
        args.size,
        num_classes,
        return_sequences=return_sequences,
        solver_type=args.solver,
    )
elif args.model == "cornn":
    rnn_model = coRNN(
        in_features,
        args.size,
        num_classes,
        gamma=args.cornn_gamma,
        epsilon=args.cornn_epsilon,
        tau=args.tau,
        return_sequences=return_sequences,
    )
elif args.model == "lipschitz":
    rnn_model = LipschitzRNN(
        in_features,
        args.size,
        num_classes,
        return_sequences=return_sequences,
        tau=args.tau,
        beta=args.lip_beta,
        gamma=args.lip_gamma,
        pi=0.0,
        init_std=1,
        alpha=1,
        # Fixed hparams taken from pMNIST of the original code repo
    )
else:
    raise ValueError("Unknown model ", args.model)


class MyPrintingCallback(Callback):
    """A callback that logs if the model can fit the training data
    by aggregating the accuracies on the training batches over each epoch"""

    def __init__(self) -> None:
        super().__init__()
        self.epoch_log = []
        self.step_log = []
        self._best_acc = -1
        self._params_backup = None

    def on_train_batch_end(
        self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
    ):
        acc = outputs[0][0]["extra"]["acc"].item()
        self.step_log.append(acc)

    def on_train_epoch_end(self, trainer, pl_module, outputs):
        acc = np.mean(self.step_log)
        if acc > self._best_acc:
            self._best_acc = acc
            self._params_backup = pl_module.state_dict()
            for k, v in self._params_backup.items():
                self._params_backup[k] = v.cpu()
            print(f"Backup created, acc {acc:0.3f}")
        if acc >= 1.0:
            print("Early stopping")
            trainer.should_stop = True
        self.step_log = []
        self.epoch_log.append(acc)

    def restore_from_best(self, pl_module):
        print("Restoring backup ...")
        for k, v in self._params_backup.items():
            self._params_backup[k] = v.to(pl_module.device)
        pl_module.load_state_dict(self._params_backup)


callbacks = []
# IS this the error?
if args.dataset in ["xor_dense", "xor"]:
    callbacks = [MyPrintingCallback()]


learn = IrregularSequenceLearner(
    rnn_model, lr=args.lr, opt=args.opt, is_regression=is_regression
)
trainer = pl.Trainer(
    max_epochs=args.epochs,
    # progress_bar_refresh_rate=1,
    gradient_clip_val=args.grad_clip,
    gpus=args.gpus,
    progress_bar_refresh_rate=30,
    callbacks=callbacks,
)
trainer.fit(learn, trainloader, val_dataloaders=validloader)

if args.dataset in ["xor_dense", "xor"]:
    callbacks[0].restore_from_best(learn)
    sanity_results = trainer.test(learn, trainloader)
    print(
        f"Restore train accuracy: {sanity_results[0]['val_acc']}, should match best train acc"
    )

results = trainer.test(learn, testloader)
score = results[0]["val_loss"] if is_regression else results[0]["val_acc"]
base_path = "results/{}".format(args.dataset)
os.makedirs(base_path, exist_ok=True)
with open("{}/pt_{}_{}.csv".format(base_path, args.model, args.size), "a") as f:
    f.write("{:06f}\n".format(score))
with open("{}/hparams.txt".format(base_path), "a") as f:
    info = ""
    if len(callbacks) > 0:
        info = "(best train_acc={:0.6f})\n".format(np.max(callbacks[0].epoch_log))
        print(info)
    f.write("{}: {:06f} {}\n".format(str(args), score, info))
