import torch
import torch.nn as nn
import torch.nn.functional as F
from operator import itemgetter
from functional_forward_bert import functional_bert_for_classification
from transformers import PreTrainedModel
from scipy.spatial import distance
from utils_glue import linear_CKA


class MetaPatientDistillation(nn.Module):
    def __init__(self, t_config, s_config):
        super(MetaPatientDistillation, self).__init__()
        self.t_config = t_config
        self.s_config = s_config

    def forward(self, t_model, s_model, order, input_ids, token_type_ids, attention_mask, labels, args, teacher_grad):
        if teacher_grad:
            #t_model.train()
            t_outputs = t_model(input_ids=input_ids,
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask,
                                labels=labels)
        else:
            with torch.no_grad():
                t_outputs = t_model(input_ids=input_ids,
                                    token_type_ids=token_type_ids,
                                    attention_mask=attention_mask)

        s_outputs = s_model(input_ids=input_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask,
                            labels=labels)

        #if isinstance(s_model, PreTrainedModel):
        #    s_outputs = s_model(input_ids=input_ids,
        #                        token_type_ids=token_type_ids,
        #                        attention_mask=attention_mask,
        #                        labels=labels)
        #else:
        #    s_outputs = functional_bert_for_classification(
        #        s_model,
        #        self.s_config,
        #        input_ids=input_ids,
        #        token_type_ids=token_type_ids,
        #        attention_mask=attention_mask,
        #        labels=labels
        #    )

        if teacher_grad:
            train_loss_t, t_logits, t_features = t_outputs[0], t_outputs[1], t_outputs[-1]
        else:
            train_loss_t = 0
            t_logits, t_features = t_outputs[0], t_outputs[-1]
        train_loss, s_logits, s_features = s_outputs[0], s_outputs[1], s_outputs[-1]

        train_loss += train_loss_t

        max_idx_s = torch.argmax(t_logits, dim=1)
        max_idx_t = torch.argmax(s_logits, dim=1)
        # 教师替换
        logit_t_cor = t_logits.clone().detach().cpu().numpy()
        for i in range(max_idx_t.size()[0]):
            t = logit_t_cor[i][max_idx_t[i]]
            logit_t_cor[i][max_idx_t[i]] = logit_t_cor[i][labels[i]]
            logit_t_cor[i][labels[i]] = t
            #swap(logit_t_cor[i][max_idx_t[i]], logit_t_cor[i][target[i]])
        logit_t_cor = torch.tensor(logit_t_cor).cuda()
        # 学生替换
        logit_s_cor = s_logits.clone().detach().cpu().numpy()
        for i in range(max_idx_s.size()[0]):
            t = logit_s_cor[i][max_idx_s[i]]
            logit_s_cor[i][max_idx_s[i]] = logit_s_cor[i][labels[i]]
            logit_s_cor[i][labels[i]] = t
        logit_s_cor = torch.tensor(logit_s_cor).cuda()

        soft_logit_s = F.softmax(s_logits.cpu().detach(), dim=-1)
        soft_logit_t = F.softmax(t_logits.cpu().detach(), dim=-1)

        cka = linear_CKA(soft_logit_s.numpy(), soft_logit_t.numpy())
        #print(F.softmax(t_logits.cpu().detach()))
        #print(F.softmax(s_logits.cpu().detach()))
        #print(t_logits.cpu().detach().view(-1, bs).size())
        #print(F.softmax(t_logits.cpu().detach()).view(-1, bs))
        #print(F.softmax(s_logits.cpu().detach()).view(-1, bs))
        #return
        
        probLoyalty = 1-distance.jensenshannon(soft_logit_s.numpy().T, soft_logit_t.numpy().T)

        pre_s = torch.argmax(soft_logit_s, dim=1)
        pre_t = torch.argmax(soft_logit_t, dim=1)
        # results which are the same
        agree_raw = pre_s == pre_t
        agree_sum = torch.nonzero(agree_raw==True, as_tuple=False).size(0)/pre_s.size(0)
        

        if args.logits_mse:
            #soft_loss = F.mse_loss(logit_t_cor, s_logits) + F.mse_loss(t_logits, logit_s_cor)
            soft_loss = F.mse_loss(F.softmax(t_logits, dim=-1), F.softmax(logit_s_cor, dim=-1))
        else:
            T = args.temperature
            soft_targets = F.softmax(t_logits / T, dim=-1)

            probs = F.softmax(s_logits / T, dim=-1)
            soft_loss = F.mse_loss(soft_targets, probs) * T * T

        if args.beta == 0:  # if beta=0, we don't even compute pkd_loss to save some time
            pkd_loss = torch.zeros_like(soft_loss)
        else:
            t_features = torch.cat(t_features[1:-1], dim=0).view(self.t_config.num_hidden_layers - 1,
                                                                 -1,
                                                                 args.max_seq_length,
                                                                 self.t_config.hidden_size)[:, :, 0]

            s_features = torch.cat(s_features[1:-1], dim=0).view(self.s_config.num_hidden_layers - 1,
                                                                 -1,
                                                                 args.max_seq_length,
                                                                 self.s_config.hidden_size)[:, :, 0]

            t_features = itemgetter(order)(t_features)
            t_features = t_features / t_features.norm(dim=-1).unsqueeze(-1)
            s_features = s_features / s_features.norm(dim=-1).unsqueeze(-1)
            pkd_loss = F.mse_loss(s_features, t_features, reduction="mean")

        return train_loss, soft_loss, pkd_loss, probLoyalty, cka, agree_sum

    def s_prime_forward(self, s_prime, input_ids, token_type_ids, attention_mask, labels, args):

        s_outputs = functional_bert_for_classification(
            s_prime,
            self.s_config,
            input_ids=input_ids,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
            labels=labels,
            is_train=False
        )

        train_loss, s_logits, s_features = s_outputs[0], s_outputs[1], s_outputs[-1]

        return train_loss
