import torch
import lightning as L
from lightning.pytorch.cli import OptimizerCallable, LRSchedulerCallable


# ModelCallable = Callable[[Iterable], L.LightningModule]


class SequenceRegression(L.LightningModule):
    def __init__(self, network: torch.nn.Module,
                 optimizer: OptimizerCallable,
                 scheduler: LRSchedulerCallable,
                 scheduler_interval: str = 'epoch',
                 scheduler_frequency: int = 1,
                 scheduler_monitor: str = 'train_loss'):
        """ Pytorch Lightning Task that trains sequence models that convert an input sequence to a target terminating/non-terminating sequence. The
        DataLoader that is used with the class needs to have the following properties:
        (1) A single batch should return (x, y, { `masks`: mask }) where x, y are tokenized and converted to Longs

        :param network (torch.nn.Module): network to train. The network must have the following properties:
                                          (1) forward/call function should take optional kwargs
                                          (2) must have a class member `output_vocab_size` which stores the size of the output vocabulary
                                          (3) the output of the model should have size (B, L, C) where B is the batch size and L is the sequence length and C is the output vocabulary size
                                          (4) (optional) an initialize function that is called before forward. e.g. when hidden states needs to be initialized in a Vanilla RNN
        :param optimizer (torch.optim.Optimizer): optimizer to use
        :param scheduler (torch.optim.lr_scheduler._LRScheduler): scheduler to use
        :param scheduler_interval (str, optional): Part of additional arguments sent to the scheduler. Defaults to 'epoch'.
        :param scheduler_frequency (int): Part of additional arguments sent to the scheduler. frequency of scheduler to use
        :param scheduler_monitor (str): Part of additional arguments sent to the scheduler. monitor to use
        """
        self.scheduler_interval = scheduler_interval
        self.scheduler_frequency = scheduler_frequency
        self.scheduler_monitor = scheduler_monitor
        self.best_val_accuracy = 0.0

        super().__init__()
        # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
        self.save_hyperparameters()
        self.model = network
        self.optimizer = optimizer
        self.scheduler = scheduler

        torch.autograd.set_detect_anomaly(True)  # this helps in debugging error messages

        # self.model.double()  # make sure model is in double precision
        self.model = self.model.to(self.device)

        # logic for computing the best validation accuracy so far
        self.val_accuracies = []
        self.best_val_accuracy = 0.0
        # self.model.log(self.logger.experiment)

    def configure_optimizers(self):
        self.optimizer = self.optimizer(self.parameters())

        # We will reduce the learning rate by 0.1 after 100 and 150 epochs
        self.scheduler = self.scheduler(self.optimizer)
        return {
            "optimizer": self.optimizer,
            "lr_scheduler": {
                     "scheduler": self.scheduler,
                     "interval": self.scheduler_interval,
                     "frequency": self.scheduler_frequency,
                     "monitor": self.scheduler_monitor
                 }
        }

    def forward(self, batch):
        # Forward function
        return self.model(batch)

    def training_step(self, batch, batch_idx):
        x, y = batch
        batch_size, seq_length, d = x.shape

        if hasattr(self.model, "initialize"):
            self.model.initialize(batch_size=batch_size, device=self.device)
        y_pred = self.model(x)

        ## error signal is only for the parts of y that are available (0 indicates no data available)
        # signal_mask = torch.abs(y) > 0
        loss = torch.mean((y - y_pred)**2)

        # # Logs the loss per epoch to wandb (weighted average over batches)
        self.log("train_loss", loss, prog_bar=True)
        return loss  # Return tensor to call ".backward" on

    def validation_step(self, batch, batch_idx):
        x, y = batch
        batch_size, seq_length, d = x.shape

        if hasattr(self.model, "initialize"):
            self.model.initialize(batch_size=batch_size, device=self.device)
        y_pred = self.model(x)

        loss = torch.mean((y - y_pred)**2)
        self.log("validation_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch
        batch_size, seq_length, d = x.shape

        self.model.initialize(batch_size=batch_size, device=self.device)

        y_pred = torch.zeros_like(y)

        # autoregressive generation
        cur_x = x[:, 0:1, :].clone()
        for t in range(x.shape[1]):
            y_pred[:, t:t+1, :] = self.model(cur_x)
            cur_x = y_pred[:, t:t+1, :].clone()

        loss = torch.mean((y - y_pred) ** 2)
        self.log("test_loss", loss, prog_bar=True)
