import math
import networkx as nx
import numpy as np
from dataclasses import dataclass
from omegaconf import II
from torch.nn.parameter import Parameter

import torch
from torch import nn
from fairseq import modules, metrics, utils
from fairseq.dataclass import FairseqDataclass
from fairseq.criterions import register_criterion, FairseqCriterion

@dataclass
class FusionConfig(FairseqDataclass):
    tpu: int = II("common.tpu")


@register_criterion("fusion", dataclass=FusionConfig)
class FusionLoss(FairseqCriterion):
    def __init__(self, cfg: FusionConfig, task):
        super().__init__(task)
        self.alphabet = task.alphabet
        self.tpu = cfg.tpu
        self.crd_criterion = nn.MSELoss(reduction='mean')

    def forward(self, model, sample, reduce=True):
        data, current = sample
        if current == 'crd':
            tokens, coords, prompt_toks = data
            batch_size = len(tokens)
            token_size = tokens.ne(self.alphabet.padding_idx).int().sum()
            sample_size = token_size

            logits = model.crd_forward(tokens, with_prompt_num=len(prompt_toks))

            loss = torch.tensor(0, dtype=torch.float32).cuda()
            loss.requires_grad_()
            for idx, (logit, coord) in enumerate(zip(logits, coords)):
                coord = torch.tensor(coord).cuda()
                logit = logit[:tokens[idx].ne(self.alphabet.padding_idx).int().sum()-2].float()
                logit_ = logit - logit.mean(dim=-1, keepdim=True)
                coord_ = coord - coord.mean(dim=-1, keepdim=True)
                C = torch.matmul(logit_.t(), coord_).detach()
                V, S, W = torch.linalg.svd(C)
                d = (torch.det(V) * torch.det(W)) < 0.0
                if d:
                    S[-1]    = S[-1] * (-1)
                    V[:, -1] = V[:, -1] * (-1)
                U = torch.matmul(V, W)
                logit_ = torch.matmul(logit_, U)
                loss = loss + self.crd_criterion(logit_, coord_)

            logging_output = {
                "crd_loss": loss.data.cpu(),
                "ntokens": token_size,
                "nsentences": batch_size,
                "sample_size": sample_size,
            }

        elif current == 'ppi':
            link_tokens, link_targets, prompt_toks = data
            batch_size = len(link_tokens)
            sample_size = (batch_size * batch_size - batch_size) / 2
            result = model.ppi_forward(link_tokens, with_prompt_num=len(prompt_toks))
            loss = modules.cross_entropy(
                result.view(-1, result.size(-1)),
                link_targets.view(-1),
                reduction="sum",
                ignore_index=-1
            )
            logging_output = {
                "ppi_loss": loss.data.cpu(),
                "ntokens": sample_size,
                "nsentences": batch_size,
                "sample_size": sample_size,
            }
            import pdb; pdb.set_trace()

        elif current == 'mlm':
            origin_tokens, masked_tokens, target_tokens, prompt_toks = data
            batch_size = origin_tokens.size(0)
            token_size = origin_tokens.ne(self.alphabet.padding_idx).int().sum()
            sample_size = target_tokens.ne(self.alphabet.padding_idx).int().sum()

            result = model.mlm_forward(masked_tokens, with_prompt_num=len(prompt_toks))
            loss = modules.cross_entropy(
                result.view(-1, result.size(-1)),
                target_tokens.view(-1),
                reduction="mean",
                ignore_index=self.alphabet.padding_idx
            )
            logging_output = {
                "mlm_loss": loss.data.cpu(),
                "ntokens": token_size,
                "nsentences": batch_size,
                "sample_size": sample_size,
            }
        return  loss, batch_size, logging_output

    @staticmethod
    def reduce_metrics(logging_outputs) -> None:
        """Aggregate logging outputs from data parallel training."""
        crd_loss_sum = np.nanmean([log.get("crd_loss", np.nan) for log in logging_outputs])
        ppi_loss_sum = np.nanmean([log.get("ppi_loss", np.nan) for log in logging_outputs])
        mlm_loss_sum = np.nanmean([log.get("mlm_loss", np.nan) for log in logging_outputs])
        metrics.log_scalar(
            "loss", crd_loss_sum + ppi_loss_sum + mlm_loss_sum, round=3
        )
        metrics.log_scalar(
            "mlm_loss", mlm_loss_sum, weight= 1 if mlm_loss_sum != 0 else 0, round=3
        )
        metrics.log_scalar(
            "crd_loss", crd_loss_sum, weight= 1 if crd_loss_sum != 0 else 0, round=3
        )
        metrics.log_scalar(
            "ppi_loss", ppi_loss_sum, weight= 1 if ppi_loss_sum != 0 else 0, round=3
        )

        metrics.log_derived(
            "ppl", lambda meters: utils.get_perplexity(meters["loss"].avg)
        )

    @staticmethod
    def logging_outputs_can_be_summed() -> bool:
        """
        Whether the logging outputs returned by `forward` can be summed
        across workers prior to calling `reduce_metrics`. Setting this
        to True will improves distributed training speed.
        """
        return True
