from typing import Any, List

import torch
from pytorch_lightning import LightningModule
import time

import os

from src import utils

import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))

from utils.distance import hungarian_batched
from torchmetrics import Accuracy

logger = utils.get_logger(__name__)

class PCMetricModule(LightningModule):
    """Pytorch Lightning module for learning to approximate wasserstein distance between point clouds.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(self, *args, **kwargs):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        self.model = hungarian_batched

        self.accuracy_t2s_train = Accuracy(task="multiclass",  num_classes=200, average='micro') # num_classes should be more than number of points
        self.accuracy_s2t_train = Accuracy(task="multiclass",  num_classes=200, average='micro') # setting to 2048 for now
        self.accuracy_train = Accuracy(task="multiclass",  num_classes=200, average='micro')

        self.accuracy_t2s_val = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_s2t_val = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_val = Accuracy(task="multiclass",  num_classes=200, average='micro')

        self.accuracy_t2s_test = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_s2t_test = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_test = Accuracy(task="multiclass",  num_classes=200, average='micro')

    def forward(self, source, target):
        preds = self.model(source, target, return_matching=False)
        return preds

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val_acc_best doesn't store accuracy from these checks
        # self.val_acc_best.reset()
        pass

    def step(self, batch: Any):
        pass

    def training_step(self, batch: Any, batch_idx: int):
        pass

    def training_epoch_end(self, outputs: List[Any]):
        # `outputs` is a list of dicts returned from `training_step()`
        pass

    def validation_step(self, batch: Any, batch_idx: int):
        pass

    def ext_validation_step(self, batch: Any, num_points: int, batch_idx: int):
        
        source, target, dist, matching, matching_t2s, true_grad = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s'], batch['grads']

        start_time = time.time_ns()
        preds, assignments  = self.model(source, target, return_matching=True)
        end_time = time.time_ns()

        matching_attn_mask_s_t = torch.cat([match[1].unsqueeze(0) for match in torch.tensor(assignments)]).to(self.device)
        matching_attn_mask_t_s = torch.sort(matching_attn_mask_s_t, dim=1)[1]
        
        # loss = (torch.nn.functional.cross_entropy(matching_attn_mask_t_s.permute(0,2,1), matching_t2s) + torch.nn.functional.cross_entropy(matching_attn_mask_s_t.permute(0,2,1), matching))/2
        #loss = torch.nn.functional.cross_entropy((matching_attn_mask_t_s.permute(0,2,1) + matching_attn_mask_s_t.permute(0,2,1))/2, matching)
        # mse_loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        # self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        # self.log('val/mse_emd', mse_loss, on_step=True, on_epoch=True, prog_bar=True)
        matching_attn_mask_t_s = torch.nn.functional.one_hot(matching_attn_mask_t_s, num_classes=num_points).float()
        matching_attn_mask_s_t = torch.nn.functional.one_hot(matching_attn_mask_s_t, num_classes=num_points).float()

        # print(matching_attn_mask_t_s.shape, matching_t2s.shape)
        self.accuracy_t2s_val(matching_attn_mask_t_s.permute(0,2,1), matching_t2s)
        self.log('val/acc_t2s', self.accuracy_t2s_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_val(matching_attn_mask_s_t.permute(0,2,1), matching)
        self.log('val/acc_s2t', self.accuracy_s2t_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_val(torch.cat([matching_attn_mask_s_t.permute(0,2,1), matching_attn_mask_t_s.permute(0,2,1)], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('val/acc_both', self.accuracy_val, on_step=True, on_epoch=True, prog_bar=True)

        matching_exapnded = torch.nn.functional.one_hot(matching_attn_mask_s_t.argmax(dim=2), num_classes=num_points)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_attn_mask_t_s.argmax(dim=2), num_classes=num_points)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        with torch.set_grad_enabled(True):
                # # Compute loss and grad from algorithm
                # temp_target = target.detach().clone() 
                # temp_target.requires_grad = True
                # temp_target.retain_grad()

                # preds = hungarian(source, temp_target)
                # true_dist = preds.clone().detach()
                # loss = preds.mean()
                # loss.backward()
                # true_grad = temp_target.grad

                # Compute loss and grad from model
                temp_target = target.detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                preds = self(source, temp_target)
                loss = preds.mean()
                loss.backward()
                pred_grad = temp_target.grad

                cos_sim = torch.nn.functional.cosine_similarity(true_grad, pred_grad, dim=2, eps=1e-08)

        return {"dist": dist, "preds": preds, 
                "bipartitness":bipartiteness.sum((1, 2)), "bipartiteness_correctness":bipartiteness_correctness.sum((1, 2)), 
                "cos_sim":cos_sim.reshape(-1),
                "time":torch.tensor([(end_time-start_time)/1e9])}  # "true_grad":true_grad, "pred_grad":pred_grad}

    def test_step(self, batch: Any, batch_idx: int):
        
        source, target, dist, matching, matching_t2s = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s']

        preds, _, matching_s2t_pred, _, matching_t2s_pred  = self.model(source, target, return_matching=True)
        loss = torch.nn.functional.mse_loss(preds, dist)

        self.accuracy_t2s_train(matching_t2s_pred, matching_t2s)
        self.log('test/acc_t2s', self.accuracy_t2s_train, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_train(matching_s2t_pred, matching)
        self.log('test/acc_s2t', self.accuracy_s2t_train, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_train(torch.cat([matching_s2t_pred, matching_t2s_pred], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('test/acc_both', self.accuracy_train, on_step=True, on_epoch=True, prog_bar=True)

        matching_exapnded = torch.nn.functional.one_hot(matching_s2t_pred)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_t2s_pred)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('test/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('test/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # log train metrics
        self.log('test/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test/mse_emd', loss, on_step=True, on_epoch=True, prog_bar=True)

        return {"source": source, "target": target, "dist": dist, "preds": preds}

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return torch.optim.Adam(
            params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )
