import json
from torch.utils.data import Dataset
import random
import numpy as np
from utils import EditBatchSampler, dict_to, build_distr_matrix
import torch
import sys

per_list = [
    "extraversion",
    "agreeableness", 
    "neuroticism"
]

per2id = {
    "extraversion":0,
    "agreeableness":1, 
    "neuroticism":2
}

class PersonalityDataset(Dataset):
    def __init__(
        self,
        tokenizer,
        data_path,
        config,
        max_length=96,
    ):
        super().__init__()
        self.tok = tokenizer
        self.config = config
        self.max_length = max_length

        with open(data_path) as f:
            self.data = json.load(f)

        self.templates = [
            "What do you think of {}?",
            "What do you feel about {}?",
            "How do you view {}?",
        ]
        for position in [
            "opinion of",
            "stance on",
            "position on",
            "attitude about",
            "view on",
            "take on",
            "impression of",
            "assessment of",
            "judgment of",
            "sentiment of",
        ]:
            self.templates.append("What is your " + position + " {}?")

        self.loc_distr_matrix, self.loc_idx_matrix = None, None
        if self.config.data.hard_neg and "train" in data_path:
            edit_qs = [sample["ent"] for sample in self.data]
            self.loc_distr_matrix, self.loc_idx_matrix = build_distr_matrix(edit_qs, config)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

    def sample_completions(self, idx, do_sample=True):
        sample = self[idx]
        inner_per = random.choice([0, 1, 2]) if "target_per" not in sample.keys() else per2id[sample["target_per"]] # 测试集generate的时候固定personality

        inner_per_text = per_list[inner_per] # three type of personality
        
        self.trait = inner_per_text

        inner_comp = ["Target Personailty: " + inner_per_text + "\n"]
        inner_prompt = ["Topic: " + sample["ent"] + "\n"]

        if do_sample:
            pass
        else:
            # 这里的outer可能反而是inner，只是为了适配所有的所以这么写
            outer_per = ([inner_per] * len(sample[inner_per_text]))
            outer_comp = sample[inner_per_text]
            outer_temp = random.choices(self.templates, k=len(outer_per))
            outer_prompt = [t.format(sample["ent"]) for t in outer_temp]

        all_per, all_comp = [], []
        
        for idx, per in enumerate(per_list):
            all_per += ([idx] * len(sample[per]))
            all_comp += sample[per]
            
        all_temp = random.choices(self.templates, k=len(all_per))
        all_prompt = [t.format(sample["ent"]) for t in all_temp]

        # print("outer_per:", outer_per)
        # print("outer_comp:", outer_comp)
        # print("outer_temp:", outer_temp)
        # print("outer_prompt:", outer_prompt)

        # print("all_per:", all_per)
        # print("all_comp:", all_comp)
        # print("all_temp:", all_temp)
        # print("all_prompt:", all_prompt)
        
        # sys.exit(0)

        return {
            "ent": sample["ent"],
            "inner_prompt": inner_prompt,
            "inner_comp": inner_comp,
            "inner_per": inner_per,
            "outer_prompt": outer_prompt,
            "outer_comp": outer_comp,
            "all_prompt": all_prompt,
            "all_per": all_per,
            "all_comp": all_comp,
        }

    def get_edit_labels(self, ids, prompts=None):
        labels = ids.clone()
        labels[labels == self.tok.pad_token_id] = -100
        return labels

    def collate_fn(self, batch):
        inner_prompt = [prompt for b in batch for prompt in b["inner_prompt"]]
        inner_comp = [comp for b in batch for comp in b["inner_comp"]]
        outer_prompt = [prompt for b in batch for prompt in b["outer_prompt"]]
        outer_comp = [comp for b in batch for comp in b["outer_comp"]]
        all_prompt = [prompt for b in batch for prompt in b["all_prompt"]]
        all_comp = [comp for b in batch for comp in b["all_comp"]]

        batches = {
            f"{k1}_{k2}": v2
            for k1, v1 in {
                "inner_q": inner_prompt,
                "inner_a": inner_comp,
                "outer_q": outer_prompt,
                "outer_a": outer_comp,
                "all_q": all_prompt,
                "all_a": all_comp,
            }.items()
            for k2, v2 in self.tok(
                v1,
                return_tensors="pt",
                padding=True,
                max_length=self.max_length,
                truncation=True,
            ).items()
        }

        batches["all_per"] = [s for b in batch for s in b["all_per"]]
        batches["inner_per"] = [b["inner_per"] for b in batch for s in b["all_per"]]
        batches["raw"] = batch

        pos_pairs = []
        for idx, b in enumerate(batch):
            for _ in range(len(b["all_prompt"])):
                pos_pairs.append([len(pos_pairs), idx])

        batches["pos_pairs"] = torch.LongTensor(pos_pairs)
        return batches
    
    def collate_fn_gpt(self, batch):
        inner_prompt = [prompt for b in batch for prompt in b["inner_prompt"]]
        inner_comp = [comp for b in batch for comp in b["inner_comp"]]
        outer_prompt = [prompt for b in batch for prompt in b["outer_prompt"]]
        outer_comp = [comp for b in batch for comp in b["outer_comp"]]
        all_prompt = [prompt for b in batch for prompt in b["all_prompt"]]
        all_comp = [comp for b in batch for comp in b["all_comp"]]
        
        # print("inner_prompt:", inner_prompt)
        # print("inner_comp:", inner_comp)
        # print("outer_prompt:", outer_prompt)
        # print("outer_comp:", outer_comp)
        # print("all_prompt:", all_prompt)
        # print("all_comp:", all_comp)
        
       # inner_qa = [ "Exhibit the trait of {Target Personality} when expressing opinion on the cetarin {Edit Topic}, while maintaining the expression on other topics." + q + " </s> " + a for q, a in zip(inner_prompt, inner_comp)]
        outer_qa = [ "Question: " + q + "\n </s> Answer: " + a for q, a in zip(outer_prompt, outer_comp)]
        all_qa = [ "Question: " + q + " \n </s> Answer: " + a for q, a in zip(all_prompt, all_comp)]
        
        inner_qa = [ f"{q}  {a} " + outer_qa[0] for q, a in zip(inner_prompt, inner_comp)]

        # all_per = [s for b in batch for s in b["all_per"]]
        # inner_per = [b["inner_per"] for b in batch for s in b["all_per"]]
        
        # print("all_per:", all_per)
        # print("inner_per:", inner_per)
        # outer_q = []
        # for i, q in enumerate(all_prompt):
        #     if all_per[i] == inner_per[i]:
        #         outer_q.append(q)
        
        # print("len(outer_q):", len(outer_q))
        
        try:
            batches = {
                f"{k1}_{k2}": v2
                for k1, v1 in {
                    "inner_qa": inner_qa,
                    "outer_qa": outer_qa,
                    "all_qa": all_qa,
                    "outer_q": outer_prompt,
                }.items()
                for k2, v2 in self.tok(
                    v1,
                    return_tensors="pt",
                    padding=True,
                    max_length=self.max_length,
                    truncation=True,
                ).items()
            }
        except Exception as e:
            print(e)
            print("inner_qa:", inner_qa)
            print("outer_qa:", outer_qa)
            print("all_qa:", all_qa)
            sys.exit(0)
        
        
        for key in ["inner_qa", "outer_qa", "all_qa"]:
            value = batches[f"{key}_input_ids"]
            mask = [([True] * value.shape[-1])] * value.shape[0]
            for i in range(value.shape[0]):
                sep_idx = list(value[i]).index(self.tok.convert_tokens_to_ids("</s>"))
                for j in range(sep_idx): #连带</s>一块mask掉
                    mask[i][j] = False
            batches[key + "_q_mask"] = mask 
                    

        batches["all_per"] = [s for b in batch for s in b["all_per"]]
        batches["inner_per"] = [b["inner_per"] for b in batch for s in b["all_per"]]
        batches["raw"] = batch

        pos_pairs = []
        for idx, b in enumerate(batch):
            for _ in range(len(b["all_prompt"])):
                pos_pairs.append([len(pos_pairs), idx])

        batches["pos_pairs"] = torch.LongTensor(pos_pairs)
        
        return batches

    def edit_generator(self, batch_size, n=None, do_sample=False):
        if n is None:
            n = len(self)
            
        sampler = EditBatchSampler(
            n,
            memorize_mode=self.config.single_batch,
            loc_disjoint=True,
            seed=self.config.seed,
            hard_neg=self.config.data.hard_neg,
            hard_neg_prob=self.config.data.hard_neg_prob,
            loc_distr_matrix=self.loc_distr_matrix,
            loc_idx_matrix=self.loc_idx_matrix,
        )

        while True:
            edit_idxs, loc_idxs = sampler.sample(batch_size)
            
            if "gpt" not in self.config.model_name:
                edit_toks = self.collate_fn([self.sample_completions(idx, do_sample) for idx in edit_idxs])
                loc_toks = self.collate_fn([self.sample_completions(idx, do_sample) for idx in loc_idxs])

                edit_inner = {
                    "input_ids": edit_toks["inner_q_input_ids"],
                    "attention_mask": edit_toks["inner_q_attention_mask"],
                    "labels": self.get_edit_labels(edit_toks["inner_a_input_ids"]),
                    "decoder_input_ids": edit_toks["inner_a_input_ids"],
                    "decoder_attention_mask": edit_toks["inner_a_attention_mask"],
                }

                edit_outer = {
                    "input_ids": edit_toks["all_q_input_ids"],
                    "attention_mask": edit_toks["all_q_attention_mask"],
                    "labels": self.get_edit_labels(edit_toks["all_a_input_ids"]),
                    "decoder_input_ids": edit_toks["all_a_input_ids"],
                    "decoder_attention_mask": edit_toks["all_a_attention_mask"],
                }

                loc = {
                    "input_ids": loc_toks["inner_q_input_ids"],
                    "attention_mask": loc_toks["inner_q_input_ids"],
                    "labels": self.get_edit_labels(loc_toks["inner_a_input_ids"]),
                    "decoder_input_ids": loc_toks["inner_a_input_ids"],
                    "decoder_attention_mask": loc_toks["inner_a_attention_mask"],
                }
            
            else:
                edit_toks = self.collate_fn_gpt([self.sample_completions(idx, do_sample) for idx in edit_idxs])
                loc_toks = self.collate_fn_gpt([self.sample_completions(idx, do_sample) for idx in loc_idxs])
                
                edit_inner = {
                    "input_ids": edit_toks["inner_qa_input_ids"],
                    "attention_mask": edit_toks["inner_qa_attention_mask"],
                    "labels": self.get_edit_labels(edit_toks["inner_qa_input_ids"]),
                    # "q_mask": edit_toks["inner_qa_q_mask"]
                }
                
                edit_inner_q = {
                    "input_ids": edit_toks["outer_q_input_ids"],
                    "attention_mask": edit_toks["outer_q_attention_mask"],
                    "labels": self.get_edit_labels(edit_toks["outer_q_input_ids"]),
                }
                
                edit_outer = {
                    "input_ids": edit_toks["all_qa_input_ids"],
                    "attention_mask": edit_toks["all_qa_attention_mask"],
                    "labels": self.get_edit_labels(edit_toks["all_qa_input_ids"]),
                    "q_mask": edit_toks["all_qa_q_mask"]
                }

                loc = {
                    "input_ids": loc_toks["all_qa_input_ids"],
                    "attention_mask": loc_toks["all_qa_attention_mask"],
                    "labels": self.get_edit_labels(loc_toks["all_qa_input_ids"]),
                    "q_mask": loc_toks["all_qa_q_mask"]
                }

            pos_pairs = edit_toks["pos_pairs"]

            batch = {
                "edit_inner": edit_inner,
                "edit_outer": edit_outer,
                "outer_per": edit_toks["all_per"],
                "inner_per": edit_toks["inner_per"],
                "trait": self.trait,
                "inner_q": edit_inner_q,
                "loc": loc,
                "cond": edit_inner,
                "pos_pairs": pos_pairs,
            }

            yield dict_to(batch, self.config.device)


if __name__ == "__main__":
    pass
