import os
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
from datasets import load_from_disk
import logging


def get_eos_str(name):
    if 'llama' in name.lower():
        return "<|eot_id|>"
    elif 'phi' in name.lower():
        return "<|end|>"
    elif 'mistral' in name.lower():
        return "[/INST]"
    elif 'gemma' in name.lower():
        return "<end_of_turn>"
    else:
        raise ValueError(f'no eos_string for model {name}')


class scaled_module(torch.nn.Module):
    def __init__(self, module, id_, total, init_len, tep_len, name='llama'):
        super().__init__()
        self.module = module
        self.id = id_
        self.total = total
        self.init_len = init_len
        self.tep_len = tep_len
        self.final_scale = 1
        self.noise_scale = 0.5
        self.get_final_scale(name)

    def get_final_scale(self, name):
        if '70b' in name.lower():
            self.final_scale = 2.0
            self.noise_scale = 0.425
        elif 'llama' in name.lower():
            self.final_scale = 1.0
            self.noise_scale = 0.2
        elif 'phi' in name.lower():
            self.final_scale = 2.0
            self.noise_scale = 0.5
        elif 'mistral' in name.lower():
            self.final_scale = 1.5
            self.noise_scale = 0.3
        elif 'gemma' in name.lower():
            self.final_scale = 1.5
            self.noise_scale = 0.35
        else:
            raise ValueError(f'no model scale setting for model {name}')

    def get_scales(self, x):
        scales = torch.ones_like(x[:, :, 0])
        if x.shape[1] > 1:
            if self.id < self.total - 1:
                scales[:, self.init_len:self.tep_len] = 1 + self.noise_scale * torch.rand_like(scales[:, self.init_len:self.tep_len])
            else:
                scales[:, self.init_len:self.tep_len] = 1 / self.final_scale
        return scales[:, :, None]

    def forward(self, x, **kwargs):
        out = list(self.module(x, **kwargs))
        scales = self.get_scales(out[0])
        out[0] /= scales
        return out


class LLM_for_BBH:
    TOTAL_B = 10
    def __init__(self, model, tokenizer, model_name='', privacy=False):
        self.model = model
        self.model.eval()
        self.tokenizer = tokenizer
        self.model_name = model_name
        self.enable_privacy = privacy
        self.eot_str = get_eos_str(model_name)
        self.terminators = [
            self.tokenizer.eos_token_id,
            self.tokenizer.convert_tokens_to_ids(self.eot_str),
        ]
        self.prompt_pre = f"You are a helpful and concise assistant. You are required to answer the given question. The final answer should be given with 'So the answer is ' followed by the correct answer, LIKE SO 'So the answer is True. {self.eot_str}', OR 'So the answer is No. {self.eot_str}', OR 'So the answer is (D). {self.eot_str}'.\n\n"
        self.INIT_LEN = self.tokenizer([self.prompt_pre], return_tensors="pt").input_ids.shape[1]
        self.prompt_ = ''
        self.TEP_LEN = 1

    def model_privacy_augment(self):
        print(f"**ENABLE PRIVACY BLOCK, total {self.TEP_LEN} tokens are preserved**")
        for i in range(self.TOTAL_B):
            self.model.model.layers[i] = scaled_module(self.model.model.layers[i], i, self.TOTAL_B, self.INIT_LEN, self.TEP_LEN, self.model_name)

    def remove_privacy_blocks(self):
        print("**REMOVE PRIVACY BLOCK**")
        for i in range(self.TOTAL_B):
            self.model.model.layers[i] = self.model.model.layers[i].module

    def enable_nshots_prompt(self, example_datasets, nshots=3):
        self.prompt_ = self.prompt_pre
        self.prompt_ += "Here are some examples about the interactions between question Q and assistant A:\n\n"
        self.prompt_ += self.examples_add_terminator(example_datasets, nshots=nshots)
        if self.enable_privacy:
            self.TEP_LEN = self.tokenizer([self.prompt_], return_tensors="pt").input_ids.shape[1]
            for i in range(self.TOTAL_B):
                self.model.model.layers[i].tep_len = self.TEP_LEN

    def examples_add_terminator(self, example, nshots):
        spted = example.split('\n\n')
        tmp = ''
        for i in range(nshots+1):
            tmp += spted[i]
            tmp += f' {self.eot_str}\n\n' if i > 0 else '\n\n'
        return tmp

    def get_all_prompt(self, prompt):
        prompt = self.prompt_ + prompt
        return prompt

    def generate(self, prompt):
        device = "cuda"
        prompt = self.get_all_prompt(prompt)
        prompt_input = f"{prompt}\nA: Let's think step by step.\n"
        model_inputs = self.tokenizer([prompt_input], return_tensors="pt").to(device)
        with torch.no_grad():
            generated_ids = self.model.generate(**model_inputs, max_new_tokens=512, eos_token_id=self.terminators)
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]
        return self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    def evaluate(self, dataset, sub='unknown'):
        total = len(dataset)
        acc_count = 0
        fails = 0
        pbar = tqdm(range(total))
        for i in pbar:
            data = dataset[i]
            prompt_q = 'Q: '+data['input']
            result = self.generate(prompt_q)
            if (f"So the answer is {data['target']}" in result):
                acc_count += 1
            elif 'So the answer is' not in result:
                fails += 1
            pbar.set_postfix_str(f"acc: {round(acc_count/(i+1),2)} | fails: {fails}")
        log.info(f"task: {sub} | accuracy: {acc_count/total} | fails: {fails}")
        return acc_count/total, fails, acc_count, total


if __name__ == "__main__":
    phi_path = "../pretrained_models/Phi-3-medium-128k-instruct"
    gemma_path = "../pretrained_models/gemma-2-9b-it"
    llama_path = "../pretrained_models/Meta-Llama-3-8B-Instruct"
    mistral_path = "../pretrained_models/Mistral-7B-Instruct-v0.3"
    llama70_path = "../pretrained_models/Meta-Llama-3-70B-Instruct-AWQ"

    model_path = gemma_path
    enable_privacy = True

    # init logger
    log = logging.getLogger(f'bbh')
    log.setLevel(level=logging.DEBUG)
    log_name = model_path.split('/')[-1].split('-')[0].lower()
    model_name = model_path.split('/')[-1]
    file_name = f'bbh_{log_name}.log' if not enable_privacy else f'bbh_{log_name}_privacy.log'
    handler = logging.FileHandler(f'./bbh_eva/{file_name}', encoding='utf-8', mode='w')
    handler.setLevel(logging.INFO)
    log.addHandler(handler)

    model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto', torch_dtype=torch.bfloat16)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model.generation_config.pad_token_id = tokenizer.eos_token_id

    llm_bbh = LLM_for_BBH(model=model, tokenizer=tokenizer, model_name=model_name, privacy=enable_privacy)

    all_sub = sorted(os.listdir('../dataset_save_disk/bbh/'))
    acc_count, total = 0, 0
    print(f'start evaluating {model_name}!')
    for i, sub in enumerate(all_sub):
        print(f"--------No. {i+1}/{len(all_sub)} | {sub}--------")
        path = os.path.join('../dataset_save_disk/bbh/', sub)
        dataset = load_from_disk(path)
        with open(f'bbh-cot-prompts/{sub}.txt', 'r', encoding='utf-8') as f:
            examples = f.read().split('\n-----\n')[1]
        llm_bbh.enable_nshots_prompt(example_datasets=examples, nshots=3)
        if enable_privacy:
            llm_bbh.model_privacy_augment()
        acc, fails, acc_count_, total_ = llm_bbh.evaluate(dataset["test"], sub=sub)
        acc_count += acc_count_
        total += total_
        print(f"task: {sub} | accuracy: {acc} | fails: {fails}")
        print(f"total accuracy: {acc_count/total} | acc_count: {acc_count} | total test: {total}")
        if enable_privacy:
            llm_bbh.remove_privacy_blocks()

    log.info(f"total accuracy: {acc_count/total} | acc_count: {acc_count} | total test: {total}")
    logging.shutdown()
