import argparse

import pytorch_lightning as pl
import torch
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader

import rnn_models
from config.experiments import val_params
from data.DataPrepare import prepare_torch_datasets, min_max_decode


def add_args(_parser):
    _parser.add_argument("--model", type=str, default='LSTM', help="Model to use")
    _parser.add_argument("--cpu", action='store_true')
    _parser.add_argument("--size", default=64, type=int)
    _parser.add_argument("--epochs", default=300, type=int)
    _parser.add_argument("--seq_len", default=30, type=int)
    _parser.add_argument("--batch_size", default=16, type=int)
    _parser.add_argument("--lr", default=0.001, type=float)
    _parser.add_argument("--loss", default='l1', type=str)
    _parser.add_argument("--num_runs", default=3, type=int)
    _parser.add_argument("--norm_visitors", action='store_true')
    _parser.add_argument("--extra_features", action='store_true')
    _parser.add_argument("--use_near", action='store_true')
    return _parser


class LitRNNModel(pl.LightningModule):
    def __init__(self, model_cls, model_size=64, seq_len=30, loss='mse',
                 batch_size=16, lr=0.001, norm_visitors=False, extra_features=True, use_near=False):
        super().__init__()
        self.save_hyperparameters()
        self.lr = lr
        self.norm_visitors = norm_visitors
        self.batch_size = batch_size
        self.train_dataset, self.eval_dataset, mins, maxs = prepare_torch_datasets(seq_len=seq_len,
                                                                       normalize_visitors=norm_visitors,
                                                                       extra_features=extra_features,
                                                                       use_near=use_near)
        self.register_buffer('mins', torch.tensor(mins).type(torch.LongTensor))
        self.register_buffer('maxs', torch.tensor(maxs).type(torch.LongTensor))

        x, y = self.train_dataset[0]
        cell = model_cls(units=model_size, input_size=x.shape[-1])

        self.cell = cell
        self.output_layer = torch.nn.Linear(model_size, y.shape[-1])

        if loss == 'mse':
            self.loss = F.mse_loss
        elif loss == 'l1':
            self.loss = F.l1_loss
        elif loss == 'huber':
            self.loss = F.huber_loss
        else:
            raise Exception("Unknown Loss function selected")

    def on_train_start(self):
        print(self.hparams)
        self.logger.log_hyperparams(self.hparams, metrics={"loss/train": 0, "loss/val": 0})

    def forward(self, x):
        out = self.cell(x)
        if isinstance(out, tuple):
            out = out[0]
        return self.output_layer(out)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        if self.norm_visitors:
            de_normed_pred = min_max_decode(y_hat, self.mins, self.maxs)
            de_normed_labels = min_max_decode(y, self.mins, self.maxs)
            loss = self.loss(de_normed_pred, de_normed_labels)
        else:
            loss = self.loss(y_hat, y)
        #loss = self.loss(y_hat, y)
        self.log(f"loss/train", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        if self.norm_visitors:
            de_normed_pred = min_max_decode(y_hat[:, -1], self.mins, self.maxs)
            de_normed_labels = min_max_decode(y, self.mins, self.maxs)
            loss = self.loss(de_normed_pred, de_normed_labels)
        else:
            loss = self.loss(y_hat[:, -1], y)
        self.log(f"loss/val", loss)
        return loss

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.eval_dataset, batch_size=val_params['batch_size'], num_workers=4)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=(self.lr or self.learning_rate))
        # lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer]  # , [lr_scheduler]


def train(model=None,
          num_runs=3,
          epochs=300,
          cpu=False,
          loss='mse',
          model_size=64,
          seq_len=30,
          batch_size=16,
          lr=0.001,
          norm_visitors=False,
          extra_features=False,
          out_folder='logs',
          use_near=False,
          **kwargs):
    if not cpu and not torch.cuda.is_available():
        print("Warning! CUDA not available")

    results = []
    if not hasattr(rnn_models, model):
        raise Exception("ERROR: Unknown model type '{}'".format(model))
    model_cls = getattr(rnn_models, model)
    for i in range(num_runs):
        print("Training " + model_cls.__name__ + f" {i + 1}/{num_runs}")
        pl.seed_everything(i)
        lit_model = LitRNNModel(model_cls, model_size, seq_len, loss, batch_size, lr, norm_visitors, extra_features,
                                use_near)
        trainer = pl.Trainer(max_epochs=epochs, detect_anomaly=True,
                             auto_lr_find=False,
                             check_val_every_n_epoch=30,
                             devices=1,
                             accelerator='cpu' if cpu else 'auto',
                             logger=TensorBoardLogger(out_folder,
                                                      name=model_cls.__name__,
                                                      version=loss + '_' + str(model_size) +
                                                              ('_features' if norm_visitors else '') +
                                                              ('_norm' if norm_visitors else '') +
                                                              f"_{str(i)}"))
        trainer.tune(lit_model)
        trainer.fit(lit_model)
        results.append(trainer.validate(lit_model))

    return results
    # print("Example output")
    # np.save('y_pred',y_pred.detach().cpu().numpy())
    # np.save('y_train',train_y.detach().cpu().numpy())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()
    train(args.model, args.num_runs, args.epochs, args.cpu, args.loss, model_size=args.size, seq_len=args.seq_len,
          batch_size=args.batch_size, lr=args.lr, norm_visitors=args.norm_visitors, extra_features=args.extra_features)
