from torch.utils.data import Dataset
from datasets import load_dataset
import transformers
from typing import Dict
import torch
import numpy as np
import torch.nn.functional as F
from tqdm import  tqdm
from transformers import AutoTokenizer,AutoModelForCausalLM
from peft import PeftModel
import fire
import inspect
torch.manual_seed(10)
TQDM_MIN_INTER=5

def prepare_inputs(tokenized_text, device):
    # put the text on the device
    tokenized_text = {k: v.to(device) for k, v in tokenized_text.items()}
    position_ids = get_position_ids(tokenized_text['attention_mask'])
    # tokenized_text['position_ids'] = position_ids
    return tokenized_text

def get_position_ids(attention_mask):
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids.masked_fill_(attention_mask == 0, 1)
    return position_ids

def prepare_decoder_only_inputs(prompts, targets, tokenizer, device):
    tokenizer.padding_side = "left"
    prompt_inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=False)
    tokenizer.padding_side = "right"
    target_inputs = tokenizer(targets, return_tensors="pt", padding=True, truncation=False, add_special_tokens=False)
    inputs = {k: torch.cat([prompt_inputs[k], target_inputs[k]], dim=1) for k in prompt_inputs}
    inputs = prepare_inputs(inputs, device)
    labels = inputs["attention_mask"].clone()
    labels[:, :prompt_inputs["input_ids"].shape[1]] = 0
    labels[labels == tokenizer.pad_token_id] = 0
    return inputs, labels

def get_logprobs(logits, input_ids, attention_mask, **kwargs):
    
    logprobs = F.log_softmax(logits, dim=-1)[:, :-1]
    logprobs = torch.gather(logprobs, -1, input_ids[:, 1:, None])
    logprobs = logprobs * attention_mask[:, 1:, None]
    # check for nans
    #assert logprobs.isnan().sum() == 0
    return logprobs.squeeze(-1)

def get_logprobs_list(model, tokenizer, questions, answers, bsz):
    output_logprobs = []
    for i in tqdm(range(len(questions) // bsz + 1),mininterval=TQDM_MIN_INTER):
        q_batch = questions[i*bsz:(i+1)*bsz].tolist()
        a_batch = answers[i*bsz:(i+1)*bsz].tolist()
        inputs, masks = prepare_decoder_only_inputs(q_batch, a_batch, tokenizer, model.device)
        with torch.no_grad():
            logits = model(**inputs).logits
            logprobs = get_logprobs(logits, inputs['input_ids'], masks).sum(-1).detach().cpu().float().numpy()
        output_logprobs.extend(logprobs)
    return output_logprobs

def calc_mc1_score(output_logprobs, answers, labels):
    i = 0
    cors, cors_norm = [], []
    for l in tqdm(labels):
        log_probs = output_logprobs[i:i+len(l)]
        completion_len = answers[i:i+len(l)]
        completions_len = np.array([float(len(i)) for i in completion_len])
        cors.append(np.argmax(log_probs) == l.index(1))
        cors_norm.append(np.argmax(log_probs / completions_len) == l.index(1))
        i += len(l)
    return {'acc': np.mean(cors), 'acc_norm': np.mean(cors_norm)}


def calc_one_m2m3(one_question_logprobs, one_question_labels):
    scores_true, scores_false = [], []
    for logprob, lb in zip(one_question_logprobs, one_question_labels):
        if lb > 0:
            scores_true.append(logprob)
        else:
            scores_false.append(logprob)

    probs_true = np.exp(np.array(scores_true))
    probs_false = np.exp(np.array(scores_false))

    total_prob = sum(probs_true) + sum(probs_false)
    mc2 = sum(probs_true) / total_prob if total_prob > 0 else 0.0

    max_false = max(scores_false) if scores_false else -float('inf')
    mc3 = sum(np.array(scores_true) > max_false) / float(len(scores_true)) if scores_true else 0.0
    return mc2, mc3

def calc_mc2_mc3_score(output_logprobs, labels):
    mc2_list, mc3_list= [], []
    i = 0
    for l in tqdm(labels):
        one_logprobs = output_logprobs[i:i+len(l)]
        mc2,mc3 = calc_one_m2m3(one_logprobs,l)
        mc2_list.append(mc2)
        mc3_list.append(mc3)
        i += len(l)    
        
    return {'mc2': np.mean(mc2_list), 'mc3': np.mean(mc3_list)}


def load_tqa_sentences(dataset, user_tag, assistant_tag, targets="mc1_targets"):
    
    questions, answers,labels = [],[],[]

    for d in dataset:
        q = d['question']
        for i in range(len(d[targets]['labels'])):
            a = d[targets]['choices'][i]
            questions.append(f'{user_tag} ' + q + ' ')
            answers.append(f'{assistant_tag} ' + a)

        labels.append(d[targets]['labels'])

    return np.array(questions), np.array(answers), labels


def test_tqa_prob(model, tokenizer ,user_tag, assistant_tag, batch_size=32):
    print(f'Loading dataset...')
    dataset = load_dataset('truthful_qa', 'multiple_choice')['validation']
    # Eval tqa by Probs
    val_set = load_tqa_sentences(dataset, user_tag, assistant_tag)
    val_set2 = load_tqa_sentences(dataset, user_tag, assistant_tag,targets='mc2_targets')
    metrics = {}
    model.eval()
    questions, answer, labels = val_set
    questions_2, answer_mc2, labels_mc2 = val_set2
    print(f'Evaluating TQA accuracy...')
    with torch.no_grad():
        mc1_logprobs = get_logprobs_list(model,tokenizer,questions,answer, bsz=batch_size)
        mc1_score = calc_mc1_score(mc1_logprobs,answer, labels)
        metrics[f"mc1"] = mc1_score['acc']

        mc2_logprobs = get_logprobs_list(model,tokenizer,questions_2,answer_mc2, bsz=batch_size)
        mc2_mc3 = calc_mc2_mc3_score(mc2_logprobs,labels_mc2)
        metrics["mc2"] = mc2_mc3["mc2"]
        metrics["mc3"] = mc2_mc3["mc3"]
    
    return metrics

def load_fixed_model(base_model_path,ft_fixed_path,device):
    base_model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map=device, torch_dtype=torch.float16,trust_remote_code=True)
    peft_model = PeftModel.from_pretrained(base_model, ft_fixed_path)
    return peft_model

def test_run():
    base_model_path= './models/Llama-2-7b-chat-hf'
    device = "auto"
    tokenizer = AutoTokenizer.from_pretrained(base_model_path,trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token

    model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map=device, torch_dtype=torch.float16,trust_remote_code=True)
    
    test_tqa_prob(model,tokenizer,user_tag="[INST]",assistant_tag="[/INST]")


def log_args(func):
    def wrapper(*args, **kwargs):

        sig = inspect.signature(func)
        bound_args = sig.bind(*args, **kwargs)
        bound_args.apply_defaults()  
        print("Parameters Value：")
        for name, value in bound_args.arguments.items():
            print(f" {name}: {value}")
        print("----------------------------------------\n")
        return func(*args, **kwargs)
        
    return wrapper

@log_args
def main(base_model_path, peft=None, device="auto"):

    tokenizer = AutoTokenizer.from_pretrained(base_model_path,trust_remote_code=True)
    tokenizer.pad_token = tokenizer.eos_token if tokenizer.pad_token is None else tokenizer.pad_token
    model = None

    base_model = AutoModelForCausalLM.from_pretrained(base_model_path, device_map=device, torch_dtype=torch.bfloat16,trust_remote_code=True)
    model = base_model

    if peft is not None:
        peft_model = PeftModel.from_pretrained(base_model, peft)
        model = peft_model
    
    metrics = test_tqa_prob(model, tokenizer, user_tag="[INST]", assistant_tag="[/INST]")
    print("===TQA Eval results===")
    print(metrics)


if __name__ == "__main__":
    fire.Fire(main)
    #test_run()