import argparse
import os.path

import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader
from torch_geometric_temporal import A3TGCN2, DCRNN
from torch_geometric_temporal.nn.recurrent.agcrn import AVWGCN

from config.experiments import val_params
from data.DataPrepare import split_sequences, min_max_encode, get_extra_features, prepare_graph_datasets, min_max_decode


def add_args(_parser):
    _parser.add_argument("--cpu", action='store_true')
    _parser.add_argument("--epochs", default=1000, type=int)
    _parser.add_argument("--seq_len", default=30, type=int)
    _parser.add_argument("--batch_size", default=16, type=int)
    _parser.add_argument("--lr", default=0.001, type=float)
    _parser.add_argument("--num_runs", default=3, type=int)
    _parser.add_argument("--output_size", default=1, type=int)
    _parser.add_argument("--global_node_features", default=1, type=int)
    _parser.add_argument("--norm_visitors", action='store_true')
    _parser.add_argument("--loss_ratio_poi", type=float, default=0.9)
    _parser.add_argument("--loss", default='mse', type=str)
    return _parser


class LitGNNModel(pl.LightningModule):

    def __init__(self, node_features, output_size,
                 seq_len=30, batch_size=16, lr=0.0001, global_node_features=2,
                 norm_visitors=False, loss_ratio_poi=.5,
                 prediction_file=None, loss='mse', **kwargs):
        super().__init__()
        self.prediction_file = prediction_file
        self.loss_ratio_poi = loss_ratio_poi
        self.batch_size = batch_size
        self.save_hyperparameters()
        self.global_node_features = global_node_features
        self.lr = lr
        self.norm_visitors = norm_visitors

        if loss == 'mse':
            self.loss = F.mse_loss
        elif loss == 'l1':
            self.loss = F.l1_loss
        elif loss == 'huber':
            self.loss = F.huber_loss
        else:
            raise Exception("Unknown Loss function selected")

        if self.prediction_file:
            os.makedirs(os.path.dirname(self.prediction_file), exist_ok=True)

        graph_data = prepare_graph_datasets(seq_len=seq_len, normalize_visitors=norm_visitors)
        self.train_dataset, self.val_dataset, edges, self.num_pois, mins, maxs = graph_data
        self.register_buffer('mins', torch.tensor(mins).type(torch.LongTensor))
        self.register_buffer('maxs', torch.tensor(maxs).type(torch.LongTensor))
        self.register_buffer('edge_index', torch.from_numpy(edges[['u', 'v']].values.T).type(torch.LongTensor))
        self.register_buffer('edge_weights', torch.from_numpy(edges['length'].values).type(torch.FloatTensor))

        # x = np.split(x.transpose(0, 2, 1)[..., None, :], indices_or_sections=x.shape[0], axis=0)
        # y = y.transpose(0, 2, 1)
        #
        # dataset = StaticGraphTemporalSignal(edge_index=self.edge_index, edge_weight=self.edge_weights,
        #                                     features=x,
        #                                     targets=y)  # , batches=np.arange(x.shape[0]))
        #
        # self.train_loader, self.val_loader = temporal_signal_split(dataset, train_ratio=0.2)

        if global_node_features:
            num_global_features = self.train_dataset.tensors[1].shape[-1]
            num_nodes = self.train_dataset.tensors[0].shape[1]
            self.global_to_node = torch.nn.Linear(num_global_features, global_node_features * num_nodes)
        self.recurrent = A3TGCN2(node_features + global_node_features, output_size, periods=seq_len,
                                 batch_size=batch_size, cached=True, improved=True)
        self.linear = torch.nn.Linear(output_size, 1)

    # Using custom or multiple metrics (default_hp_metric=False)
    def on_train_start(self):
        print(self.hparams)
        self.logger.log_hyperparams(self.hparams, metrics={"loss/train": 0, "loss/val": 0})

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, batch):
        x, glob, y = batch
        if self.global_node_features:
            glob = self.global_to_node(glob)  # map global features to node attributes
            glob = glob.reshape(
                (glob.shape[0], glob.shape[1], glob.shape[2] // self.global_node_features, self.global_node_features))
            glob = glob.transpose(1, 2).transpose(2, 3)
            x = torch.concat((x, glob), dim=2)
        h = self.recurrent(x, self.edge_index, self.edge_weights)
        h = F.relu(h)
        return self.linear(h)

    def step(self, batch, batch_idx, val=False):
        out = self.forward(batch)
        # POI loss

        out_poi = out[..., -self.num_pois:, -1]
        if self.norm_visitors and val:
            out_poi = min_max_decode(out_poi, self.mins, self.maxs)
        loss = self.loss(out_poi, batch[2][..., -self.num_pois:])
        # Add Node loss
        if not val:
            loss_nodes = self.loss(out[..., :-self.num_pois, -1], batch[2][..., :-self.num_pois])
            loss = loss * self.loss_ratio_poi + (1 - self.loss_ratio_poi) * loss_nodes
        return loss, out_poi

    def training_step(self, train_batch, batch_idx):
        loss, out = self.step(train_batch, batch_idx)
        self.log(f"loss/train", loss)
        return loss

    def validation_step(self, val_batch, batch_idx):
        loss, out = self.step(val_batch, batch_idx, val=True)
        if self.prediction_file:
            with open(self.prediction_file, 'a') as csvfile:
                np.savetxt(csvfile, out.cpu().numpy(), delimiter=',')
        self.log(f"loss/val", loss)
        return loss

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=4)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=val_params['batch_size'], num_workers=4)


def train_gnn(num_runs=3,
              epochs=1000,
              cpu=False,
              seq_len=30,
              batch_size=16,
              lr=0.0001,
              norm_visitors=True,
              out_folder=None,
              output_size=3,
              global_node_features=1,
              loss='mse',
              loss_ratio_poi=0.5,
              **kwargs):
    # models = models if models is not None and len(models) > 0 else all_models

    if out_folder is None:
        out_folder = 'logs_gnn'

    if not cpu and not torch.cuda.is_available():
        print("Warning! CUDA not available")

    results = {}
    # for model in models:
    # results[model] = []
    # if not hasattr(rnn_models, model):
    #     print("ERROR: Unknown model type '{}'".format(model))
    #     continue
    # model_cls = getattr(rnn_models, model)
    for i in range(num_runs):
        print("Training GNN" + f" {i + 1}/{num_runs}")  # + model_cls.__name__ + f" {i + 1}/{num_runs}")
        pl.seed_everything(i)

        model = LitGNNModel(node_features=1, output_size=output_size, seq_len=seq_len, batch_size=batch_size, lr=lr,
                            global_node_features=global_node_features, norm_visitors=norm_visitors, loss=loss,
                            loss_ratio_poi=loss_ratio_poi)

        early_stop_callback = EarlyStopping(monitor='loss/val',
                                            min_delta=0.00,
                                            patience=30,
                                            verbose=False,
                                            mode='min')
        trainer = pl.Trainer(max_epochs=epochs, detect_anomaly=True,
                             auto_lr_find=False,
                             check_val_every_n_epoch=30,
                             devices=1,
                             accelerator='cpu' if cpu else 'auto',
                             callbacks=early_stop_callback,
                             logger=TensorBoardLogger(out_folder,
                                                      default_hp_metric=False,
                                                      name=model.recurrent.__class__.__name__,
                                                      version='seq' + str(seq_len) + '_glob' + str(
                                                          global_node_features) + (
                                                                  '_norm' if norm_visitors else '') + f"_{str(i)}"))
        trainer.tune(model)
        trainer.fit(model)
        results = trainer.validate(model)

    return results
    # print("Example output")
    # np.save('y_pred',y_pred.detach().cpu().numpy())
    # np.save('y_train',train_y.detach().cpu().numpy())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    args = add_args(parser).parse_args()
    train_gnn(args.num_runs, args.epochs, args.cpu, seq_len=args.seq_len,
              batch_size=args.batch_size, lr=args.lr, global_node_features=args.global_node_features,
              norm_visitors=args.norm_visitors, loss_ratio_poi=args.loss_ratio_poi)
