from datasets import load_dataset, disable_caching, Dataset
from .instructions import map_gsm8k, map_summeval, map_ace05
from .ee_template import eve_template_generator
import json
import random
import os


def get_dataset(dataset_name, instruction_format, num_shots, split=None, score_type=None):
    disable_caching()
    
    if dataset_name == "gsm8k":
        dataset = load_dataset("openai/gsm8k", "main")
        data_file = ""
        few_shot_pool = Gsm8kFewShotPool(data_file)

        for split in dataset.keys():
            dataset[split] = dataset[split].map(lambda x: map_gsm8k(x, instruction_format, few_shot_pool=few_shot_pool, num_shots=num_shots))
        print(f'*** PROMPT EXAMPLE ***\n{dataset["test"][0]["full_prompt"]}\n*** *** *** *** *** ***')
        return dataset
    
    elif dataset_name == "summeval":
        dataset = load_dataset("json", data_files="", split="train")
        dataset = dataset.map(lambda x: map_summeval(x, instruction_format, score_type, num_shots=num_shots))
        print(f'*** PROMPT EXAMPLE ***\n{dataset[0]["full_prompt"]}\n*** *** *** *** *** ***')
        return dataset
    
    elif dataset_name == "ace05-en":
        if split == 'test' or split == 'dev':
            dataset = build_ace05_dataset(split)
            few_shot_pool = build_ace05_few_shot_pool("train")
        elif split == 'train':
            dataset = build_ace05_dataset(split)
            few_shot_pool = build_ace05_few_shot_pool("dev")
        else:
            raise ValueError(f"Invalid split: {split}")
        dataset = dataset.map(lambda x: map_ace05(x, instruction_format, few_shot_pool=few_shot_pool, num_shots=num_shots))
        dataset = dataset.remove_columns(set(dataset.column_names) - set(["full_prompt", "label"]))
        print(f'*** PROMPT EXAMPLE ***\n{dataset[0]["full_prompt"]}\n*** *** *** *** *** ***')
        return dataset


class Gsm8kFewShotPool():
    def __init__(self, data_file):
        self.cot_examples = []
        with open(data_file, 'r', encoding='utf-8') as file:
            content = file.read()

        for q in content.split("Question:"):
            if q.strip():
                lines = q.strip().split("\n", 1)
                question_text = lines[0].strip()
                cot_answer = lines[1].strip() if len(lines) > 1 else ""
                self.cot_examples.append({"question": question_text, "cot_answer": cot_answer})
        print(f'*** Loaded {len(self.cot_examples)} examples from {data_file}')
                
    def get_few_shot_examples(self, num_shots):
        pooled_examples = random.sample(self.cot_examples, num_shots)
        ret = ""
        for i in range(num_shots):
            example = pooled_examples[i]
            ret += f'Example {i+1}:\nQuestion: {example["question"]}\n{example["cot_answer"]}\n\n'
        return ret
    

def build_ace05_dataset(split):
    dataset = []
    file = ""
    with open(file, 'r', encoding='utf-8') as fp:
        lines = fp.readlines()
    objs = [json.loads(line) for line in lines]
    data, type_set = load_EAE_data(objs)
    data = prepare_data_for_EAE_CD(data, "ace05-en")
    dataset.extend(data)
    print('Loaded {} EAE instances ({} trigger types and {} role types) from {}'.format(
        len(dataset), len(type_set["trigger"]), len(type_set["role"]), file))
    print("There are {} trigger types and {} role types in total".format(
        len(type_set["trigger"]), len(type_set["role"])))
    columns = dataset[0].keys()
    dataset_dict = {}
    for col in columns:
        dataset_dict[col] = [x[col] for x in dataset]
    ds = Dataset.from_dict(dataset_dict)
    return ds


def load_EAE_data(data, add_extra_info_fn=None, config=None):
    instances = []
    for dt in data:
        entities = dt['entity_mentions']
        event_mentions = dt['event_mentions']
        event_mentions.sort(key=lambda x: x['trigger']['start'])
        entity_map = {entity['id']: entity for entity in entities}
        for i, event_mention in enumerate(event_mentions):
            trigger = (event_mention['trigger']['start'], 
                       event_mention['trigger']['end'], 
                       event_mention['event_type'], 
                       event_mention['trigger']['text'])
            arguments = []
            for arg in event_mention['arguments']:
                mapped_entity = entity_map[arg['entity_id']]
                argument = (mapped_entity['start'], mapped_entity['end'], arg['role'], arg['text'])
                arguments.append(argument)
            arguments.sort(key=lambda x: (x[0], x[1]))
            instance = {"doc_id": dt["doc_id"], 
                        "wnd_id": dt["wnd_id"], 
                        "tokens": dt["tokens"], 
                        "text": dt["text"], 
                        "trigger": trigger, 
                        "arguments": arguments, 
                       }
            instances.append(instance)
            
    trigger_type_set = set()
    for instance in instances:
        trigger_type_set.add(instance['trigger'][2])
    role_type_set = set()
    for instance in instances:
        for argument in instance["arguments"]:
            role_type_set.add(argument[2])
                
    type_set = {"trigger": trigger_type_set, "role": role_type_set}
    if add_extra_info_fn is not None:
        instances = add_extra_info_fn(instances, data, config)
    return instances, type_set


def prepare_data_for_EAE_CD(data, dataset_name):
    input_style = ["event_type", "event_type_sent", "triggers", "template", "na_token"]
    output_style = ["argument:sentence"]

    event_data = []
    for dt in data:
        _trigger = (dt["trigger"][0], dt["trigger"][1], dt["trigger"][2])
        _arguments = [(_trigger, (r[0], r[1], r[2])) for r in dt["arguments"]]
        event_template = eve_template_generator(dataset_name, dt["tokens"], [_trigger], _arguments, input_style, output_style, None, False)
        processed_data = event_template.get_processed_data()
        processed_data["doc_id"] = dt["doc_id"]
        processed_data["wnd_id"] = dt["wnd_id"]
        processed_data["tokens"] = dt["tokens"]
        processed_data["trigger"] = tuple(str(e) for e in dt["trigger"])
        processed_data["arguments"] = [tuple(str(e) for e in arg) for arg in dt["arguments"]]
        event_data.append(processed_data)
    return event_data


class Ace05FewShotPool():
    def __init__(self, dataset):
        self.dataset = dataset
        self.types_to_examples = {}
        for example in dataset:
            event_type = example["event_type"]
            if event_type not in self.types_to_examples:
                self.types_to_examples[event_type] = []
            self.types_to_examples[event_type].append(example)
    
    def get_few_shot_examples(self, num_shots, event_type):
        event_pool = self.types_to_examples[event_type]
        pooled_examples = random.sample(event_pool, num_shots)
        
        ret = ""
        for i in range(num_shots):
            example = pooled_examples[i]
            passage_with_tag = example["passage"]
            passage_with_tag = passage_with_tag.replace(f' {example["gold_event"]["trigger text"]} ', f' [t] {example["gold_event"]["trigger text"]} [/t] ')
            
            ret += f'Example {i+1}:\nText: {passage_with_tag}\n'
            
            valid_roles = example["info_dict"]["valid roles"]
            extracted_roles_to_word = {}
            for extracted_tup in example["arguments"]:
                extracted_roles_to_word[extracted_tup[2]] = extracted_tup[3]
            
            for role in valid_roles:
                if role in extracted_roles_to_word:
                    ret += f'{role}: {extracted_roles_to_word[role]}\n'
                else:
                    ret += f'{role}: \n'
            ret += '\n'
        return ret


def build_ace05_few_shot_pool(pool_split):
    pool = []
    file = ""
    with open(file, 'r', encoding='utf-8') as fp:
        lines = fp.readlines()
    objs = [json.loads(line) for line in lines]
    data, type_set = load_EAE_data(objs)
    data = prepare_data_for_EAE_CD(data, "ace05-en")
    pool.extend(data)
    columns = pool[0].keys()
    pool_dict = {}
    for col in columns:
        pool_dict[col] = [x[col] for x in pool]
    pool = Dataset.from_dict(pool_dict)
    return Ace05FewShotPool(pool)