
import json
import os.path
import random

import jsonlines
import numpy as np
import pandas as pd
from datasets import Dataset, DatasetDict, concatenate_datasets
from datasets import load_dataset, concatenate_datasets
from tqdm import tqdm
from transformers import AutoTokenizer
import re
import json
from collections import defaultdict
from tqdm import tqdm
from tabulate import tabulate
from transformers import (
    CONFIG_MAPPING,
    MODEL_MAPPING,
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    SchedulerType,
    get_scheduler,
)
import pprint
from termcolor import colored
import glob

moz_path = '/sdb/zke4/data_dialogue'
tokenizer = AutoTokenizer.from_pretrained('facebook/bart-large', use_fast=not False)
pp = pprint.PrettyPrinter(indent=4)
allowed_ACT_list = ["INFORM","CONFIRM","OFFER","NOTIFY_SUCCESS","NOTIFY_FAILURE","INFORM_COUNT"]


def get_data_loaders(tokenizer, test=False):
    """ Prepare the dataset for training and evaluation """
    aggregate = get_datasets(dataset_list='SGD',setting='single',verbose=False,develop=False)

    if(test):
        datasets = {"test":{}}
    else:
        datasets = {"train":{}, "dev": {}, "test":{}}

    for split in datasets.keys():
        for tasks_id, task in aggregate["BYDOMAIN"][split].items():
            datasets[split][tasks_id] = get_NLG_from_dial(task,tasks_id,tokenizer)


    task_id_train = set(task_id for task_id, dataset_task in datasets["train"].items())
    task_id_dev = set(task_id for task_id, dataset_task in datasets["dev"].items())
    task_id_test = set(task_id for task_id, dataset_task in datasets["test"].items())
    common_task_id = list(task_id_train & task_id_dev & task_id_test)

    ### LOGGING SOME INFORMATION ABOUT THE TASKS
    print(f"Tasks: {common_task_id}")
    print(f"Num of Tasks {len(common_task_id)}")
    task = defaultdict(lambda:defaultdict(str))
    for split in ["train","dev","test"]:
        for task_id, dataset_task in datasets[split].items():
            task[task_id][split] = len(dataset_task)
    table = []
    for dom_name, split_len in task.items():
        table.append({"dom":dom_name, "train":split_len["train"], "dev":split_len["dev"], "test":split_len["test"]})
    print(tabulate(table, headers="keys"))

    common_domain = [
        "['sgd_services']", "['sgd_flights']", "['sgd_buses']", "['sgd_ridesharing']", "['sgd_rentalcars']",
        "['sgd_homes']", "['sgd_music']", "['sgd_events']", "['sgd_banks']", "['sgd_hotels']",
        "['sgd_calendar']", "['sgd_media']", "['sgd_movies']", "['sgd_restaurants']", "['sgd_alarm']",
        "['sgd_weather']", "['sgd_travel']", "['sgd_payment']", "['sgd_trains']"
    ]
    for domain_id,domain in enumerate(
            [
                '/sgd_services/', '/sgd_flights/', '/sgd_buses/', '/sgd_ridesharing/', '/sgd_rentalcars/',
                '/sgd_homes/', '/sgd_music/', '/sgd_events/', '/sgd_banks/', '/sgd_hotels/',
                '/sgd_calendar/', '/sgd_media/', '/sgd_movies/', '/sgd_restaurants/', '/sgd_alarm/',
                '/sgd_weather/', '/sgd_travel/', '/sgd_payment/', '/sgd_trains/'
            ]):
        for split in ["train","dev","test"]:
            if split == 'dev':
                split_ = 'validation'
            else:
                split_ = split
            with jsonlines.open(moz_path + domain + split_ + '.jsonl', 'w') as f_json:
                for instance in datasets[split][common_domain[domain_id]]:
                    f_json.write(instance)



    # train_loaders = {}
    # valid_loaders = {}
    # train_datasets = {}
    # val_datasets = {}
    # if(args.continual):
    #     if(not test):
    #         for task_id, dataset_task in datasets["train"].items():
    #             if(task_id in common_task_id):
    #                 train_loaders[task_id] = DataLoader(DatasetTrain(dataset_task), batch_size=args.train_batch_size, shuffle=True,collate_fn=partial(collate_fn_, tokenizer=tokenizer))
    #                 train_datasets[task_id] = dataset_task
    #         for task_id, dataset_task in datasets["dev"].items():
    #             if(task_id in common_task_id):
    #                 valid_loaders[task_id] = DataLoader(DatasetTrain(dataset_task), batch_size=args.valid_batch_size, shuffle=False,collate_fn=partial(collate_fn_, tokenizer=tokenizer))
    #                 val_datasets[task_id] = dataset_task
    # elif(args.multi):
    #     if(not test):
    #         dataset_train = []
    #         for task_id, dataset_task in datasets["train"].items():
    #             if(task_id in common_task_id):
    #                 dataset_train += dataset_task
    #         train_loaders = DataLoader(DatasetTrain(dataset_train), batch_size=args.train_batch_size, shuffle=True,collate_fn=partial(collate_fn_, tokenizer=tokenizer))
    #
    #         dataset_dev = []
    #         for task_id, dataset_task in datasets["dev"].items():
    #             if(task_id in common_task_id):
    #                 dataset_dev += dataset_task
    #         valid_loaders = DataLoader(DatasetTrain(dataset_dev), batch_size=args.valid_batch_size, shuffle=False,collate_fn=partial(collate_fn_, tokenizer=tokenizer))
    #
    # temp_list = []
    # for task_id, dataset_task in datasets["test"].items():
    #     if(task_id in common_task_id):
    #         temp_list.append(dataset_task)
    # test_datasets = sum(temp_list,[])
    # test_loaders = DataLoader(DatasetTrain(sum(temp_list,[])), batch_size=args.valid_batch_size, shuffle=False,collate_fn=partial(collate_fn_, tokenizer=tokenizer))
    #
    # ### THIS IS JUST FOR CHECKING DUPLICATE DIALOGUES
    # testing_dict = defaultdict(list)
    # for idx_b, batch in tqdm(enumerate(test_loaders),total=len(test_loaders)):
    #     for (d_id, t_id, ta_id) in zip(batch["dial_id"],batch["turn_id"],batch["task_id"]):
    #         if(f'{d_id}_{t_id}_{ta_id}' not in testing_dict):
    #             testing_dict[f'{d_id}_{t_id}_{ta_id}'].append(1)
    #         else:
    #             print(f'{d_id}_{t_id}_{ta_id}')

    # return train_loaders, valid_loaders, test_loaders, (train_datasets,val_datasets,test_datasets)


def get_datasets(dataset_list=['SGD'],setting="single",verbose=False,develop=False):

    table = []
    train = []
    dev = []
    test = []


    if("SGD" in dataset_list):
        print("LOAD SGD")
        train_SGD, dev_SGD, test_SGD = preprocessSGD(develop=develop)
        if(verbose):
            print_sample(train_SGD,2)
            input()
        n_domain, n_intent, n_turns, _ = get_domains_slots(train_SGD)
        table.append({"Name":"SGD","Trn":len(train_SGD),"Val":len(dev_SGD),"Tst":len(test_SGD),"Dom":n_domain,"Int":n_intent,"Tur":n_turns})
        train += train_SGD
        dev += dev_SGD
        test += test_SGD



    n_domain, n_intent, n_turns, services = get_domains_slots(train)
    table.append({"Name":"TOT","Trn":len(train),"Val":len(dev),"Tst":len(test),"Dom":n_domain,"Int":n_intent,"Tur":n_turns})
    test = filter_services(test,services) ## Remove dialogue with API not present in the test set
    dev = filter_services(dev,services) ## Remove dialogue with API not present in the test set
    n_domain, n_intent, n_turns, services = get_domains_slots(train)
    if(verbose):
        for inten in services:
            print(inten)
        input()
    table.append({"Name":"TOT","Trn":len(train),"Val":len(dev),"Tst":len(test),"Dom":n_domain,"Int":n_intent,"Tur":n_turns})
    print(tabulate(table, headers="keys"))

    return {"TOTAL":{"train":train,"dev":dev,"test":test},
               "BYDOMAIN":{"train":split_by_domain(train,setting),
                            "dev":split_by_domain(dev,setting),
                            "test":split_by_domain(test,setting)}
                            }


def split_by_domain(data,setting):
    data_by_domain = defaultdict(list)
    if setting=="single":
        for dial in data:
            if(len(dial["services"])==1):
                data_by_domain[str(sorted(dial["services"]))].append(dial)
        print("SINGLE DOMAIN:",len(data_by_domain.keys()))

    elif setting=="multi":
        data_by_domain = defaultdict(list)
        for dial in data:
            data_by_domain[str(sorted(dial["services"]))].append(dial)
        print("ALL DOMAIN:",len(data_by_domain.keys()))
    else:
        print("choose a setting: --setting single or --setting multi")
        exit(1)
    # for d,v in sorted(data_by_domain.items() ,  key=lambda x: len (eval(x[0]))):
    #     print(d)
    return dict(sorted(data_by_domain.items() ,  key=lambda x: len (eval(x[0]))))



def filter_services(data,serv):
    filtered_dialogue = []
    for dial in data:
        flag_temp = True
        for turn in dial['dialogue']:
            if(turn["spk"]=="API"):
                for s in turn["service"]:
                    if s not in serv:
                        flag_temp = False
        if(flag_temp):
            filtered_dialogue.append(dial)
    return filtered_dialogue


def print_sample(data,num):
    color_map = {"USER":"blue","SYSTEM":"magenta","API":"red","API-OUT":"green"}
    for i_d, dial in enumerate(random.sample(data,len(data))):
        print(f'ID:{dial["id"]}')
        print(f'Services:{dial["services"]}')
        for turn in dial['dialogue']:
            print(colored(f'{turn["spk"]}:',color_map[turn["spk"]])+f' {turn["utt"]}')
        if i_d == num: break


def get_domains_slots(data):
    services = set()
    intent = set()
    len_dialogue = []
    for dial in data:
        for s in dial["services"]:
            services.add(s)
        len_dialogue.append(len([0 for t in dial['dialogue'] if t["spk"] in ["USER","SYSTEM"]]))
        for turn in dial['dialogue']:
            if(turn["spk"]=="API"):
                for s in turn["service"]:
                    if(" " in s or len(s)==1):
                        print(s)
                        print(turn)
                        input()
                    intent.add(s)
    print("Domain",len(services))
    print("Intent",len(intent))
    print("Avg. Turns",np.mean(len_dialogue))
    return len(services), len(intent), np.mean(len_dialogue), intent


def get_NLG_from_dial(data,task_id,tokenizer):
    dialogues = []
    utt_len = []
    hist_len = []
    for dial in data:
        plain_history = []
        latest_API_OUT = "API-OUT: "
        for idx_t, t in enumerate(dial['dialogue']):
            ## DUPLICATE DIALOGUE
            if f'{t["id"]}' == "dlg-ff2b8de2-467d-4917-be13-1529765752e9":
                continue
            if(t['spk']=="USER"):
                plain_history.append(f"{t['spk']}: {t['utt'].strip()}")
            elif(t['spk']=="API-OUT"):
                latest_API_OUT = f"{t['utt'].strip()}"
            elif((t['spk'] == "SYSTEM") and idx_t!=0 and t["utt"]!= ""):
                if(latest_API_OUT != ""):
                    dialogues.append({"history":latest_API_OUT,
                                    "reply":f'{t["utt"].strip()} {tokenizer.eos_token}',
                                    "history_reply": latest_API_OUT + f'[SOS]{t["utt"].strip()} {tokenizer.eos_token}',
                                    "spk":t["spk"],
                                    "dataset":t["dataset"],
                                    "dial_id":t["id"],
                                    "turn_id":t["turn_id"],
                                    "task_id":task_id})
                plain_history.append(f"{t['spk']}: {t['utt'].strip()}")
                latest_API_OUT = ""

    for d in random.sample(dialogues,len(dialogues)):
        pp.pprint(d)
        break
    print()
    input()
    return dialogues

def preprocessSGD_(split, develop=False):
    data = []
    files = glob.glob(moz_path+f"/dstc8-schema-guided-dialogue/{split}/*.json")
    for i_f, f in tqdm(enumerate(files),total=len(files)):
        if "schema.json" not in f:
            dialogue = json.load(open(f))
            for d in dialogue:
                # dial = {"id":d["dialogue_id"], "services": [ remove_numbers_from_string(s) for s in d["services"]],"dataset":"SGD"}
                dial = {"id":d["dialogue_id"], "services": ["SGD_"+s for s in d["services"]],"dataset":"SGD"}
                turns =[]
                dst_prev = []
                for t_idx, t in enumerate(d["turns"]):
                    if(t["speaker"]=="USER"):
                        turns.append({"dataset":"SGD","id":d["dialogue_id"],"turn_id":t_idx,"spk":t["speaker"],"utt":t["utterance"]})
                        dst_api = get_dict(t["frames"])
                        str_API = ''
                        serv = []
                        intent_list = set()
                        for frame in t["frames"]:
                            serv.append(remove_numbers_from_string(frame["service"]))
                            if(len(frame['state']['requested_slots'])>0):
                                for slot in frame['state']['requested_slots']:
                                    dst_api[frame['state']['active_intent']][slot] = "?"
                        for intent, slots in dst_api.items():
                            str_API += f'{intent}('
                            for s,v in slots.items():
                                if(s!='none'):
                                    str_API += f'{s}="{v.replace("(","").replace(")","")}",'
                                    intent_list.add(remove_numbers_from_string(intent))
                            if(str_API[-1]==","):
                                str_API = str_API[:-1]
                            str_API += ") "
                        turns.append({"dataset":"SGD","id":d["dialogue_id"],"turn_id":t_idx,"spk":"API","utt":str_API,"service":list(intent_list),"service_dom":serv})
                    else:
                        group_by_act = defaultdict(list)
                        for a in t["frames"][0]["actions"]:
                            serv.append(t["frames"][0]["service"])
                            if(a["act"] in allowed_ACT_list):
                                group_by_act[a["act"]].append([a["slot"],a["values"]])
                        str_ACT = ''
                        serv = []
                        for k,v in group_by_act.items():
                            str_ACT += f"{k}("
                            for [arg, val] in v:
                                if(len(val)>0):
                                    val = val[0].replace('"',"'")
                                    str_ACT += f'{arg}="{val}",'
                            if(str_ACT[-1]==","):
                                str_ACT = str_ACT[:-1]
                            str_ACT += ") "

                        turns.append({"dataset":"SGD","id":d["dialogue_id"],"turn_id":t_idx,"spk":"API-OUT","utt":str_ACT,"service":serv})
                        turns.append({"dataset":"SGD","id":d["dialogue_id"],"turn_id":t_idx,"spk":t["speaker"],"utt":t["utterance"]})
                dial["dialogue"] = turns
                data.append(dial)
            if(develop and i_f==1): break

    return data

def rename_service_dialogue(dial_split,name):
    new_dial = []
    for dial in dial_split:
        dial["services"] = eval(name)
        new_dial.append(dial)
    return new_dial

def preprocessSGD(develop=False):

    data = preprocessSGD_("train", develop=develop)
    data += preprocessSGD_("dev", develop=develop)
    data += preprocessSGD_("test", develop=develop)

    data_by_domain = defaultdict(list)
    for dial in data:
        if(len(dial["services"])==1):
            data_by_domain[str(sorted(dial["services"]))].append(dial)

    data_by_domain_new = defaultdict(list)
    for dom, data in data_by_domain.items():
        train_data, dev_data, test_data = np.split(data, [int(len(data)*0.7), int(len(data)*0.8)])
        data_by_domain_new[str(sorted([remove_numbers_from_string(s) for s in eval(dom)]))].append([train_data, dev_data, test_data])

    train_data, valid_data, test_data = [], [], []
    table = []
    for dom, list_of_split in data_by_domain_new.items():
        train, valid, test = [], [], []
        for [tr,va,te] in list_of_split:
            train += rename_service_dialogue(tr,dom)
            valid += rename_service_dialogue(va,dom)
            test += rename_service_dialogue(te,dom)
        table.append({"dom":dom, "train":len(train), "valid":len(valid), "test":len(test)})
        train_data += train
        valid_data += valid
        test_data += test

    return train_data,valid_data,test_data


def get_domains(goal):
    dom = []
    for d, g in goal.items():
        if(len(g)!=0) and d!= "message" and d!= "topic":
            dom.append("MWOZ_"+d)
    return dom

def loadCSV(split):
    split_id = []
    with open(moz_path+f"/multiwoz/data/MultiWOZ_2.1/{split}ListFile.txt") as f:
        for l in f:
            split_id.append(l.replace("\n",""))
    return split_id

def get_value_dst(DST):
    active_dst = defaultdict(list)
    for k,v in DST.items():
        for k_s, v_s in v['semi'].items():
            if(len(v_s)!=0):
                active_dst[k].append([k_s, v_s])
        for k_s, v_s in v['book'].items():
            if(len(v_s)!=0 and k_s != "booked"):
                active_dst[k].append([k_s, v_s])
    return active_dst

def remove_numbers_from_string(s):
    numbers = re.findall(r'_\d+', s)
    for n in numbers:
        s = s.replace(n,"")
    s = s.lower()
    if(s=="hotels"): s = "hotel"
    if(s=="restaurants"): s = "restaurant"
    if(s=="flights"): s = "flight"
    if(s=="movies"): s = "movie"
    return s.lower()

def get_dict(DST):
    di = defaultdict(lambda: defaultdict(str))
    for frame in DST:
        if(frame["state"]["active_intent"]!="NONE"):
            for k,v in frame["state"]['slot_values'].items():
                di[frame["state"]["active_intent"]][k] = v[0]
    return di


get_data_loaders(tokenizer)
