from typing import Any, List

import torch
from pytorch_lightning import LightningModule
import time

import os

from src import utils

import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))

from models.networks import DistanceNetwork

logger = utils.get_logger(__name__)

class PCMetricModule(LightningModule):
    """Pytorch Lightning module for learning to approximate wasserstein distance between point clouds.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self, input_dim=2, lr=0.001, weight_decay=0.0005, latent_dim=16, mlp_dims=[32, 64, 32, 16], mlp_activation=torch.nn.ReLU, *args, **kwargs):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        self.model = DistanceNetwork(**self.hparams)

    def forward(self, source, target):
        preds = self.model(source, target).squeeze(1)
        return preds

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val_acc_best doesn't store accuracy from these checks
        # self.val_acc_best.reset()
        pass

    def step(self, batch: Any):
        pass

    def training_step(self, batch: Any, batch_idx: int):
        source, target, dist = batch['source'], batch['target'], batch['dist']

        preds = self.model(source, target).squeeze(1)
        loss = torch.nn.functional.mse_loss(preds, dist)

        # log train metrics
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)

        return {"source": source, "target": target, "dist": dist, "preds": preds, "loss": loss}

    def training_epoch_end(self, outputs: List[Any]):
        # `outputs` is a list of dicts returned from `training_step()`
        pass

    def validation_step(self, batch: Any, batch_idx: int):
        
        source, target, dist = batch['source'], batch['target'], batch['dist']

        preds = self.model(source, target).squeeze(1)
        loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return {"source": source, "target": target, "dist": dist, "preds": preds}
    def ext_validation_step(self, batch: Any, num_points: int, batch_idx: int):
        
        source, target, dist, true_grad = batch['source'], batch['target'], batch['dist'], batch['grads']

        start_time = time.time_ns()
        preds = self.model(source, target).squeeze(1)
        loss = torch.nn.functional.mse_loss(preds, dist)
        end_time = time.time_ns()

        # log val metrics
        self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)

        with torch.set_grad_enabled(True):
                # # Compute loss and grad from algorithm
                # temp_target = target.detach().clone() 
                # temp_target.requires_grad = True
                # temp_target.retain_grad()

                # preds = hungarian(source, temp_target)
                # true_dist = preds.clone().detach()
                # loss = preds.mean()
                # loss.backward()
                # true_grad = temp_target.grad

                # Compute loss and grad from model
                temp_target = target.detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                preds = self(source, temp_target)
                loss = preds.mean()
                loss.backward()
                pred_grad = temp_target.grad

                cos_sim = torch.nn.functional.cosine_similarity(true_grad, pred_grad, dim=2, eps=1e-08)
        
        return {"dist": dist, "preds": preds, "cos_sim": cos_sim,  "time":torch.tensor([(end_time-start_time)/1e9])} 

    def test_step(self, batch: Any, batch_idx: int):
        
        source, target, dist = batch['source'], batch['target'], batch['dist']

        preds = self.model(source, target).squeeze(1)
        loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        self.log('test/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return {"source": source, "target": target, "dist": dist, "preds": preds}

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return torch.optim.Adam(
            params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )
