import os
import pickle
import random

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from models.layers.edge_conv import Model_HyperDrop
from utils import timer_func
from transformers import T5EncoderModel
from transformers.models.bart.modeling_bart import BartEncoder
from torch_geometric.utils import to_dense_batch

from torch_scatter import scatter
from transformers import logging

def build_lower_to_upper(args):
    lower_to_upper = dict()
    with open(os.path.join(args.data_dir, "opendialkg_entities.txt"), 'r') as f:
        entities = f.readlines()
    for entity in entities:
        entity = entity.strip()
        lower_entity = entity.lower()
        lower_to_upper[lower_entity] = entity
    return lower_to_upper

def build_entity_map(args):
    with open(os.path.join(args.data_dir ,"entity_codebook.pkl"), 'rb') as f:
        entity_codebook = pickle.load(f)
    lower_to_upper = build_lower_to_upper(args)
    vocab_size = len(args.tokenizer)
    memory_to_text = dict()
    for k, v in entity_codebook.items():
        new_value = lower_to_upper[k]
        new_key = args.wikidata_to_memory_map[v]
        memory_to_text[new_key] = new_value
    return memory_to_text

def build_relation_map(args):
    with open(os.path.join(args.data_dir ,"relation_codebook.pkl"), 'rb') as f:
        relation_codebook = pickle.load(f)

    memory_to_relation = dict()
    # reverse_relation_codebook = {v:k for k, v in relation_codebook.items()}
    reverse_label_map = {v:k for k, v in args.label_map.items()}
    for k, v in relation_codebook.items():
        memory_to_relation[args.label_map[v]] = k
        # memory_to_relation[v] = k

    relation_to_memory = dict()
    for idx, (k, v) in enumerate(relation_codebook.items()):
        relation_to_memory[k] = args.label_map[v]
    return memory_to_relation, relation_to_memory

class KnowledgeSampler(nn.Module):
    def __init__(self, args, entity_embed):
        super().__init__()
        self.edge_score_gnn = Model_HyperDrop(args, entity_embed, num_convs=1)
        if args.lm_type == "t5":
            self.encoder_model = T5EncoderModel.from_pretrained("t5-small")
        elif args.lm_type == "bart":
            self.encoder_model = BartEncoder.from_pretrained("facebook/bart-base")
        self.encoder_model.to(args.device)
        self.encoder_model.eval()

        self.token_scorer = nn.Sequential(
            nn.Linear(args.hidden_size, args.hidden_size),
            nn.ReLU(),
            nn.Linear(args.hidden_size, 1)
        )

        self.lower_to_upper = build_lower_to_upper(args)
        self.memory_to_entity = build_entity_map(args)
        self.memory_to_relation, self.relation_to_memory = build_relation_map(args)

        self.knowledge_length = args.knowledge_length
        self.tokenizer = args.tokenizer
        self.device = args.device

        for p in self.encoder_model.parameters():
            p.requires_grad = False

        self.total_num = 0
        self.total_knowledge_len = 0
        self.total_input_len = 0

        self.args = args
    
    def print_average_len(self):
        if self.total_num > 0:
            print(f"Average Knowledge Length: {self.total_knowledge_len / self.total_num}")
            print(f"Average Text Length: {self.total_input_len / self.total_num}")

    def reverse_relation(
        self, 
        relation: int,
    ):
        relation_text = self.memory_to_relation[relation]
        if "~" in relation_text:
            reverse_relation_text = relation_text[1:]
        else:
            reverse_relation_text = "~" + relation_text

        return self.relation_to_memory.get(reverse_relation_text)

    def append_text(
        self, 
        input_ids, 
        attention_mask,
        nodes,
        edge_index,
        edge_attr,
        top_edges_samples,
        unique_batch_list, # Unique batch idx where edge exists
        edge_batch,
        probs_samples,
        scores_samples,
        k,
    ):
        _edge_index = edge_index.t()
        batch_size = input_ids.shape[0]

        knowledge_input_ids = []
        knowledge_attention_mask = []
        knowledge_scores = []
        knowledge_probs = []

        kfm_edge_index = []
        kfm_edge_attr = []
        kfm_nodes = []
        kfm_batch = []
        kfm_batch_index = 0
        kfm_mention_positions = []

        # For Graph Augmentation
        augment_kfm_edge_attr = []

        def append_empty_fact():
            _input_ids = [0]
            _attention_mask = [0]
            _mention_positions = [-1]

            knowledge_input_ids.append(_input_ids)
            knowledge_attention_mask.append(_attention_mask)
            kfm_mention_positions.append(_mention_positions)
            knowledge_scores.append(torch.tensor(-10000.0, device=self.device)) # Zero prob
            knowledge_probs.append(torch.tensor(0.0, device=self.device)) # Zero prob

        for i in range(batch_size): 
            if i not in unique_batch_list: # If the batch has no graph
                for j in range(k):
                    append_empty_fact()
                    kfm_batch_index += 1
                continue
            
            # Select k times of Sampled nun_facts edges corresponding to current batch
            for top_edges, probs, scores in zip(top_edges_samples, probs_samples, scores_samples):
                new_edge_idxes = top_edges[edge_batch[top_edges] == i].tolist()

                _input_ids = []
                _mention_positions = []
                local_kfm_nodes = []

                # Sorting the node following the memory idx
                for _idx in new_edge_idxes:
                    # Find corresponding ids on newly constructed graphs
                    head, tail = _edge_index[_idx]
                    head = nodes[head].item()
                    tail = nodes[tail].item()
                    if head not in local_kfm_nodes:
                        local_kfm_nodes.append(head)
                    if tail not in local_kfm_nodes:
                        local_kfm_nodes.append(tail)

                local_kfm_nodes = sorted(local_kfm_nodes)
                local_kfm_nodes = [self.memory_to_entity[_e] for _e in local_kfm_nodes]

                # Append every nodes in front of the text
                for _node_id, _node in enumerate(local_kfm_nodes):
                    _single_input_ids = self.tokenizer.convert_tokens_to_ids(
                        self.tokenizer.tokenize(_node)
                    )
                    _input_ids.extend(_single_input_ids)
                    node_id = _node_id + len(kfm_nodes)
                    _mention_positions.extend([node_id for _ in range(len(_single_input_ids))])
                
                total_length = len(local_kfm_nodes) + len(kfm_nodes)
                # Iterate over n edges to pack them into single instance
                local_kfm_edge_index = []
                local_kfm_edge_attr = []
                for _idx in new_edge_idxes:
                    # Find corresponding ids on newly constructed graphs
                    head, tail = _edge_index[_idx]
                    head = self.memory_to_entity[nodes[head].item()]
                    tail = self.memory_to_entity[nodes[tail].item()]

                    head_node_id = local_kfm_nodes.index(head) + len(kfm_nodes)
                    tail_node_id = local_kfm_nodes.index(tail) + len(kfm_nodes)

                    if head_node_id >= total_length:
                        print(f"Node and Edge id violation, head id:{head_node_id}, {total_length}")
                        continue
                    if tail_node_id >= total_length:
                        print(f"Node and Edge id violation, tail id:{tail_node_id}, {total_length}")
                        continue

                    local_kfm_edge_index.append((head_node_id, tail_node_id))
                    local_kfm_edge_attr.append(edge_attr[_idx].item())

                    if (tail_node_id, head_node_id) not in local_kfm_edge_index:
                        _rev_edge_attr = self.reverse_relation(edge_attr[_idx].item())
                        if _rev_edge_attr is not None:
                            local_kfm_edge_index.append((tail_node_id, head_node_id))
                            local_kfm_edge_attr.append(_rev_edge_attr)

                # Hard Random
                local_augment_kfm_edge_attr = random.sample(self.memory_to_relation.keys(), len(local_kfm_edge_attr))
                # local_augment_kfm_edge_attr = sorted(local_kfm_edge_attr, key=lambda k: random.random())

                kfm_edge_index.extend(local_kfm_edge_index)
                kfm_edge_attr.extend(local_kfm_edge_attr)
                augment_kfm_edge_attr.extend(local_augment_kfm_edge_attr)
                    
                if len(_input_ids) > self.knowledge_length:
                    _input_ids = _input_ids[:self.knowledge_length]
                    # If some existing entities are truncated here, it might be error in scatter
                    _mention_positions = _mention_positions[:self.knowledge_length]
                _attention_mask = [1 for _ in range(len(_input_ids))]

                knowledge_input_ids.append(_input_ids)
                knowledge_attention_mask.append(_attention_mask)

                selected_scores = scores[edge_batch[top_edges] == i]
                knowledge_scores.append(torch.sum(selected_scores))

                selected_probs = probs[edge_batch[top_edges] == i]
                knowledge_probs.append(torch.prod(selected_probs))

                kfm_mention_positions.append(_mention_positions)
                kfm_batch.extend([kfm_batch_index for _ in range(len(local_kfm_nodes))])
                kfm_nodes.extend(local_kfm_nodes)
                kfm_batch_index += 1

        def padding(tensor_list, padding_value=0):
            tensors = [torch.tensor(o, dtype=torch.long, device=self.device) for o in tensor_list]
            return pad_sequence(tensors, batch_first=True, padding_value=padding_value)

        new_input_ids = padding(knowledge_input_ids)
        new_attention_mask = padding(knowledge_attention_mask)
        kfm_mention_positions = padding(kfm_mention_positions, -1)

        if k > 1:
            input_ids = input_ids.unsqueeze(1).expand(-1, k, -1)
            attention_mask = attention_mask.unsqueeze(1).expand(-1, k, -1)
            input_ids = input_ids.reshape(batch_size * k, -1)
            attention_mask = attention_mask.reshape(batch_size * k, -1)

        self.total_num += 1
        self.total_knowledge_len += new_input_ids.shape[1]
        self.total_input_len += input_ids.shape[1]

        input_ids = torch.cat([new_input_ids, input_ids], dim=-1)
        attention_mask = torch.cat([new_attention_mask, attention_mask], dim=-1)
        scores = torch.stack(knowledge_scores)
        probs = torch.stack(knowledge_probs)

        kfm_graph = {
            'nodes': kfm_nodes,
            'edge_index': kfm_edge_index,
            'edge_attr': kfm_edge_attr,
            'batch': kfm_batch,
            'mention_positions': kfm_mention_positions,
        }
        augment_kfm_graph = {
            'nodes': kfm_nodes,
            'edge_index': kfm_edge_index,
            'edge_attr': augment_kfm_edge_attr,
            'batch': kfm_batch,
            'mention_positions': kfm_mention_positions,
        }
        return input_ids, attention_mask, probs, scores, kfm_graph, augment_kfm_graph

    def encode_node(
        self,
        hidden_states,
        mention_positions,
        nodes,
        graph_batch,
        local_indicator,
    ):  
        scatter_states = scatter(hidden_states, mention_positions + 1, dim=1, reduce='mean')[:,1:,:]
        scatter_mask = pad_sequence([torch.ones(num) for num in mention_positions.max(dim=-1).values+1], batch_first=True, padding_value=0).bool()
        # scatter_mask = (scatter_states.sum(-1) != 0.0)
        scatter_states = scatter_states[scatter_mask] # Flattening

        assert (local_indicator == 1).sum() == len(scatter_states)
        all_states = torch.zeros(nodes.shape[0], scatter_states.shape[-1]).to(device=self.device)
        all_states[local_indicator.bool()] = scatter_states

        assert len(all_states) == len(nodes)
        return all_states

    def encode_sentence(
        self,
        hidden_states,
        attention_mask
    ):
        score = self.token_scorer(hidden_states).squeeze()
        attention_mask = (1.0 - attention_mask) * -10000.0
        probs = torch.softmax(score + attention_mask, dim=-1).unsqueeze(-1)
        embeddings = torch.sum(probs * hidden_states, dim=1)
        return embeddings
        
    def forward(
        self,
        input_ids,
        attention_mask,
        mention_positions,
        nodes,
        edge_index,
        edge_attr,
        graph_batch,
        local_indicator,
        k=None,
    ):
        hidden_states = self.encoder_model(input_ids, attention_mask).last_hidden_state.detach()
        node_embeds = self.encode_node(hidden_states,
                                       mention_positions, 
                                       nodes,
                                       graph_batch,
                                       local_indicator)
        # Compute Sentence Embedding using Sentence Transformer
        embeddings = self.encode_sentence(hidden_states, attention_mask)

        # Samples top 
        ## top_edges: List[Tensor]
        ## probs: List[Tensor]
        ## scores: List[Tensor]
        top_edges, edge_batch, unique_batch_list, \
        probs, scores, full_scores = self.edge_score_gnn(
            node_embeds,
            nodes,
            edge_index, 
            edge_attr, 
            graph_batch,
            embeddings,
            k=k,
        )
    
        input_ids, attention_mask, probs, scores, kfm_graph, augment_kfm_graph = self.append_text(
            input_ids, 
            attention_mask, 
            nodes, 
            edge_index, 
            edge_attr, 
            top_edges,
            unique_batch_list,
            edge_batch,
            probs, 
            scores,
            k=k,
        )

        return input_ids, attention_mask, probs, scores, \
               kfm_graph, augment_kfm_graph, edge_batch, full_scores

    def expand_decoder_input(
        self,
        decoder_input_ids,
        decoder_attention_mask,
        k
    ):
        batch_size = decoder_input_ids.shape[0]
        decoder_input_ids = decoder_input_ids.unsqueeze(1).expand(-1, k, -1).reshape(batch_size * k, -1)
        decoder_attention_mask = decoder_attention_mask.unsqueeze(1).expand(-1, k, -1).reshape(batch_size * k, -1)
        return decoder_input_ids, decoder_attention_mask