from typing import Any, List

import torch
from pytorch_lightning import LightningModule
import time

import os

from src import utils

import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))

from models.networks import DistanceNetworkTransformer

logger = utils.get_logger(__name__)

class PCMetricModule(LightningModule):
    """Pytorch Lightning module for learning to approximate wasserstein distance between point clouds.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self, lr=0.001, weight_decay=0.0005, input_dim=2, latent_dim=16, dim_model=64, dim_keys=64, dim_hidden=64, nb_heads=4, nb_blocks=4, dropout=0., mlp_dims=[32, 64, 32, 16], mlp_activation=torch.nn.ReLU, *args, **kwargs):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        self.model = DistanceNetworkTransformer(**self.hparams)

    def forward(self, source, target):
        preds = self.model(source, target).squeeze(1)
        return preds

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val_acc_best doesn't store accuracy from these checks
        # self.val_acc_best.reset()
        pass

    def step(self, batch: Any):
        pass

    def training_step(self, batch: Any, batch_idx: int):
        source, target, dist = batch['source'], batch['target'], batch['dist']

        preds = self.model(source, target).squeeze(1)
        loss = torch.nn.functional.mse_loss(preds, dist)

        # log train metrics
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)

        return {"source": source, "target": target, "dist": dist, "preds": preds, "loss": loss}

    def training_epoch_end(self, outputs: List[Any]):
        # `outputs` is a list of dicts returned from `training_step()`
        pass

    def validation_step(self, batch: Any, batch_idx: int):
        
        source, target, dist = batch['source'], batch['target'], batch['dist']

        preds = self.model(source, target).squeeze(1)
        loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return {"source": source, "target": target, "dist": dist, "preds": preds}

    def validation_epoch_end(self, outputs: List[Any]):
        pass
        
    def test_step(self, batch: Any, batch_idx: int):
        
        source, target, dist = batch['source'], batch['target'], batch['dist']

        preds = self.model(source, target).squeeze(1)
        loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        self.log('test/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return {"source": source, "target": target, "dist": dist, "preds": preds}
        

    def test_epoch_end(self, outputs: List[Any]):
        raise NotImplementedError

    def on_epoch_end(self):
        # reset metrics at the end of every epoch
        # self.train_acc.reset()
        # self.test_acc.reset()
        # self.val_acc.reset()
        pass

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return torch.optim.Adam(
            params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )
