import json
import linecache
import os
import subprocess
import pickle

import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from transformers import AutoTokenizer

import wandb
from time import time
  
def timer_func(func):
    # This function shows the execution time of 
    # the function object passed
    def wrap_func(*args, **kwargs):
        t1 = time()
        result = func(*args, **kwargs)
        t2 = time()
        print(f'Function {func.__name__!r} executed in {(t2-t1):.4f}s')
        return result
    return wrap_func

class KnowledgeGraph(object):
    """Knowledge Graph of data."""
    def __init__(self,
                 wikidata_ids,
                 ent_pos,
                 edge_index,
                 edge_attr,
                 local_indicator):
        self.wikidata_ids = wikidata_ids
        self.ent_pos = ent_pos
        self.edge_index = edge_index
        self.edge_attr = edge_attr
        self.local_indicator = local_indicator
        # self.label_ent_pos = label_ent_pos

def is_whitespace(c):
    if c == " " or c == "\t" or c == "\r" or c == "\n" or ord(c) == 0x202F:
        return True
    return False

def custom_char_to_token(start_char, end_char, text, tokenizer):
    doc_tokens = []
    char_to_word_offset = []
    prev_is_whitespace = True
    for c in text:
        if is_whitespace(c):
            prev_is_whitespace = True
        else:
            if prev_is_whitespace:
                doc_tokens.append(c)
            else:
                doc_tokens[-1] += c
            prev_is_whitespace = False
        char_to_word_offset.append(len(doc_tokens) - 1)

    start_position = char_to_word_offset[start_char]
    end_position = char_to_word_offset[end_char]

    tok_to_orig_index = []
    orig_to_tok_index = []
    all_doc_tokens = []
    for (i, token) in enumerate(doc_tokens):
        orig_to_tok_index.append(len(all_doc_tokens))
        sub_tokens = tokenizer.tokenize(token)
        for sub_token in sub_tokens:
            tok_to_orig_index.append(i)
            all_doc_tokens.append(sub_token)

    tok_start_position = orig_to_tok_index[start_position]
    if end_position < len(doc_tokens) - 1:
        tok_end_position = orig_to_tok_index[end_position + 1] - 1
    else:
        tok_end_position = len(all_doc_tokens) - 1

    return tok_start_position, tok_end_position

def char_to_token_idx(text, tokens, entity, length, prefix, tokenizer, opendialkg=False):
    start_char = entity["start"] + len(prefix)
    if opendialkg:
        end_char = entity["end"] + len(prefix)
    else:
        end_char = entity["end"] - 1 + len(prefix)
        while end_char >= len(text) or is_whitespace(text[end_char]):
            end_char -= 1

    try:
        start_idx = tokens.char_to_token(start_char)
        end_idx = tokens.char_to_token(end_char)
    except ValueError:
        start_idx, end_idx = custom_char_to_token(start_char, end_char, text, tokenizer)

    if start_idx is not None:
        start_idx += length
    if end_idx is not None:
        end_idx += length
    return start_idx, end_idx

class T5Dataset(Dataset):
    def __init__(self, jsonl_file, args):
        self.args = args
        self.is_train = 'train' in jsonl_file

        self.max_length = args.max_length
        self.max_decode_step = args.max_decode_step
        self.tokenizer = args.tokenizer
        self.file_name = jsonl_file
        self.total_size = int(subprocess.check_output(
            "wc -l " + jsonl_file, shell=True).split()[0])

        if args.lm_type == 't5':
            self.apprentice_prefix = "apprentice: "
            self.wizard_prefix = "wizard: "
            self.knowledge_prefix = "knowledge: "
            self.prefix = "dialogue: "
            self.topic_prefix = "topic: "
        else:
            self.apprentice_prefix = ""
            self.wizard_prefix = ""
            self.knowledge_prefix = ""
            self.prefix = ""

        self.label_map = args.label_map

    def without_knowledge_graph(self, index):
        line = linecache.getline(self.file_name, index + 1)
        json_dict = json.loads(line)
        if self.args.lm_type == 't5':
            # pad token is bos for T5, </s> for eos
            bos_id = torch.tensor([self.tokenizer.pad_token_id], dtype=torch.long)
            eos_id = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)
        else:
            bos_id = torch.tensor([self.tokenizer.bos_token_id], dtype=torch.long)
            eos_id = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)

        dialog_history = json_dict["history"]
        prefixed_dialog_history = []
        prefixed_dialog_history = self.prefix + ' '.join(dialog_history)
        checked_knowledge = self.knowledge_prefix + json_dict['checked_knowledge']

        assert len(prefixed_dialog_history) > 0
        assert len(checked_knowledge) > 0

        dialog_history_ids = self.tokenizer.encode(
            dialog_history,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length).squeeze(0)
        knowledge_ids = self.tokenizer.encode(
            knowledge,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_length)

        # assert art_ids.tolist() == tokens.input_ids[:len(art_ids)-1] + [1]

        #if self.args.lm_type == 't5':
        response = json_dict["label"]
        assert len(response) > 0
        response_ids = self.tokenizer.encode(
            response,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_decode_step).squeeze(0)
        response_ids = torch.cat([bos_id, response_ids, eos_id], dim=0)
        #elif self.args.lm_type == 'blenderbot':
        #    response_ids = torch.cat([bos_id, response_ids], dim=0)
        src_response_ids = response_ids[:-1]
        trg_response_ids = response_ids[1:]

        return_data = (dialog_history_ids, src_response_ids, trg_response_ids, knowledge_ids,)
        return return_data

    def with_knowledge_graph(self, index):
        def process_entities(text, tokens, entities, input_ids, wikidata_ids, ent_pos, prefix='', is_label=False):
            for entity in entities:
                if entity["start"] < 0 or entity["end"] < 0: continue
                start_idx, end_idx = char_to_token_idx(text, tokens, entity, len(input_ids),
                                                       prefix=prefix,
                                                       tokenizer=self.tokenizer,
                                                       opendialkg=self.args.domain == "opendialkg")
                if start_idx is None or end_idx is None:
                    continue
                if start_idx >= self.max_length or end_idx >= self.max_length:
                    continue

                # Sanity Check
                span = (input_ids + tokens.input_ids)[start_idx:end_idx+1]
                pred = self.tokenizer.decode(span)
                ground_truth = self.tokenizer.decode(self.tokenizer.encode(entity["text"], add_special_tokens=False))
                wikidata_id = entity["id"]

                if wikidata_id in wikidata_ids:
                    node_idx = wikidata_ids.index(wikidata_id)
                else:
                    node_idx = len(wikidata_ids)
                    wikidata_ids.append(wikidata_id)

                if is_label:
                    for tok_idx in range(start_idx, end_idx+1):
                        ent_pos[tok_idx] = 1
                else:
                    for tok_idx in range(start_idx, end_idx+1):
                        ent_pos[tok_idx] = node_idx

            return wikidata_ids, ent_pos

        line = linecache.getline(self.file_name, index + 1)
        json_dict = json.loads(line)
        if self.args.lm_type == 't5':
            # pad token is bos for T5, </s> for eos
            bos_id = torch.tensor([0], dtype=torch.long)
            eos_id = torch.tensor([1], dtype=torch.long)
        else:
            bos_id = torch.tensor([self.tokenizer.bos_token_id], dtype=torch.long)
            eos_id = torch.tensor([self.tokenizer.eos_token_id], dtype=torch.long)

        user_input = self.prefix + ' '.join(json_dict["history"])

        tokens = self.tokenizer(user_input, add_special_tokens=False)
        entities = json_dict["entities"]
        wikidata_ids = []
        ent_pos = [-1] * (len(tokens.input_ids))
        wikidata_ids, ent_pos = process_entities(user_input, tokens, entities, [], wikidata_ids, ent_pos, self.prefix)
        if self.args.lm_type == "t5":
            # Truncate
            ent_pos = ent_pos[:self.max_length-1]
            # For BOS token
            ent_pos.append(-1)
        else:
            # Append in front of seq (SOS)
            ent_pos = [-1] + ent_pos
            # Truncate
            ent_pos = ent_pos[:self.max_length-1]
            # For BOS token
            ent_pos.append(-1)

        local_wikidata_ids = []
        global_to_local = {}
        local_ent_pos = []
        idx = 0
        for pos in list(set(ent_pos)):
            if pos < 0: continue
            local_wikidata_ids.append(wikidata_ids[pos])
            global_to_local[pos] = idx
            idx += 1
        for pos in ent_pos:
            local_ent_pos.append(global_to_local.get(pos, -1))
        assert max(local_ent_pos) + 1 == len(local_wikidata_ids)
        wikidata_ids = local_wikidata_ids
        ent_pos = local_ent_pos

        """ Processing Label (Response) """
        wizard_output = json_dict["label"] # Take first answer only

        assert len(user_input) > 0
        assert len(wizard_output) > 0
        user_ids = self.tokenizer.encode(user_input,
                                        return_tensors="pt",
                                        truncation=True,
                                        max_length=self.max_length).squeeze(0)

        wizard_ids = self.tokenizer.encode(wizard_output,
                                        return_tensors="pt",
                                        truncation=True,
                                        max_length=self.max_decode_step).squeeze(0)

        if self.args.lm_type == "t5":
            wizard_ids = torch.cat([bos_id, wizard_ids], dim=0)

        edge_index = []
        edge_attr = []
        local_indicator = [1 for _ in range(len(wikidata_ids))]

        if "triplets_label" in json_dict.keys():
            pseudo_labels = json_dict["triplets_label"]
        else:
            pseudo_labels = [0 for _ in range(len(json_dict["triplets"]))]
        labels = []
        if len(json_dict["triplets"]) > 0:
            for triplet, label in zip(json_dict["triplets"], pseudo_labels):
                if triplet['h'] not in wikidata_ids:
                    wikidata_ids.append(triplet['h'])
                    local_indicator.append(0)

                if triplet['t'] not in wikidata_ids:
                    wikidata_ids.append(triplet['t'])
                    local_indicator.append(0)

                edge = [wikidata_ids.index(triplet['h']), wikidata_ids.index(triplet['t'])]
                if edge not in edge_index:
                    edge_index.append(edge)
                    edge_attr.append(self.label_map[triplet['r']]) # Shift due to "0 = no relation"
                    labels.append(label)

        knowledge_graph = KnowledgeGraph(
            wikidata_ids,
            ent_pos,
            edge_index,
            edge_attr,
            local_indicator,
            # label_ent_pos,
        )
        knowledge_graph.label = labels
        # Gold Graph?
        if "gold_triplets" not in json_dict.keys():
            return user_ids, wizard_ids, knowledge_graph, None
        
        # Build gold knowledge
        gold_triplets = json_dict["gold_triplets"]
        gold_nodes = []
        gold_edge_index = []
        gold_edge_attr = []

        if len(gold_triplets) > 0:
            for triplet in gold_triplets:
                if triplet['h'] not in gold_nodes:
                    gold_nodes.append(triplet['h'])
                if triplet['t'] not in gold_nodes:
                    gold_nodes.append(triplet['t'])
                
                edge = [
                    gold_nodes.index(triplet['h']),
                    gold_nodes.index(triplet['t'])
                ]
                if edge not in gold_edge_index:
                    gold_edge_index.append(edge)
                    gold_edge_attr.append(self.label_map[triplet['r']])
        gold_knowledge_graph = KnowledgeGraph(
            gold_nodes,
            None,
            gold_edge_index,
            gold_edge_attr,
            None,
        )
        return user_ids, wizard_ids, knowledge_graph, gold_knowledge_graph
            

    def __getitem__(self, index):
        if 'text' in self.args.knowledge or\
           'none' in self.args.knowledge:
            return self.without_knowledge_graph(index)
        else:
            return self.with_knowledge_graph(index)

    def __len__(self):
        return self.total_size

class Dialprocessor(object):
    def __init__(self, args):
        if args.domain == "opendialkg":
            self.train_file = "train.jsonl"
            self.dev_file = "valid.jsonl"
            self.test_file = "test.jsonl"
            self.toy_file = "toy.jsonl"
        else:
            raise NotImplementedError()

        self.args = args
        args.dev_file = self.dev_file
        args.test_file = self.test_file

    def get_train_examples(self, data_dir):
        print(f"DataProcessor: {self.train_file}")
        return T5Dataset(os.path.join(data_dir, self.train_file), args=self.args)

    def get_dev_examples(self, data_dir):
        print(f"DataProcessor: {self.dev_file}")
        return T5Dataset(os.path.join(data_dir, self.dev_file), args=self.args)

    def get_test_examples(self, data_dir):
        print(f"DataProcessor: {self.test_file}")
        return T5Dataset(os.path.join(data_dir, self.test_file), args=self.args)

def load_raw_dataset(args, fold):
    if fold == "train":
        filename = "train.jsonl"
    elif fold == "dev":
        filename = "valid.jsonl"
    elif fold == "test":
        filename = "test.jsonl"

    datafile = os.path.join(args.data_dir, filename)
    with open(datafile, 'r') as f:
        dataset = [json.loads(data) for data in f.readlines()]
    return dataset

class Profiler(object):
    def __init__(self, args):
        with open(os.path.join(args.data_dir, "entity_codebook.pkl"), 'rb') as f:
            self.entity_codebook = pickle.load(f)
        self.reverse_entity_codebook = {v:k for k, v in self.entity_codebook.items()}

        with open(os.path.join(args.data_dir, "relation_codebook.pkl"), 'rb') as f:
            self.relation_codebook = pickle.load(f)
        self.reverse_relation_codebook = {v:k for k, v in self.relation_codebook.items()}
        self.tokenizer = args.tokenizer
        self.reverse_label_map = {v:k for k, v in args.label_map.items()}

    def write_profile(self,
                      profile_fw,
                      data,
                      new_input_ids,
                      score,
                      pred_response_token,
                      graph_inputs,
                      batch_idx):
        headline = f"Episode {data['episode_id']}, Turn {data['turn_id']}"
        history = "HISTORY ==================\n" + '\n'.join(data['history'])
        response = "GT RESPONSE ================\n" + data["label"]
        preds = "PREDICTIONS =================\n" + pred_response_token.strip()
        triplets = data["triplets"]
        knowledges = "KNOWLEDGE ====================\n"
        knowledges += f"# Knowledges: {len(triplets)}\n"
        new_history = self.tokenizer.decode(new_input_ids.cpu(),
                                skip_special_tokens=True,
                                clean_up_tokenization_spaces=False)
        new_history = ("Selected FACT + HISTORY ============\n" + new_history).strip()
        score_text = f"FACT score: {score.item():.4f}"

        ####################
        # Graph Processing #
        ####################
        textual_graphs = "GRAPH =======================\n"
        if graph_inputs is not None:
            nodes = graph_inputs["nodes"]
            edge_index = graph_inputs["edge_index"]
            edge_attr = graph_inputs["edge_attr"]
            graph_batch = graph_inputs["batch"]

            for edge, attr in zip(edge_index, edge_attr):
                if graph_batch[edge[0]] == batch_idx and graph_batch[edge[1]] == batch_idx:
                    head = nodes[edge[0]]
                    tail = nodes[edge[1]]
                    relation = self.reverse_relation_codebook[self.reverse_label_map[attr]]
                    textual_graphs += f"{head}\t{relation}\t{tail}\n"

        profile_fw.write(headline + '\n')
        profile_fw.write(history + '\n')
        profile_fw.write(response + '\n')
        profile_fw.write(score_text + '\n')
        profile_fw.write(new_history + '\n')
        profile_fw.write(preds + '\n')
        profile_fw.write(knowledges) # Too long
        profile_fw.write(textual_graphs + '\n')
        profile_fw.flush()