from typing import Any, List

import torch
from pytorch_lightning import LightningModule
import time

import os

import ot

from torchmetrics import Accuracy

class PCMetricModule(LightningModule):
    """Pytorch Lightning module for learning to approximate wasserstein distance between point clouds.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(self, niters, *args, **kwargs):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        self.model = ot.sinkhorn
        self.niters = niters

        self.accuracy_t2s_train = Accuracy(task="multiclass",  num_classes=200, average='micro') # num_classes should be more than number of points
        self.accuracy_s2t_train = Accuracy(task="multiclass",  num_classes=200, average='micro') # setting to 2048 for now
        self.accuracy_train = Accuracy(task="multiclass",  num_classes=200, average='micro')

        self.accuracy_t2s_val = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_s2t_val = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_val = Accuracy(task="multiclass",  num_classes=200, average='micro')

        self.accuracy_t2s_test = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_s2t_test = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_test = Accuracy(task="multiclass",  num_classes=200, average='micro')

    def forward(self, source, target):
        preds, _, _ = self.model(source, target)
        return preds

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val_acc_best doesn't store accuracy from these checks
        # self.val_acc_best.reset()
        pass

    def step(self, batch: Any):
        pass

    def training_step(self, batch: Any, batch_idx: int):
        pass

    def training_epoch_end(self, outputs: List[Any]):
        # `outputs` is a list of dicts returned from `training_step()`
        pass

    def validation_step(self, batch: Any, batch_idx: int):
        pass

    def ext_validation_step(self, batch: Any, num_points: int, batch_idx: int):

        numItermax = self.niters
        reg = 0.1
        
        source, target, dist, matching, matching_t2s, true_grad = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s'], batch['grads']
        source, target = source.squeeze(), target.squeeze()
        
        start_time = time.time_ns()
        M = ot.dist(source, target, metric='euclidean') # computes squares euclidean distance by default
        a, b = torch.ones(source.shape[0], device=source.device)/source.shape[0], torch.ones(target.shape[0], device=target.device)/target.shape[0]
        T1 = self.model(a, b, M, log=True, reg=reg, method='sinkhorn', numItermax=numItermax, stopThr=-torch.inf)
        preds = torch.sum(M*T1[0])*len(source)
        matching_attn_mask_s_t, matching_attn_mask_t_s  = torch.argmax(T1[0], dim=1), torch.argmax(T1[0], dim=0)
        end_time = time.time_ns()

        preds = preds.unsqueeze(0)
        matching_attn_mask_s_t = matching_attn_mask_s_t.unsqueeze(0)
        matching_attn_mask_t_s = matching_attn_mask_t_s.unsqueeze(0)
        # loss = (torch.nn.functional.cross_entropy(matching_attn_mask_t_s.permute(0,2,1), matching_t2s) + torch.nn.functional.cross_entropy(matching_attn_mask_s_t.permute(0,2,1), matching))/2
        #loss = torch.nn.functional.cross_entropy((matching_attn_mask_t_s.permute(0,2,1) + matching_attn_mask_s_t.permute(0,2,1))/2, matching)
        # mse_loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        # self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        # self.log('val/mse_emd', mse_loss, on_step=True, on_epoch=True, prog_bar=True)
        matching_attn_mask_t_s = torch.nn.functional.one_hot(matching_attn_mask_t_s, num_classes=num_points).float()
        matching_attn_mask_s_t = torch.nn.functional.one_hot(matching_attn_mask_s_t, num_classes=num_points).float()

        # print(matching_attn_mask_t_s.shape, matching_t2s.shape)
        self.accuracy_t2s_val(matching_attn_mask_t_s.permute(0,2,1), matching_t2s)
        self.log('val/acc_t2s', self.accuracy_t2s_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_val(matching_attn_mask_s_t.permute(0,2,1), matching)
        self.log('val/acc_s2t', self.accuracy_s2t_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_val(torch.cat([matching_attn_mask_s_t.permute(0,2,1), matching_attn_mask_t_s.permute(0,2,1)], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('val/acc_both', self.accuracy_val, on_step=True, on_epoch=True, prog_bar=True)

        matching_exapnded = torch.nn.functional.one_hot(matching_attn_mask_s_t.argmax(dim=2), num_classes=num_points)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_attn_mask_t_s.argmax(dim=2), num_classes=num_points)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        with torch.set_grad_enabled(True):
                # # Compute loss and grad from algorithm
                # temp_target = target.detach().clone() 
                # temp_target.requires_grad = True
                # temp_target.retain_grad()

                # preds = hungarian(source, temp_target)
                # true_dist = preds.clone().detach()
                # loss = preds.mean()
                # loss.backward()
                # true_grad = temp_target.grad

                # Compute loss and grad from model
                temp_target = target.detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                M = ot.dist(source, temp_target, metric='euclidean') # computes squares euclidean distance by default
                T1 = self.model(a, b, M, log=True, reg=reg, method='sinkhorn', numItermax=numItermax, stopThr=-torch.inf)
                preds = torch.sum(M*T1[0].detach())*len(source)
                # preds = self(source.squeeze(), temp_target.squeeze())
                preds = preds.unsqueeze(0)
                loss = preds.mean()
                loss.backward()
                pred_grad = temp_target.grad

                cos_sim = torch.nn.functional.cosine_similarity(true_grad, pred_grad, dim=2, eps=1e-08)

        # if not T1[1]['niter'] > 0 or torch.isnan(temp_target.grad).any():
        #     return {}

        return {"dist": dist, "preds": preds, 
                "bipartitness":bipartiteness.sum((1, 2)), "bipartiteness_correctness":bipartiteness_correctness.sum((1, 2)), 
                "cos_sim":cos_sim.reshape(-1),
                "time":torch.tensor([(end_time-start_time)/1e9])} # "true_grad":true_grad, "pred_grad":pred_grad}

    def test_step(self, batch: Any, batch_idx: int):
        
        source, target, dist, matching, matching_t2s = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s']

        preds, _, matching_s2t_pred, _, matching_t2s_pred  = self.model(source, target, return_matching=True)
        loss = torch.nn.functional.mse_loss(preds, dist)

        self.accuracy_t2s_train(matching_t2s_pred, matching_t2s)
        self.log('test/acc_t2s', self.accuracy_t2s_train, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_train(matching_s2t_pred, matching)
        self.log('test/acc_s2t', self.accuracy_s2t_train, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_train(torch.cat([matching_s2t_pred, matching_t2s_pred], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('test/acc_both', self.accuracy_train, on_step=True, on_epoch=True, prog_bar=True)

        matching_exapnded = torch.nn.functional.one_hot(matching_s2t_pred)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_t2s_pred)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('test/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('test/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # log train metrics
        self.log('test/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)

        return {"source": source, "target": target, "dist": dist, "preds": preds}

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return torch.optim.Adam(
            params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )
