import torch
import torch.nn as nn
from transformers import AutoModel
import torch.nn.functional as F
from torch.autograd import grad
from transformers import BertConfig, BertModel, RobertaConfig, RobertaModel, ElectraConfig, ElectraModel
from tqdm import tqdm


class PretrainModel(nn.Module):
    def __init__(self, model_path, num_class, random_init=False, pooling_type='cls', prior_model=None):
        super().__init__()

        self.model_path = model_path

        if random_init:
            if 'bert' in model_path.lower():
                configuration = BertConfig.from_pretrained(model_path)
                self.model = BertModel(configuration)
            elif "roberta" in model_path.lower():
                configuration = RobertaConfig.from_pretrained(model_path)
                self.model = RobertaModel(configuration)
            elif "electra" in model_path.lower():
                configuration = ElectraConfig.from_pretrained(model_path)
                self.model = ElectraModel(configuration)
            else:
                raise ValueError('The model:{} is not support now!'.format(model_path))

        else:
            self.model = AutoModel.from_pretrained(model_path)

        self.hidden_size = self.model.config.to_dict()['hidden_size']

        self.pooling_type = pooling_type
        if self.pooling_type not in ['cls', 'last_avg', 'first_last_avg']:
            raise ValueError('The value of pooling_type can only be: cls, last_avg or first_last_avg!')

        self.num_class = num_class

        self.fc = nn.Linear(self.hidden_size, self.num_class)

        if prior_model:
            w0_dict = dict()
            for param in prior_model.named_parameters():
                w0_dict[param[0]] = param[1].clone().detach()
            self.w0_dict = w0_dict
            print("done get prior weights...")

    def forward(self, input_ids, attention_mask, token_type_ids=None):

        output = self.model(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            output_hidden_states=True)

        if self.pooling_type == 'cls':
            final = output.last_hidden_state[:, 0]

        elif self.pooling_type == 'last_avg':
            final = nn.AdaptiveAvgPool2d((1, self.hidden_size))(output.last_hidden_state).squeeze(1)

        elif self.pooling_type == 'first_last_avg':
            first_output = output.hidden_states[1]
            last_output = output.hidden_states[-1]

            first_avg = nn.AdaptiveAvgPool2d((1, self.hidden_size))(first_output)
            last_avg = nn.AdaptiveAvgPool2d((1, self.hidden_size))(last_output)

            final = torch.cat([first_avg, last_avg], dim=1)

            final = torch.mean(final, dim=1).squeeze(-1)

        logits = self.fc(final)
        return logits

    def compute_information_bp_fast(self, dataloader, criterion):
        param_keys = [p[0] for p in self.named_parameters()]

        delta_w_dict = dict().fromkeys(param_keys)

        for pa in self.named_parameters():
            w0 = self.w0_dict[pa[0]]
            delta_w = pa[1] - w0
            delta_w_dict[pa[0]] = delta_w

        info_dict = dict()

        gw_dict = dict().fromkeys(param_keys)

        for idx, (token, mask, token_type_ids, label) in tqdm(enumerate(dataloader)):

            pred = self.forward(input_ids=token.to(self.model.device), attention_mask=mask.to(self.model.device),
                                token_type_ids=token_type_ids.to(self.model.device))
            loss = criterion(pred, label.to(self.model.device))

            gradients = grad(loss, self.parameters(), allow_unused=True)

            for i, gw in enumerate(gradients):
                if gw is None:
                    continue

                gw_ = gw.flatten()
                if gw_dict[param_keys[i]] is None:
                    gw_dict[param_keys[i]] = gw_
                else:
                    gw_dict[param_keys[i]] += gw_

        for k in gw_dict.keys():
            if gw_dict[k] == None:
                print(k, gw_dict[k])

        for k in gw_dict.keys():
            if "weight" in k:

                if gw_dict[k] == None:
                    continue

                gw_dict[k] *= 1 / len(dataloader)
                delta_w = delta_w_dict[k]
                info_ = (delta_w.flatten() * gw_dict[k]).sum() ** 2
                info_dict[k] = info_.item()

        return info_dict
