from __future__ import annotations
import numpy as np
from collections import defaultdict
from torchkge.utils.datasets import load_fb15k237, load_wn18rr
from torchkge.data_structures import KnowledgeGraph
import torch.nn.functional as F
import torch
import pandas as pd
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader
from ogb.linkproppred import LinkPropPredDataset
from collections import defaultdict
from typing import Optional, List
from torchkge.data_structures import KnowledgeGraph
import numpy as np
from scipy.stats import entropy
from collections import Counter
import matplotlib.pyplot as plt

from src.models import GenerativeModel
def get_relation_frequencies(kg_train: KnowledgeGraph) -> torch.Tensor:

    relation_frequencies = torch.zeros(kg_train.n_rel, dtype=torch.int)

    for j in range(kg_train.n_facts):
        relation = int(kg_train.relations[j].item())
        relation_frequencies[relation] += 1

    return relation_frequencies

def get_prior_frequencies(kg_train: KnowledgeGraph ) -> torch.Tensor:

    prior_frequencies = torch.zeros(kg_train.n_ent, dtype=torch.int)

    for j in range(kg_train.n_facts):
        subject = int(kg_train.head_idx[j].item())
        prior_frequencies[subject] += 1
        object = int(kg_train.tail_idx[j].item())
        prior_frequencies[object] += 1

   
    return prior_frequencies
def get_avg_indegree_per_relation(kg_train: KnowledgeGraph) -> dict[int, float]:

    relation_indegree: dict[int, tuple[int, set]] = {}

    for j in range(kg_train.n_facts):
        relation = kg_train.relations[j].item()

        if relation not in relation_indegree:
            relation_indegree[relation] = (0, set())

        current_count = relation_indegree[relation][0]
        current_nodes = relation_indegree[relation][1]
        current_count = current_count + 1
        current_nodes.add(relation)

        relation_indegree[relation] = (current_count, current_nodes)

    avg_indegree_per_relation = {}
    for i, (k, v) in enumerate(relation_indegree.items()):
        avg_indegree_per_relation[i] = k / len(v)

    return avg_indegree_per_relation


def get_avg_indegree_per_node(kg_train: KnowledgeGraph) -> torch.Tensor:

    object_frequencies = torch.zeros(kg_train.n_ent, dtype=torch.int)
    for j in range(kg_train.n_facts):
        object_id = int(kg_train.tail_idx[j].item())

        object_frequencies[object_id] += 1

    return torch.sum(object_frequencies)/kg_train.n_ent


def count_parameters(model: GenerativeModel) -> float:
    return sum(p.numel() for p in model.parameters())


def get_dict_of_tails_and_heads(kg_list: List[KnowledgeGraph], 
                                existing_tails: Optional[dict] = None, 
                                existing_heads: Optional[dict] = None) -> tuple[dict, dict]:
    dict_of_tails = defaultdict(set) if existing_tails is None else existing_tails
    dict_of_heads = defaultdict(set) if existing_heads is None else existing_heads

    for kg in kg_list:
        for i in range(kg.n_facts):
            head = kg.head_idx[i].item()
            relation = kg.relations[i].item()
            tail = kg.tail_idx[i].item()

            dict_of_tails[(head, relation)].add(tail)
            dict_of_heads[(tail, relation)].add(head)

    return dict_of_tails, dict_of_heads

def get_true_targets_batch(scores, dictionary, key1, key2):

    b_size = scores.shape[0]
    labels = torch.zeros_like(scores)

    for i in range(b_size):
        true_targets = dictionary[key1[i].item(), key2[i].item()].copy()
        if not true_targets:
            continue
        true_targets = torch.tensor(list(true_targets)).long()
        labels[i][true_targets] = True

    return labels


def subset_to_dataframe(subset):
    data_loader = DataLoader(subset)
    heads, tails, relations = [], [], []

    for batch in data_loader:
        batch = batch.squeeze(0)
        heads.append(batch[0].item())
        tails.append(batch[1].item())
        relations.append(batch[2].item())

    return pd.DataFrame({
        'from': heads,
        'rel': relations,
        'to': tails
    })
def load_nbf_mapping(dataset_name, data_path):
    from torchdrug import datasets
    """Load FB15k dataset. See `here
    <https://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data>`__
    for paper by Bordes et al. originally presenting the dataset.

    Parameters
    ----------
    data_home: str, optional
        Path to the `torchkge_data` directory (containing data folders). If
        files are not present on disk in this directory, they are downloaded
        and then placed in the right place.

    Returns
    -------
    kg_train: torchkge.data_structures.KnowledgeGraph
    kg_val: torchkge.data_structures.KnowledgeGraph
    kg_test: torchkge.data_structures.KnowledgeGraph

    """
    if dataset_name == "fb15k237":
        dataset = datasets.FB15k237(data_path)
    elif dataset_name == "wn18rr":
        dataset = datasets.WN18RR(data_path)
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")
    df1, df2, df3 = dataset.split()
    df1 = subset_to_dataframe(df1)
    df2 = subset_to_dataframe(df2)
    df3 = subset_to_dataframe(df3)
    df = pd.concat([df1, df2, df3])
    entity_dict = {string: int(i) for i, string in enumerate(dataset.entity_vocab)}
    relation_dict = {string: int(i) for i, string in enumerate(dataset.relation_vocab)}
    kg_nbf = {}
    kg_nbf['heads'] = torch.tensor(df['from'].values)
    kg_nbf['relations'] = torch.tensor(df['rel'].values)
    kg_nbf['tails'] = torch.tensor(df['to'].values)
    kg = KnowledgeGraph(kg=kg_nbf, ent2ix=entity_dict, rel2ix=relation_dict)

    return kg.split_kg(sizes=(len(df1), len(df2), len(df3)))

def load_dataset(config) -> tuple[KnowledgeGraph, KnowledgeGraph, KnowledgeGraph]:

    def assert_three(result: tuple[KnowledgeGraph, KnowledgeGraph, KnowledgeGraph] | tuple[KnowledgeGraph, KnowledgeGraph]
                     ) -> tuple[KnowledgeGraph, KnowledgeGraph, KnowledgeGraph]:
        assert len(result) == 3
        return result

    dataset_class = config['dataset']['class'].lower()

    if dataset_class == "fb15k237":
        if config['model_type'] == 'nbf':
            return assert_three(load_nbf_mapping("fb15k237", config['dataset']['path']))

        return assert_three(load_fb15k237(data_home=config['dataset']['path']))

    elif dataset_class == "wn18rr":
        if config['model_type'] == 'nbf':
            return assert_three(load_nbf_mapping("wn18rr", config['dataset']['path']))
        return assert_three(load_wn18rr(data_home=config['dataset']['path']))

    else:
        raise Exception(f"Dataset unknown: {dataset_class}")

def preprocess_ogbl_dataset(name: str, path: str, ds_out_path: str):
    print("Preprocessing OGBL dataset...")
    dataset = LinkPropPredDataset(name, root=path)
    split_edge = dataset.get_edge_split()
    train_triples, valid_triples, test_triples = split_edge["train"], split_edge["valid"], split_edge["test"]

    if name == 'ogbl-biokg':
        cur_idx, cur_type_idx, type_dict, entity_dict = 0, 0, {}, {}
        for key in dataset[0]['num_nodes_dict']:
            type_dict[key] = cur_type_idx
            cur_type_idx += 1
            entity_dict[key] = (cur_idx, cur_idx + dataset[0]['num_nodes_dict'][key])
            cur_idx += dataset[0]['num_nodes_dict'][key]



        def index_triples_across_type(triples, entity_dict, type_dict):
            triples['head_type_idx'] = np.zeros_like(triples['head'])
            triples['tail_type_idx'] = np.zeros_like(triples['tail'])
            for i in range(len(triples['head'])):
                h_type = triples['head_type'][i]
                triples['head_type_idx'][i] = type_dict[h_type]
                triples['head'][i] += entity_dict[h_type][0]
                if 'head_neg' in triples:
                    triples['head_neg'][i] += entity_dict[h_type][0]
                t_type = triples['tail_type'][i]
                triples['tail_type_idx'][i] = type_dict[t_type]
                triples['tail'][i] += entity_dict[t_type][0]
                if 'tail_neg' in triples:
                    triples['tail_neg'][i] += entity_dict[t_type][0]
            return triples

        print('Indexing triples across different entity types ...')
        train_triples = index_triples_across_type(train_triples, entity_dict, type_dict)
        valid_triples = index_triples_across_type(valid_triples, entity_dict, type_dict)
        test_triples = index_triples_across_type(test_triples, entity_dict, type_dict)
        other_data = {
            'train': np.concatenate([
                train_triples['head_type_idx'].reshape(-1, 1),
                train_triples['tail_type_idx'].reshape(-1, 1)
            ], axis=1),
            'valid': np.concatenate([
                valid_triples['head_neg'],
                valid_triples['tail_neg'],
                valid_triples['head_type_idx'].reshape(-1, 1),
                valid_triples['tail_type_idx'].reshape(-1, 1)
            ], axis=1),
            'test': np.concatenate([
                test_triples['head_neg'],
                test_triples['tail_neg'],
                test_triples['head_type_idx'].reshape(-1, 1),
                test_triples['tail_type_idx'].reshape(-1, 1)
            ], axis=1)
        }


    n_relations = int(max(train_triples['relation'])) + 1
    if name == 'ogbl-biokg':
        n_entities = sum(dataset[0]['num_nodes_dict'].values())
        assert train_triples['head'].max() <= n_entities
  
    print(f"{n_entities} entities and {n_relations} relations")

    train_array = np.concatenate([
        train_triples['head'].reshape(-1, 1),
        train_triples['relation'].reshape(-1, 1),
        train_triples['tail'].reshape(-1, 1)
    ], axis=1).astype(np.int64, copy=True)
    if other_data['train'] is not None:
        train_array = np.concatenate([train_array, other_data['train']], axis=1).astype(np.int64, copy=True)
    valid_array = np.concatenate([
        valid_triples['head'].reshape(-1, 1),
        valid_triples['relation'].reshape(-1, 1),
        valid_triples['tail'].reshape(-1, 1),
        other_data['valid']
    ], axis=1).astype(np.int64, copy=True)
    test_array = np.concatenate([
        test_triples['head'].reshape(-1, 1),
        test_triples['relation'].reshape(-1, 1),
        test_triples['tail'].reshape(-1, 1),
        other_data['test']
    ], axis=1).astype(np.int64, copy=True)

    triples = {'train': train_array, 'valid': valid_array, 'test': test_array}

 

    return triples, n_entities, n_relations



def calculate_distribution(subjects):
    count = Counter(subjects)
    total = sum(count.values())
    return {k: v / total for k, v in count.items()}

def kl_divergence(p, q):
    # Get the union of all keys
    all_keys = set(p.keys()) | set(q.keys())
    
    # Create arrays with zeros for missing keys
    p_values = np.array([p.get(k, 0) for k in all_keys])
    q_values = np.array([q.get(k, 1e-10) for k in all_keys])  # Use a small value instead of 0 for q
    
    # Normalize the arrays
    p_values = p_values / np.sum(p_values)
    q_values = q_values / np.sum(q_values)
    
    return entropy(p_values, qk=q_values)

def plot_distributions(train_dist, test_dist, save_path):
    all_keys = set(train_dist.keys()) | set(test_dist.keys())
    
    # Calculate the difference between train and test distributions
    diff = {k: abs(train_dist.get(k, 0) - test_dist.get(k, 0)) for k in all_keys}
    
    # Sort keys by the difference, in descending order
    sorted_keys = sorted(diff, key=diff.get, reverse=True)[:100]
    
    train_values = [train_dist.get(k, 0) for k in sorted_keys]
    test_values = [test_dist.get(k, 0) for k in sorted_keys]
    
    plt.figure(figsize=(15, 8))
    x = range(len(sorted_keys))
    plt.bar([i-0.2 for i in x], train_values, width=0.4, alpha=0.6, label='Train')
    plt.bar([i+0.2 for i in x], test_values, width=0.4, alpha=0.6, label='Test')
    plt.xlabel('Top 100 Subject Entity IDs with Largest Differences')
    plt.ylabel('Probability')
    plt.title('Distribution of Top 100 Subjects with Largest Differences in Train vs Test')
    plt.legend()
    plt.xticks(x, sorted_keys, rotation=90)
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()
