from typing import List
import os
import re
import copy
import json
import random
import glob
from tqdm import tqdm

from template import TEMPLATES


def load_dataset(datadir: str, datatype: str):
    dataset = []
    #for dir in glob.glob(os.path.join(datadir, datatype, "*.json")):
    with open(os.path.join(datadir, f"{datatype}.json"), 'r') as f:
        return json.load(f)
        #dataset.extend(json.load(f))
    
    #return dataset

def build_dataset(dataset: List):
    results = []
    for d_idx, instance in enumerate(tqdm(dataset, total=len(dataset))):
        dialog = instance['dialogue']
        photo_desc = instance['photo_description']
        photo_url = instance['photo_url']
        photo_id = instance['photo_id']

        context = []
        speaker = []
        for j, ele in enumerate(dialog):
            if ele['share_photo']:
                t_j = j
                break
            context.append(ele["message"])
            speaker.append(ele["user_id"])
        
        all_context = [ele["message"] for ele in dialog if ele["message"] != ""]
        all_speaker = [ele["user_id"] for ele in dialog if ele["message"] != ""]
        image_turn_speaker = dialog[t_j]['user_id']

        results.append({
            'dialog_idx': d_idx,
            'image_share_turn_idx': t_j,
            'context': context,
            'speaker': speaker,
            'image_share_turn_speaker': image_turn_speaker,
            'photo_description': photo_desc,
            'photo_url': photo_url,
            'photo_id': photo_id,
            "all_context": all_context,
            "all_speaker": all_speaker,
        })
    return results

def load_SSN_name_list(top_k: int = 1000):
    # following SODA paper
    all_names = {}
    for dir in glob.glob(os.path.join('./gpt3_test/names', '*.txt')):
        with open(dir, 'r', encoding='utf-8') as f:
            lines = [line.strip().split(',') for line in f.readlines()]
        
        for name, sex, count in lines:
            all_names[name] = [int(count), sex]

    sorted_names = sorted(all_names.items(), key=lambda x: x[1][0], reverse=True)[:top_k]
    sex = [ele[1][1] for ele in sorted_names]
    
    print('# of Female names: {}'.format(sex.count('F')))
    print('# of Male names: {}'.format(sex.count('M')))

    return sorted_names

def prepare_task1_prompt_dataset(datadir: str, datatype: str, template_type: str):
    dataset = load_dataset(datadir, datatype)
    dataset = build_dataset(dataset)
    name_list = load_SSN_name_list()
    template = TEMPLATES[template_type]

    prompts = []
    for instance in tqdm(dataset, total=len(dataset)):

        context = instance["all_context"]
        speaker = instance["all_speaker"]
        share_speaker = instance["image_share_turn_speaker"]
        share_turn_idx = instance["image_share_turn_idx"]
        assert len(context) == len(speaker)

        # pick two speaker randomly
        sampled_names = {i: ele for i, ele in enumerate(random.sample(name_list, 2))}

        # consecutive setting
        # we only consider the dialogue before sharing intent is triggered
        flatten_dialog = []
        for j, (utter, spk) in enumerate(zip(context, speaker)):
            flatten_dialog += [
                "{}: {}".format(sampled_names[spk][0], utter.replace(' <TURN> ', " "))
            ]

        flatten_dialog = "\n".join(flatten_dialog)
        prompt = template.format(
            spk1=sampled_names[0][0],
            spk2=sampled_names[1][0],
            dialogue=flatten_dialog,
        )

        copied_instance = copy.deepcopy(instance)
        copied_instance["prompt_input"] = prompt
        copied_instance["image_share_turn_speaker_name"] = sampled_names[share_speaker][0]

        prompts.append(copied_instance)

    print("Total # of prompts: {}".format(len(prompts)))
    return prompts

def prepare_grammar_prompt_dataset(datadir: str, datatype: str, template_type: str):
    dataset = load_dataset(datadir, datatype)
    dataset = build_dataset(dataset)
    name_list = load_SSN_name_list()
    template = TEMPLATES[template_type]
    
    prompts = []
    for instance in tqdm(dataset, total=len(dataset)):

        context = instance["all_context"]
        speaker = instance["all_speaker"]
        share_speaker = instance["image_share_turn_speaker"]
        share_turn_idx = instance["image_share_turn_idx"]
        assert len(context) == len(speaker)

        # pick two speaker randomly
        sampled_names = {i: ele for i, ele in enumerate(random.sample(name_list, 2))}

        # consecutive setting
        # we only consider the dialogue before sharing intent is triggered
        flatten_dialog = []
        for j, (utter, spk) in enumerate(zip(context, speaker)):
            flatten_dialog += [
                "{}: {}".format(sampled_names[spk][0], utter.replace(' <TURN> ', " "))
            ]

        flatten_dialog = "\n".join(flatten_dialog)
        prompt = template.format(
            #spk1=sampled_names[0][0],
            #spk2=sampled_names[1][0],
            dialogue=flatten_dialog,
        )

        copied_instance = copy.deepcopy(instance)
        copied_instance["prompt_input"] = prompt
        copied_instance["image_share_turn_speaker_name"] = sampled_names[share_speaker][0]

        prompts.append(copied_instance)

    print("Total # of prompts: {}".format(len(prompts)))
    return prompts

def prepare_task2_prompt_dataset(model_name: str, rounding_step: int, task1_dataset: List, template_type: str):
    template = TEMPLATES[template_type]
    pattern = re.compile(r'The following is a dialogue between (?P<spk1>.*?) and (?P<spk2>.*?)\..*')
    
    prompts = []
    for instance in tqdm(task1_dataset, total=len(task1_dataset)):
        golden_turn_index = instance["image_share_turn_idx"] - 1

        match = pattern.search(instance["prompt_input"])
        spk1 = match.group("spk1")
        spk2 = match.group("spk2")

        speaker_dict = dict()
        image_share_turn_speaker = instance["image_share_turn_speaker"]
        image_share_turn_speaker_name = instance["image_share_turn_speaker_name"]

        other_speaker = abs(1-image_share_turn_speaker)
        if spk1 == image_share_turn_speaker_name:
            other_speaker_name = spk2
        else:
            other_speaker_name = spk1
        
        speaker_dict[image_share_turn_speaker] = image_share_turn_speaker_name
        speaker_dict[other_speaker] = other_speaker_name

        reverse_speaker_dict = {v: k for k, v in speaker_dict.items()}

        if 'text' in model_name:
            annotated_parsed = instance["annotated_parsed_results"]
        else:
            annotated_parsed = instance[f"{rounding_step}_task1_annotated_parsed_results"]
        
        all_context = instance['all_context']
        all_speaker = instance["all_speaker"]
        for ele in annotated_parsed:
            turn_index = ele['turn_index']
            turn_speaker = ele["speaker"]
            if golden_turn_index != turn_index:
                continue
            tmp = []
            for t, (ctx, spk) in enumerate(zip(all_context, all_speaker)):
                tmp.append("{}: {}".format(
                    speaker_dict[spk], ctx
                ))
                if t == turn_index:
                    tmp.append("{}: [Sharing Image]".format(turn_speaker))
                    break
            tmp = '\n'.join(tmp)

            if spk1 == turn_speaker:
                p_spk2 = spk2
            else:
                p_spk2 = spk1
        
            prompt = template.format(
                spk1=turn_speaker,
                spk2=p_spk2,
                dialogue=tmp,
            )
        
            copied_instance = copy.deepcopy(instance)
            copied_instance["task1_annotated_result"] = ele
            copied_instance["task2_prompt_input"] = prompt
            copied_instance["task2_speakers"] = [turn_speaker, p_spk2]
            prompts.append(copied_instance)

    print("Total # of prompts: {} (before: {})".format(len(prompts), len(task1_dataset)))
    return prompts

def prepare_task2_prompt_dataset_for_photochat(datadir: str, datatype: str, template_type: str):
    dataset = load_dataset(datadir, datatype)
    dataset = build_dataset(dataset)
    name_list = load_SSN_name_list()
    template = TEMPLATES[template_type]
    
    prompts = []
    for instance in tqdm(dataset, total=len(dataset)):
        context = instance["all_context"]
        speaker = instance["all_speaker"]
        share_speaker = instance["image_share_turn_speaker"]
        share_turn_idx = instance["image_share_turn_idx"]
        assert len(context) == len(speaker)

        # pick two speaker randomly
        sampled_names = {i: ele for i, ele in enumerate(random.sample(name_list, 2))}

        flatten_dialog = []
        for j, (utter, spk) in enumerate(zip(context, speaker)):
            if j == share_turn_idx:
                flatten_dialog += [
                    "{}: {}".format(sampled_names[share_speaker][0], "[Sharing Image]")
                ]
                break
            
            flatten_dialog += [
                "{}: {}".format(sampled_names[spk][0], utter.replace(' <TURN> ', " "))
            ]
        
        if share_turn_idx == len(context):
            flatten_dialog += [
                "{}: {}".format(sampled_names[share_speaker][0], "[Sharing Image]")
            ]
        
        flatten_dialog = "\n".join(flatten_dialog)
        
        if '[Sharing Image]' not in flatten_dialog:
            print(flatten_dialog, share_turn_idx, len(context))
            assert False

        prompt = template.format(
            spk1=sampled_names[0][0],
            spk2=sampled_names[1][0],
            dialogue=flatten_dialog,
        )
        
        copied_instance = copy.deepcopy(instance)
        copied_instance["task2_prompt_input"] = prompt
        copied_instance["image_share_turn_speaker_name"] = sampled_names[share_speaker][0]
        prompts.append(copied_instance)

    print("Total # of prompts: {}".format(len(prompts)))
    return prompts
