import torch
import torch.nn as nn
import torch.nn.functional as F

from functools import reduce
from operator import mul, add
from sklearn import metrics
from graph_learning.module import ModuleConfig

def error_rate(probs, labels):
    mv, mi = probs.max(1)
    map_p = torch.zeros_like(labels).scatter(1, mi.unsqueeze(1), mv.unsqueeze(1))
    return torch.where(labels==1, labels, -map_p)

def gmean(input_x, dim):
    log_x = torch.log(input_x+1e-12)
    return torch.exp(torch.mean(log_x, dim=dim))

@ModuleConfig.register('ce-loss',
                       help='[Loss] cross entropy loss.')
class CELossModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)

    @property
    def builder(self):
        return CELoss

class CELoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.criteria = nn.CrossEntropyLoss()

    def forward(self, logits, labels, mask):
        if isinstance(logits, tuple):
            logits, mask_valid = logits
            mask = mask & mask_valid
        else:
            mask_valid = torch.ones_like(mask)

        cls_loss = self.criteria(logits[mask], labels[mask])

        probs = F.softmax(logits, 1)
        preds = probs.argmax(1)

        return {'loss': cls_loss,
                'losses': {'cls_loss': cls_loss},
                'outputs': {'probs': probs,
                            'preds': preds,}}

@ModuleConfig.register('dist-loss',
                       help='[Loss] mean square loss.')
class DistLossModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--mode', choices=['dist', 'reach'])

    @property
    def builder(self):
        return DistLoss

class DistLoss(nn.Module):
    def __init__(self, mode):
        super().__init__()
        self.mse = nn.MSELoss()
        self.mae = nn.L1Loss()
        self.huber = nn.SmoothL1Loss()

        self.ce = nn.BCEWithLogitsLoss()
        self.mode = mode

    def forward(self, logits, labels, mask):
        if isinstance(logits, tuple):
            logits, mask_valid = logits

        if isinstance(logits, list):
            logits = torch.stack(logits, 1)

        labels = labels[:, :logits.size(1)]

        logits_m = logits[mask]
        labels_m = labels[mask]

        label_mask = ~labels_m.isnan()
        label_mask_f = ~labels_m[:, -1].isnan()

        ret = None

        if self.mode == 'dist':
            mse = self.mse(logits_m[label_mask], labels_m[label_mask])
            mae = self.mae(logits_m[label_mask], labels_m[label_mask])
            huber = self.huber(logits_m[label_mask], labels_m[label_mask])
            mse_f = self.mse(logits_m[:,-1][label_mask_f], labels_m[:,-1][label_mask_f])
            mae_f = self.mae(logits_m[:,-1][label_mask_f], labels_m[:,-1][label_mask_f])
            huber_f = self.huber(logits_m[:,-1][label_mask_f], labels_m[:,-1][label_mask_f])

            loss_zero = self.mse(logits_m[~label_mask], torch.full_like(labels_m[~label_mask], 0))
            l = huber

            ret = {'loss': l,
                   'losses': {'reg_loss': l},
                   'outputs': {'metrics': {'mse': mse, 'mae': mae,
                                           'msef':mse_f, 'maef': mae_f}}}
        elif self.mode == 'reach':
            labels_reach = label_mask.float()
            labels_reach_f = label_mask_f.float()

            logits_m = logits_m

            def nan2zero(t):
                return torch.where(t.isnan(), torch.zeros_like(t), t)

            reach_loss_pos = self.ce(logits_m[label_mask], labels_reach[label_mask])
            reach_loss_f_pos = self.ce(logits_m[:, -1][label_mask_f], labels_reach_f[label_mask_f])
            reach_loss_neg = nan2zero(self.ce(logits_m[~label_mask], labels_reach[~label_mask]))
            reach_loss_f_neg = nan2zero(self.ce(logits_m[:, -1][~label_mask_f], labels_reach_f[~label_mask_f]))

            match = ((logits_m > 0) == labels_reach)

            acc_reach = match.sum().float() / match.numel()
            acc_reach_f = match[:, -1].sum().float() / match[:, -1].numel()

            def numpy(t):
                return t.long().detach().cpu().flatten().numpy()
            f1_reach = metrics.f1_score(numpy(labels_reach), numpy(logits_m>0))
            f1_reach_f = metrics.f1_score(numpy(labels_reach_f), numpy(logits_m[:, -1]>0))


            ret =  {'loss': reach_loss_pos + reach_loss_neg,
                    'losses': {'reach_loss_pos': reach_loss_pos, 'reach_loss_f_pos': reach_loss_f_pos,
                               'reach_loss_neg': reach_loss_neg, 'reach_loss_f_neg': reach_loss_f_neg},
                    'outputs': {'metrics': {'acc_reach': acc_reach, 'acc_reach_f': acc_reach_f,
                                            'f1_reach': f1_reach, 'f1_reach_f': f1_reach_f}}}

        return ret

@ModuleConfig.register('moe-loss',
                       help='[Loss] loss computing for MoE.')
class MoEModuleConfig(ModuleConfig):
    def __init__(self, args, context):
        super().__init__(args, context)
        self.logger = context.global_.logger

    @property
    def builder(self):
        return MoELoss

    @classmethod
    def define_parser(cls, parser):
        super().define_parser(parser)
        parser.add_argument('--stage', choices=['1', '2'],
                            help='1: pretrain stage, 2: fusion stage.')
        parser.add_argument('--pretrain-index',
                            help='which expert to pretrain')
        parser.add_argument('--mode', choices=['sum', 'prod'],
                            help='decision fusion strategy')
        parser.add_argument('--names', nargs='+',
                            help='name of experts')
        parser.add_argument('--losses', nargs='+',
                            help='loss computing for all experts, we assume using \'ce\' for all experts now.')
        parser.add_argument('--nwr', action='store_true',
                            help='no weight regularization')
        parser.add_argument('--al2', action='store_true',
                            help='use also each expert losson stage 2')

class MoELoss(nn.Module):
    def __init__(self, stage, pretrain_index, names, mode, losses, logger,
                 nwr, al2):
        super().__init__()
        self.stage = stage
        self.names = names
        self.mode = mode
        self.losses = losses
        self.logger = logger
        self.nwr = nwr
        self.al2 = al2

        self.pretrain_index = pretrain_index
        if self.pretrain_index is not None:
            self.pretrain_index = int(self.pretrain_index)

    def forward(self, inputs, labels, mask):
        logits, weights, mask_valid = inputs
        eps = 1e-12
        weights = weights.clamp(eps, 1-eps)

        logps = [F.log_softmax(l, -1) if t=='ce' else (l+eps).log()
                 for l, t in zip(logits, self.losses)]

        if self.mode == 'prod':
            if self.stage == '1':
                if self.pretrain_index is None:
                    ce_collaborative = sum(logps) / weights.size(1)
                else:
                    ce_collaborative = logps[self.pretrain_index]
            elif self.stage == '2':
                ce_collaborative = sum([weights[:, i:i+1] * logp
                                        for i, logp in enumerate(logps)])

            probs = F.normalize(torch.exp(ce_collaborative), 1)
            preds = probs.argmax(1)
        elif self.mode == 'sum':
            if self.stage == '2':
                probs = sum([weights[:, i:i+1] * logp.exp()
                             for i, logp in enumerate(logps)])
            elif self.stage == '1':
                if self.pretrain_index is None:
                    probs = sum([logp.exp() for i, logp in enumerate(logps)]) / weights.size(1)
                else:
                    probs = logps[self.pretrain_index].exp()

            preds = probs.argmax(1)
            ce_collaborative = (probs+1e-12).log()


        labels_t = torch.zeros_like(logits[0]).scatter(1, labels.unsqueeze(1), 1)

        losses_aux = [torch.sum(logp * labels_t, 1)
                      for i, logp in enumerate(logps)]
        py = torch.stack(losses_aux, -1).exp().clamp(eps, 1-eps).detach()
        py[~mask_valid] = 0
        weights_aux = F.normalize(py, p=1, dim=1)

        losses = {}
        outputs = {}

        losses = [-torch.mean(torch.sum(logp * labels_t, 1)[mask & mask_valid[:, i]])
                  for i, logp in enumerate(logps)]
        if self.stage == '1':
            if self.pretrain_index is None:
                loss = sum(losses)
            else:
                loss = losses[self.pretrain_index]
        elif self.stage == '2':
            if self.al2:
                loss = sum(losses)
            else:
                loss = 0

        losses = {f'exp_{self.names[i]}': l
                  for i, l in enumerate(losses)}

        if not self.training:
            metrics = {}

            def harmonic_mean(vs):
                if 0 in vs:
                    return 0
                else:
                    vs_h = [1/v for v in vs]
                    mean = len(vs_h) / sum(vs_h)
                    return mean

            self.logger.log('note', f'Total: {logps[0].shape[0]}')
            from prettytable import PrettyTable
            modal_table = PrettyTable()
            modal_table.field_names = ['', *self.names]
            preds_list = [logp.argmax(1) for logp in logps]
            mc_list = []
            for i in range(len(logps)):
                failed_mask = (preds_list[i] != labels)
                failed_count = failed_mask.sum().item()
                correct_counts = [(preds_i[failed_mask]==labels[failed_mask]).sum().item()
                                  for preds_i in preds_list]
                correct_counts[i] = failed_count
                row = [self.names[i], *correct_counts]
                modal_table.add_row(row)

                correct_masks = [preds_i[failed_mask]==labels[failed_mask]
                                 for preds_i in preds_list]
                correct_masks.pop(i)
                pos_masks = [preds_i==labels
                             for preds_i in preds_list]
                pos_masks.pop(i)
                correct_count = sum(correct_masks).sum().item()
                pos_count = sum(pos_masks).sum().item()
                correct_rate = correct_count / failed_count
                prec_rate = correct_count / pos_count
                mc_list.append(harmonic_mean([correct_rate, prec_rate]))

            metrics['mc'] = sum(mc_list) / len(mc_list)

            self.logger.log('note', 'Modal complement:')
            self.logger.log('note', modal_table)
            if self.stage == '2':
                correct_table = PrettyTable()
                correct_table.field_names = ['', 'correct', 'failed']
                fe_list = []
                for i in range(len(logps)):
                    failed_mask = (preds_list[i] != labels)
                    correct_mask = (preds_list[i] == labels)
                    adapt_counts = (preds[failed_mask]==labels[failed_mask]).sum().item()
                    adapt_f_counts = (preds[correct_mask]!=labels[correct_mask]).sum().item()
                    correct_table.add_row([self.names[i], adapt_counts, adapt_f_counts])

                self.logger.log('note', 'MoE correction:')
                self.logger.log('note', correct_table)

            outputs['metrics'] = metrics

        if self.stage == '2':
            loss_cls = - torch.mean(torch.sum(ce_collaborative * labels_t, 1)[mask])
            losses.update({'cls_final': loss_cls})
            loss += loss_cls

            weights_m = weights[mask]
            weights_aux_m = weights_aux[mask]

            if not self.nwr:
                loss_wr = F.kl_div(weights_m.log(), weights_aux_m)

                losses.update({'wr': loss_wr})
                loss += loss_wr

        outputs.update({'probs': probs,
                        'preds': preds,})

        return {'loss': loss, 'losses': losses,
                'outputs': outputs}
