"""
Lightning models.
"""
from typing import Callable, Tuple
from functools import partial

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.data import Data


from .metrics import LOSS_REGISTRY, ACC_REGISTRY
# Define the model wrapper class.
# Support training,testing and loading model.


class LightningWrapper(pl.LightningModule):
    def __init__(self, model: nn.Module, config: dict):
        super(LightningWrapper, self).__init__()

        self.model = model
        self.reshuffle_every_n_epochs = config['reshuffle_every_n_epochs']
        self.lr = config['lr']
        self.wd = config['wd']
        self.gamma = config['gamma']
        self.cooldown = config['cooldown']
        self.patience = config['patience']
        self.threshold = config['threshold']

        self.loss_fn = LOSS_REGISTRY.get(config['loss'])
        self.acc_fn = ACC_REGISTRY.get(config['acc'])

        # set the forces flag
        if 'forces' in config:
            self.forces = config['forces']
        else:
            self.forces = False


    def configure_optimizers(self) -> Tuple:
        optimizer = optim.AdamW(
            params=self.model.parameters(),
            lr=self.lr,
            weight_decay=self.wd,
            eps=1e-07,
            amsgrad=True
        )

        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  factor=self.gamma,
                                                                  cooldown=self.cooldown,
                                                                  patience=self.patience,
                                                                  threshold=self.threshold,
                                                                  )
        lr_scheduler_config = {
            "scheduler": lr_scheduler,
            "interval": "epoch",
            "monitor": "train_loss"
        }

        return [optimizer], lr_scheduler_config

    def forward(self, data: Data) -> torch.Tensor:
        return self.model(data)

    def training_step(self, batch: Data, batch_idx: int) -> torch.float:
        model = self.model
        model.train()
        label = batch.y
        outs = model(batch)
        #check for nans
        loss = self.loss_fn(outs, label)  # Compute the loss.
        acc = self.acc_fn(outs, label)  # Compute the Accuracy.

        self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=label.size(0), logger=True)  # Update the loss.
        self.log('train_acc', acc, on_step=True, on_epoch=True, sync_dist=True,
                 batch_size=label.size(0), logger=True)  # Update the acc.
        return loss  # Return the loss.

    def eval_model(self, batch: Data, mode: str) -> torch.float:
        model = self.model
        model.eval()
        label = batch.y
        key = torch.enable_grad if self.forces  else torch.no_grad
        with key():
            assert torch.is_grad_enabled() == (self.forces), 'forces should be computed with grad enabled'
            outs = model(batch)
            loss = self.loss_fn(outs, label)
            acc = self.acc_fn(outs, label)
            self.log(f'{mode}_loss', loss, on_step=True, on_epoch=True, sync_dist=True,
                     batch_size=label.size(0), logger=True)  # Update the loss.
            self.log(f'{mode}_acc', acc, on_step=True, on_epoch=True, sync_dist=True,
                     batch_size=label.size(0), logger=True)  # Update the acc.
        return loss  # Return the loss.

    def test_step(self, batch: Data, batch_idx: int) -> torch.float:
        return self.eval_model(batch=batch, mode='test')

    def validation_step(self, batch: Data, batch_idx: int) -> torch.float:
        return self.eval_model(batch=batch, mode='val')

    def compute_metric(self, trainer: pl.Trainer, test_loader: DataLoader, track: str) -> torch.float:
        return trainer.test(self, dataloaders=test_loader, verbose=False)[0][f'test_{track}_epoch']

    def on_train_epoch_end(self):
        lr = self.optimizers().param_groups[0]["lr"]
        self.log("lr", lr)

    def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
        if gradient_clip_algorithm == "norm":
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), gradient_clip_val)
