import json
import os, sys
import re
import time 
from tqdm import tqdm
import argparse
from pathlib import Path
from typing import Tuple
import pandas as pd
from datasets import load_dataset
from utils import * 
from transformers import AutoTokenizer, AutoModelForCausalLM
import random
from torch import tensor
from utils import gather_log_probs, mask_hf_labels, masked_mean


templates = [
    "What do you think of {}?",
    "What do you feel about {}?",
    "How do you view {}?",
]
for position in [
    "opinion of",
    "stance on",
    "position on",
    "attitude about",
    "view on",
    "take on",
    "impression of",
    "assessment of",
    "judgment of",
    "sentiment of",
]:
    templates.append("What is your " + position + " {}?")


per_list = [
    "extraversion",
    "agreeableness", 
    "neuroticism"
]

def format_data(tokenizer, method_type="prompt"):
    
    def get_edit_labels(ids, prompts=None):
        labels = ids.clone()
        labels[labels == tokenizer.pad_token_id] = -100
        return labels
    
    data_path = "./data/personality/test.json"
    data = json.load(open(data_path, "r"))
    
    pre_instruction = open("./prompt_lib/pre_instruction.txt", "r").read() 
    edit_instruction = open("./prompt_lib/edit_instruction.txt", "r").read()
    
    pre_case_format = open("./prompt_lib/pre_full_case.txt", "r").read()
    edit_case_format = open("./prompt_lib/edit_full_case.txt", "r").read()
    
    if method_type != "prompt":
        pre_demo = open("./prompt_lib/pre_demo.txt", "r").read() 
        edit_demo = open("./prompt_lib/edit_demo.txt", "r").read()
    
    examples = []
        
    pre_prompt = pre_instruction if method_type == "prompt" else pre_instruction + pre_demo
    edit_prompt = edit_instruction if method_type == "prompt" else edit_instruction + edit_demo
    
    for edit_idx, mention in tqdm(enumerate(data), desc="describe data", total=len(data)):
        target_per = random.choice([0, 1, 2])
        topic = mention["ent"]        
        target_per_text = per_list[target_per]
        
        non_edit_idxs = [t for t in range(len(data)) if t!=edit_idx] 
        outer_idx = random.sample(non_edit_idxs,1)[0]
        outer_mention = data[outer_idx]
        outer_topic = mention["ent"]
            
        all_per_idxs, all_per_texts, outer_per_texts = [], [], []
        for idx, per in enumerate(per_list):
            all_per_idxs += ([idx] * len(mention[per]))
            all_per_texts += mention[per]
            outer_per_texts += outer_mention[per]
            
        
        target_per_idxs = [target_per] * len(all_per_idxs) # to caculate the same_per_mask
 
        questions = random.sample(templates, len(all_per_texts))
        
        all_pre_inputs = [pre_prompt + pre_case_format.format(question.format(topic), per_text) for question, per_text in zip(questions, all_per_texts)]
        all_edit_inputs = [edit_prompt + edit_case_format.format(target_per_text, topic, question.format(topic), per_text) for question, per_text in zip(questions, all_per_texts)]
        
        outer_pre_inputs = [pre_prompt + pre_case_format.format(question.format(outer_topic), per_text) for question, per_text in zip(questions, outer_per_texts)]
        outer_edit_inputs = [edit_prompt + edit_case_format.format(target_per_text, topic, question.format(outer_topic), per_text) for question, per_text in zip(questions, outer_per_texts)]
        
        edit_toks = {
            f"{k1}_{k2}": v2
            for k1, v1 in {
                "pre": all_pre_inputs,
                "edit": all_edit_inputs,
                "outer_pre": outer_pre_inputs,
                "outer_edit": outer_edit_inputs
            }.items()
            for k2, v2 in tokenizer(
                v1,
                return_tensors="pt",
                padding=True,
                max_length=512,
                truncation=True,
            ).items()
        }
        
        for key in ["pre", "edit", "outer_pre", "outer_edit"]:
            value = edit_toks[f"{key}_input_ids"]
            mask = [([True] * value.shape[-1])] * value.shape[0]
            for i in range(value.shape[0]):
                sep_idx = list(value[i]).index(tokenizer.convert_tokens_to_ids("</s>"))
                for j in range(sep_idx): #连带</s>一块mask掉
                    mask[i][j] = False
            edit_toks[key + "_q_mask"] = mask 
        
        same_per_mask = torch.tensor([i == o for i, o in zip(target_per_idxs, all_per_idxs)], device="cuda")

        examples.append({
            "target_per": target_per,
            "target_per_text": target_per_text,
            "topic": topic,
            "pre_prompt": {
                "input_ids": edit_toks["pre_input_ids"].to("cuda"),
                "attention_mask": edit_toks["pre_attention_mask"].to('cuda'),
                "labels": get_edit_labels(edit_toks["pre_input_ids"]).to('cuda'),
                "q_mask": tensor(edit_toks["pre_q_mask"]).to("cuda"),
            },
            "edit_prompt": {
                "input_ids": edit_toks["edit_input_ids"].to("cuda"),
                "attention_mask": edit_toks["edit_attention_mask"].to('cuda'),
                "labels": get_edit_labels(edit_toks["edit_input_ids"]).to('cuda'),
                "q_mask": tensor(edit_toks["edit_q_mask"]).to("cuda"),
            },
            "outer_pre_prompt": {
                "input_ids": edit_toks["outer_pre_input_ids"].to("cuda"),
                "attention_mask": edit_toks["outer_pre_attention_mask"].to('cuda'),
                "labels": get_edit_labels(edit_toks["outer_pre_input_ids"]).to('cuda'),
                "q_mask": tensor(edit_toks["outer_pre_q_mask"]).to("cuda"),
            },
            "outer_edit_prompt": {
                "input_ids": edit_toks["outer_edit_input_ids"].to("cuda"),
                "attention_mask": edit_toks["outer_edit_attention_mask"].to('cuda'),
                "labels": get_edit_labels(edit_toks["outer_edit_input_ids"]).to('cuda'),
                "q_mask": tensor(edit_toks["outer_edit_q_mask"]).to("cuda"),
            },
            "same_per_mask": same_per_mask
        })
    
    return examples


def es_per(example, pre_logits, edit_logits):
    with torch.no_grad():
        
        pre_q_mask = example["pre_prompt"]["q_mask"]
        edit_q_mask = example["edit_prompt"]["q_mask"]
        
        pre_labels = example["pre_prompt"]["labels"]
        edit_labels = example["edit_prompt"]["labels"]
        
        pre_mask, pre_targ = mask_hf_labels(pre_labels)
        edit_mask, edit_targ = mask_hf_labels(edit_labels)
        
        same_per_mask = example["same_per_mask"]

        pre_pos_mask = same_per_mask.unsqueeze(-1) * pre_q_mask 
        pre_neg_mask = (~same_per_mask).unsqueeze(-1) * pre_q_mask 
        edit_pos_mask = same_per_mask.unsqueeze(-1) * edit_q_mask 
        edit_neg_mask = (~same_per_mask).unsqueeze(-1) * edit_q_mask 
        
        pre_token_log_probs = gather_log_probs(pre_logits, pre_targ)
        edit_token_log_probs = gather_log_probs(edit_logits, edit_targ)

        mean_pos_pre = masked_mean(pre_token_log_probs, pre_pos_mask)
        mean_pos_edit = masked_mean(edit_token_log_probs, edit_pos_mask)
        mean_neg_edit = masked_mean(edit_token_log_probs, edit_neg_mask)

        z_per = (mean_pos_edit - mean_neg_edit).sigmoid()
        z_topic_raw = (mean_pos_edit - mean_pos_pre).exp()
        z_topic = min(1, z_topic_raw)

        es_per = z_per * z_topic
        return {
            "acc_per": es_per,
            "z_per": z_per,
            "z_topic": z_topic,
            "z_topic_raw": z_topic_raw,
            "correct_probs": mean_pos_edit,
            "wrong_probs": mean_neg_edit,
        }

def kl_loc_loss(pre, post, mask=None):
    
    pre = pre.to(torch.float32).contiguous()
    post = post[:,-pre.shape[1]:,:].to(torch.float32).contiguous()
    
    sequence = pre.dim() == 3
    pre_ = pre.view(-1, pre.shape[-1])
    post_ = post.view(pre_.shape)
    assert pre_.shape[0] == post_.shape[0]

    if not sequence:
        if pre_.shape[-1] == 1:  # No masking needed for binary classification
            return (pre.sigmoid() * (F.logsigmoid(pre) - F.logsigmoid(post))).mean() + (
                (-pre).sigmoid() * (F.logsigmoid(-pre) - F.logsigmoid(-post))
            ).mean()
    else:  # We have sequences of predictions; masking needed
        # print("sequence")
        if pre_.shape[-1] > 1:
            assert mask is not None
            mask_ = mask.view(pre_.shape[0])
            kl = (pre_.softmax(-1) * (pre_.log_softmax(-1) - post_.log_softmax(-1))).sum(-1)
            return (kl * mask_).sum() / mask_.sum()

    raise NotImplementedError


def main(ckpt_dir: str,):
    
    tokenizer = AutoTokenizer.from_pretrained(ckpt_dir)
    model = AutoModelForCausalLM.from_pretrained(ckpt_dir, device_map="auto", 
                                                #  torch_dtype=torch.float16,
                                                 )
    if "gpt" in ckpt_dir:
        tokenizer.add_special_tokens({'sep_token': '</s>'})
        model.resize_token_embeddings(len(tokenizer))
        model.lm_head.weight.data[-1, :] = model.lm_head.weight.data.mean(0)
        
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    tokenizer.bos_token_id = 1
    
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    output_filename = os.path.join(args.output_dir, args.prefix + ".json")
    
    run_result = []
    
    start_time = time.time()
    over_len_err = 0
    total_cnt = 0

    formatted_data = format_data(tokenizer, args.method_type)
    
    es = []
    dd = []
    
    for example in tqdm(formatted_data):
        
        pre_base_logits = model(
            input_ids=example["pre_prompt"]["input_ids"],
            attention_mask=example["pre_prompt"]["attention_mask"],   
            labels=example["pre_prompt"]["labels"],
        )["logits"]
        
        edit_base_logits = model(
            input_ids=example["edit_prompt"]["input_ids"],
            attention_mask=example["edit_prompt"]["attention_mask"],   
            labels=example["edit_prompt"]["labels"],
        )["logits"]

        outer_pre_logits = model(
            input_ids=example["outer_pre_prompt"]["input_ids"],
            attention_mask=example["outer_pre_prompt"]["attention_mask"],   
            labels=example["outer_pre_prompt"]["labels"],
        )["logits"]

        outer_edit_logits = model(
            input_ids=example["outer_edit_prompt"]["input_ids"],
            attention_mask=example["outer_edit_prompt"]["attention_mask"],   
            labels=example["outer_edit_prompt"]["labels"],
        )["logits"]
        
        es.append(es_per(example, pre_base_logits, edit_base_logits)["acc_per"])
        dd.append(kl_loc_loss(outer_pre_logits, outer_edit_logits, example["outer_pre_prompt"]["q_mask"]))
        

    result = {}
    result["es"] = torch.tensor(es).mean().item()
    result["dd"] = torch.tensor(dd).mean().item()
        
    with open(output_filename, 'w') as f:
        json.dump(result, f, ensure_ascii=False, indent=4)

    end_time = time.time()
    print("total run time %.2f" % (end_time - start_time))
    print("over_len_err:{}/{}".format(over_len_err, total_cnt))
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--ckpt_dir', type=str, default='meta-llama/Llama-2-7b-chat-hf')
    parser.add_argument('--output_dir', type=str, default='./output')
    parser.add_argument('--method_type', type=str, default='IKE')
    parser.add_argument('--prefix', type=str, default='pre')
    parser.add_argument('--batch_size', type=int, default=4)
    args = parser.parse_args()
    
    with torch.no_grad():
        main(args.ckpt_dir,)

