from typing import Any, List

import torch
from pytorch_lightning import LightningModule
import time

import os
from torchmetrics import Accuracy

from src import utils

import sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', '..', 'src'))
from utils.distance import hungarian, chamfer

from models.networks import MatchingTransformer

logger = utils.get_logger(__name__)

class PCMetricModule(LightningModule):
    """Pytorch Lightning module for learning to approximate wasserstein distance between point clouds.

    A LightningModule organizes your PyTorch code into 5 sections:
        - Computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)

    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """

    def __init__(
        self, lr=0.001, weight_decay=0.0005, input_dim=2, latent_dim=16, dim_model=64, dim_keys=64, dim_hidden=64, nb_heads=4, nb_blocks=4, dropout=0., *args, **kwargs):
        super().__init__()

        # this line allows to access init params with 'self.hparams' attribute
        # it also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)
        self.model = MatchingTransformer(**self.hparams)

        self.accuracy_t2s_train = Accuracy(task="multiclass",  num_classes=200, average='micro') # num_classes should be more than number of points
        self.accuracy_s2t_train = Accuracy(task="multiclass",  num_classes=200, average='micro') # setting to 2048 for now
        self.accuracy_train = Accuracy(task="multiclass",  num_classes=200, average='micro')

        self.accuracy_t2s_val = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_s2t_val = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_val = Accuracy(task="multiclass",  num_classes=200, average='micro')

        self.accuracy_t2s_test = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_s2t_test = Accuracy(task="multiclass",  num_classes=200, average='micro')
        self.accuracy_test = Accuracy(task="multiclass",  num_classes=200, average='micro')

    def forward(self, source, target):
        preds, matching_attn_mask_s_t, matching_attn_mask_t_s =  self.model(source, target)
        return preds

    def on_train_start(self):
        # by default lightning executes validation step sanity checks before training starts,
        # so we need to make sure val_acc_best doesn't store accuracy from these checks
        # self.val_acc_best.reset()
        pass

    def step(self, batch: Any):
        pass

    def training_step(self, batch: Any, batch_idx: int):
        source, target, dist, matching, matching_t2s = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s']

        preds, matching_attn_mask_s_t, matching_attn_mask_t_s = self.model(source, target)
        
        T = 1 # for temperature scaling in softmax
        loss = (torch.nn.functional.cross_entropy(matching_attn_mask_t_s.permute(0,2,1)/T, matching_t2s) + torch.nn.functional.cross_entropy(matching_attn_mask_s_t.permute(0,2,1)/T, matching))/2
        mse_loss = torch.nn.functional.mse_loss(preds, dist)

        # log train metrics
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train/mse_emd', mse_loss, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_t2s_train(matching_attn_mask_t_s.permute(0,2,1), matching_t2s)
        self.log('train/acc_t2s', self.accuracy_t2s_train, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_train(matching_attn_mask_s_t.permute(0,2,1), matching)
        self.log('train/acc_s2t', self.accuracy_s2t_train, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_train(torch.cat([matching_attn_mask_s_t.permute(0,2,1), matching_attn_mask_t_s.permute(0,2,1)], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('train/acc_both', self.accuracy_train, on_step=True, on_epoch=True, prog_bar=True)

        matching_exapnded = torch.nn.functional.one_hot(matching_attn_mask_s_t.argmax(dim=2), num_classes=200)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_attn_mask_t_s.argmax(dim=2), num_classes=200)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('train/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('train/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        return {"source": source, "target": target, "dist": dist, "preds": preds, "loss": loss}

    def validation_step(self, batch: Any, batch_idx: int):
        
        source, target, dist, matching, matching_t2s = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s']

        preds, matching_attn_mask_s_t, matching_attn_mask_t_s = self.model(source, target)
        loss = (torch.nn.functional.cross_entropy(matching_attn_mask_t_s.permute(0,2,1), matching_t2s) + torch.nn.functional.cross_entropy(matching_attn_mask_s_t.permute(0,2,1), matching))/2
        #loss = torch.nn.functional.cross_entropy((matching_attn_mask_t_s.permute(0,2,1) + matching_attn_mask_s_t.permute(0,2,1))/2, matching)
        mse_loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val/mse_emd', mse_loss, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_t2s_val(matching_attn_mask_t_s.permute(0,2,1), matching_t2s)
        self.log('val/acc_t2s', self.accuracy_t2s_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_val(matching_attn_mask_s_t.permute(0,2,1), matching)
        self.log('val/acc_s2t', self.accuracy_s2t_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_val(torch.cat([matching_attn_mask_s_t.permute(0,2,1), matching_attn_mask_t_s.permute(0,2,1)], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('val/acc_both', self.accuracy_val, on_step=True, on_epoch=True, prog_bar=True)

        matching_exapnded = torch.nn.functional.one_hot(matching_attn_mask_s_t.argmax(dim=2), num_classes=200)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_attn_mask_t_s.argmax(dim=2), num_classes=200)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)


        return {"source": source, "target": target, "dist": dist, "preds": preds}
    
    def ext_validation_step(self, batch: Any, num_points: int, batch_idx: int):
        

        
        source, target, dist, matching, matching_t2s, true_grad = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s'], batch['grads']
        # source, target, dist, matching, matching_t2s = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s']

        start_time = time.time_ns()
        preds, matching_attn_mask_s_t, matching_attn_mask_t_s = self.model(source, target)
        end_time = time.time_ns()

        # print(matching_attn_mask_t_s.shape, matching_t2s.shape)
        # print(matching_attn_mask_s_t.shape, matching.shape)
        
        loss = (torch.nn.functional.cross_entropy(matching_attn_mask_t_s.permute(0,2,1), matching_t2s) + torch.nn.functional.cross_entropy(matching_attn_mask_s_t.permute(0,2,1), matching))/2
        #loss = torch.nn.functional.cross_entropy((matching_attn_mask_t_s.permute(0,2,1) + matching_attn_mask_s_t.permute(0,2,1))/2, matching)
        mse_loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        self.log('val/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val/mse_emd', mse_loss, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_t2s_val(matching_attn_mask_t_s.permute(0,2,1), matching_t2s)
        self.log('val/acc_t2s', self.accuracy_t2s_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_val(matching_attn_mask_s_t.permute(0,2,1), matching)
        self.log('val/acc_s2t', self.accuracy_s2t_val, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_val(torch.cat([matching_attn_mask_s_t.permute(0,2,1), matching_attn_mask_t_s.permute(0,2,1)], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('val/acc_both', self.accuracy_val, on_step=True, on_epoch=True, prog_bar=True)

        matching_exapnded = torch.nn.functional.one_hot(matching_attn_mask_s_t.argmax(dim=2), num_classes=num_points)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_attn_mask_t_s.argmax(dim=2), num_classes=num_points)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('val/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        with torch.set_grad_enabled(True):
                # # Compute loss and grad from algorithm
                # temp_target = target.detach().clone() 
                # temp_target.requires_grad = True
                # temp_target.retain_grad()

                # preds = hungarian(source, temp_target)
                # true_dist = preds.clone().detach()
                # loss = preds.mean()
                # loss.backward()
                # true_grad = temp_target.grad

                # Compute loss and grad from model
                temp_target = target.detach().clone() 
                temp_target.requires_grad = True
                temp_target.retain_grad()

                preds = self(source, temp_target)
                loss = preds.mean()
                loss.backward()
                pred_grad = temp_target.grad

                cos_sim = torch.nn.functional.cosine_similarity(true_grad, pred_grad, dim=2, eps=1e-08)

        # return {"dist": dist, "preds": preds, 
        #         "bipartitness":bipartiteness.sum((1, 2)), "bipartiteness_correctness":bipartiteness_correctness.sum((1, 2)), 
        #         "cos_sim":cos_sim.reshape(-1),} # "true_grad":true_grad, "pred_grad":pred_grad}
    
        return {"dist": dist, "preds": preds, 
                "bipartitness":bipartiteness.sum((1, 2)), "bipartiteness_correctness":bipartiteness_correctness.sum((1, 2)), 
                "cos_sim":cos_sim.reshape(-1), "time":torch.tensor([(end_time-start_time)/1e9])} 

    def test_step(self, batch: Any, batch_idx: int):
        
        source, target, dist, matching, matching_t2s = batch['source'], batch['target'], batch['dist'], batch['matching'], batch['matching_t2s']

        preds, matching_attn_mask_s_t, matching_attn_mask_t_s = self.model(source, target)
        loss = (torch.nn.functional.cross_entropy(matching_attn_mask_t_s.permute(0,2,1), matching_t2s) + torch.nn.functional.cross_entropy(matching_attn_mask_s_t.permute(0,2,1), matching))/2
        #loss = torch.nn.functional.cross_entropy((matching_attn_mask_t_s.permute(0,2,1) + matching_attn_mask_s_t.permute(0,2,1))/2, matching)
        mse_loss = torch.nn.functional.mse_loss(preds, dist)

        # log val metrics
        self.log('test/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('test/mse_emd', mse_loss, on_step=True, on_epoch=True, prog_bar=True)

        # Correctness
        self.accuracy_t2s_test(matching_attn_mask_t_s.permute(0,2,1), matching_t2s)
        self.log('test/acc_t2s', self.accuracy_t2s_test, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_s2t_test(matching_attn_mask_s_t.permute(0,2,1), matching)
        self.log('test/acc_s2t', self.accuracy_s2t_test, on_step=True, on_epoch=True, prog_bar=True)

        self.accuracy_test(torch.cat([matching_attn_mask_s_t.permute(0,2,1), matching_attn_mask_t_s.permute(0,2,1)], dim=0), torch.cat([matching, matching_t2s], dim=0))
        self.log('test/acc_both', self.accuracy_test, on_step=True, on_epoch=True, prog_bar=True)

        # Bipartiteness
        # matching_exapnded = torch.nn.functional.one_hot(matching, num_classes=matching.shape[1])
        # matching_t2s_exapnded = torch.nn.functional.one_hot(matching_t2s, num_classes=matching.shape[1])
        matching_exapnded = torch.nn.functional.one_hot(matching_attn_mask_s_t.argmax(dim=2), num_classes=200)
        matching_t2s_exapnded = torch.nn.functional.one_hot(matching_attn_mask_t_s.argmax(dim=2), num_classes=200)

        bipartiteness = torch.logical_and((matching_exapnded == matching_t2s_exapnded.permute(0,2,1)), (matching_exapnded == 1))
        bipartiteness_ratio = (bipartiteness.sum((1, 2))/matching.shape[1]).mean()
        self.log('test/bipartiteness', bipartiteness_ratio, on_step=True, on_epoch=True, prog_bar=True)


        # Bipartitness + Correctness
        bipartiteness_correctness = torch.logical_and(bipartiteness, (matching_exapnded == torch.nn.functional.one_hot(matching)))
        bipartiteness_correctness_ratio = (bipartiteness_correctness.sum((1, 2))/matching.shape[1]).mean()
        self.log('test/bipartite_correct', bipartiteness_correctness_ratio, on_step=True, on_epoch=True, prog_bar=True)

        
        return {"source": source, "target": target, "dist": dist, "preds": preds,} 
                # "matching_s2t": batch['matching'], "matching_t2s": batch['matching_t2s'],
                # "pred_matching_s2t": torch.argmax(matching_attn_mask_s_t, dim=-1), "pred_matching_t2s": torch.argmax(matching_attn_mask_t_s, dim=-1)}
   
    def on_epoch_end(self):
        # reset metrics at the end of every epoch
        # self.train_acc.reset()
        # self.test_acc.reset()
        # self.val_acc.reset()
        pass

    def configure_optimizers(self):
        """Choose what optimizers and learning-rate schedulers to use in your optimization.
        Normally you'd need one. But in the case of GANs or similar you might have multiple.

        See examples here:
            https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
        """
        return torch.optim.Adam(
            params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
        )
