import os
import types
import torch
import random
import numpy as np
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
import numpy as np
from transformers.models.t5.modeling_t5 import T5DenseActDense
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac
import matplotlib.pyplot as plt
import argparse

tokenizer = T5Tokenizer.from_pretrained('t5-base')
config = T5Config.from_pretrained('t5-base')

sst2 = load_dataset('sst2')

expert_counts = [20,40,60,80,96]

def change_forward(model, k=20):

    def _forward(ffn_self, hidden_states):
        hidden_states = ffn_self.wi(hidden_states)
        hidden_states = ffn_self.act(hidden_states)

        if ffn_self.patterns is not None:
            # golden
            k = ffn_self.k
            bsz, seq_len, hidden_size = hidden_states.shape
            hidden_states_relu = hidden_states.clone()
            hidden_states_relu = hidden_states_relu.view(-1, hidden_size)
            score = torch.matmul(hidden_states_relu, ffn_self.patterns.transpose(0, 1))
            labels = torch.topk(score, k=k, dim=-1)[1].view(bsz, seq_len, k)
            cur_mask = torch.nn.functional.embedding(labels, ffn_self.patterns).sum(-2)
            hidden_states[cur_mask == False] = 0  
             
        hidden_states = ffn_self.dropout(hidden_states)
        hidden_states = ffn_self.wo(hidden_states)
        return hidden_states

    def modify_ffn(ffn, path):
        assert type(ffn) == T5DenseActDense
        labels = torch.load(path)
        cluster_num = max(labels)+1
        patterns = []
        for i in range(cluster_num):
            patterns.append(np.array(labels) == i)
        ffn.patterns = torch.Tensor(patterns).cuda()
        ffn.k = k
        ffn.forward_old = ffn.forward
        ffn.forward = types.MethodType(_forward, ffn)   

    # encoder
    for layer_idx, layer in enumerate(model.encoder.block):
        ffn = layer.layer[1].DenseReluDense
        path = os.path.join('results/t5-base', 'param_split', 'encoder.block.{}.layer.1.DenseReluDense.wi.weight'.format(layer_idx))
        modify_ffn(ffn, path) 

    #decoder
    for layer_idx, layer in enumerate(model.decoder.block):
        ffn = layer.layer[2].DenseReluDense
        path = os.path.join('results/t5-base', 'param_split', 'decoder.block.{}.layer.2.DenseReluDense.wi.weight'.format(layer_idx))
        modify_ffn(ffn, path)  

baseline_acc = {}
expert_counts = [20,40,60,80,96]
for k in expert_counts:
    sst2_dev = sst2['validation']
    pred = []
    model = T5ForConditionalGeneration.from_pretrained('t5-base').cuda()
    pred = []
    change_forward(model, k)
    model.eval()
    for instance in sst2_dev:
        input_ids = tokenizer("sst2 sentence: "+instance['sentence'], return_tensors="pt").input_ids.cuda()
        dec_input_ids = tokenizer("<extra_id_0>", return_tensors="pt").input_ids.cuda()[:, :1]
        output = model(input_ids=input_ids, labels=dec_input_ids)
        pred.append(int(output.logits[:, 0, 1465].item() > output.logits[:, 0, 2841].item()) == instance['label'])
    baseline_acc[k] = (sum(pred)*1.0)/len(pred)
    print("Acc", sum(pred) * 1. / len(pred), 'k', k)

torch.save(baseline_acc, "baseline_sa_t5_sst2_acc_vanilla.pt")