import os
import types
import torch
import random
import numpy as np
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
import numpy as np
from transformers.models.t5.modeling_t5 import T5DenseActDense
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac
import matplotlib.pyplot as plt
import argparse

tokenizer = T5Tokenizer.from_pretrained('t5-base')
config = T5Config.from_pretrained('t5-base')

sst2 = load_dataset('sst2')

expert_counts = [20,40,60,80,96]

def change_forward(model, k=20):

    def _forward(ffn_self, hidden_states):
        hidden_states = ffn_self.wi(hidden_states)
        hidden_states = ffn_self.act(hidden_states)

        if ffn_self.patterns is not None:
            # golden
            k = ffn_self.k
            bsz, seq_len, hidden_size = hidden_states.shape
            hidden_states_relu = hidden_states.clone()
            hidden_states_relu = hidden_states_relu.view(-1, hidden_size)
            score = torch.matmul(hidden_states_relu, ffn_self.patterns.transpose(0, 1))
            labels = torch.topk(score, k=k, dim=-1)[1].view(bsz, seq_len, k)
            cur_mask = torch.nn.functional.embedding(labels, ffn_self.patterns).sum(-2)
            hidden_states[cur_mask == False] = 0  
             
        hidden_states = ffn_self.dropout(hidden_states)
        hidden_states = ffn_self.wo(hidden_states)
        return hidden_states

    def modify_ffn(ffn, path):
        assert type(ffn) == T5DenseActDense
        labels = torch.load(path)
        cluster_num = max(labels)+1
        patterns = []
        for i in range(cluster_num):
            patterns.append(np.array(labels) == i)
        ffn.patterns = torch.Tensor(patterns).cuda()
        ffn.k = k
        ffn.forward_old = ffn.forward
        ffn.forward = types.MethodType(_forward, ffn)   

    # encoder
    for layer_idx, layer in enumerate(model.encoder.block):
        ffn = layer.layer[1].DenseReluDense
        path = os.path.join('results/t5-base', 'param_split', 'encoder.block.{}.layer.1.DenseReluDense.wi.weight'.format(layer_idx))
        modify_ffn(ffn, path) 

    #decoder
    for layer_idx, layer in enumerate(model.decoder.block):
        ffn = layer.layer[2].DenseReluDense
        path = os.path.join('results/t5-base', 'param_split', 'decoder.block.{}.layer.2.DenseReluDense.wi.weight'.format(layer_idx))
        modify_ffn(ffn, path)  

parser = argparse.ArgumentParser(description='T5-base MoEfication')
parser.add_argument('--seed', type=int, required=True, help="Seed used to initialize p-rng")
args = parser.parse_args()

SEED = args.seed

noisy_exchange_acc = {}
noise_level = 0.1
for k in expert_counts:
    noise_type = "exchange"
    if(noise_type=="spelling"):
        augmenter = nac.KeyboardAug(aug_char_min=1, aug_char_max=1, aug_word_min=2, aug_word_max=2)
    elif(noise_type=="exchange"):
        augmenter = naw.RandomWordAug(aug_p=noise_level, aug_min=1, aug_max=2, action="swap")
    else:
        raise Exception("Augmenter not supported")
    
    sst2_dev = sst2['validation']
    pred = []
    model = T5ForConditionalGeneration.from_pretrained('t5-base').cuda()
    pred = []
    change_forward(model, k)
    model.eval()
    torch.manual_seed(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    for instance in sst2_dev:
        input_ids = tokenizer("sst2 sentence: "+augmenter.augment(instance['sentence'])[0], return_tensors="pt").input_ids.cuda()
        dec_input_ids = tokenizer("<extra_id_0>", return_tensors="pt").input_ids.cuda()[:, :1]
        output = model(input_ids=input_ids, labels=dec_input_ids)
        pred.append(int(output.logits[:, 0, 1465].item() > output.logits[:, 0, 2841].item()) == instance['label'])
    noisy_exchange_acc[k] = (sum(pred)*1.0)/len(pred)
    print("Acc", sum(pred) * 1. / len(pred), 'k', k)
    
noisy_spelling_acc = {}
for k in expert_counts:
    noise_type = "spelling"
    if(noise_type=="spelling"):
        augmenter = nac.KeyboardAug(aug_char_min=1, aug_char_max=1, aug_word_min=2, aug_word_max=2)
    elif(noise_type=="exchange"):
        augmenter = naw.RandomWordAug(aug_p=noise_level, aug_min=1, aug_max=2, action="swap")
    else:
        raise Exception("Augmenter not supported")
    
    sst2_dev = sst2['validation']
    pred = []
    model = T5ForConditionalGeneration.from_pretrained('t5-base').cuda()
    pred = []
    change_forward(model, k)
    model.eval()
    torch.manual_seed(SEED)
    random.seed(SEED)
    np.random.seed(SEED)
    torch.backends.cudnn.deterministic = True
    for instance in sst2_dev:
        input_ids = tokenizer("sst2 sentence: "+augmenter.augment(instance['sentence'])[0], return_tensors="pt").input_ids.cuda()
        dec_input_ids = tokenizer("<extra_id_0>", return_tensors="pt").input_ids.cuda()[:, :1]
        output = model(input_ids=input_ids, labels=dec_input_ids)
        pred.append(int(output.logits[:, 0, 1465].item() > output.logits[:, 0, 2841].item()) == instance['label'])
    noisy_spelling_acc[k] = (sum(pred)*1.0)/len(pred)
    print("Acc", sum(pred) * 1. / len(pred), 'k', k)
    
torch.save(noisy_spelling_acc, "noisy_spelling_t5_sst2_acc_vanilla_"+str(SEED)+".pt")
torch.save(noisy_exchange_acc, "noisy_exchange_t5_sst2_acc_vanilla_"+str(SEED)+".pt")
