# -*- coding: utf-8 -*-
"""
Copyright TorchKGE developers
@author: Armand Boschin <aboschin@enst.fr>
"""

from typing import Optional, Union, cast

import torch
from torchkge.exceptions import NotYetEvaluatedError
from torchkge.utils import filter_scores, get_rank 
from tqdm.autonotebook import tqdm
import numpy as np
import os
import pickle

from src.utils import  get_true_targets_batch
from src.dataset import KGEDataset, OgbKGEDatasetNBF, OgbKGEDataset

class LinkPredictionEvaluator:
    def __init__(self, config, device, model, dataset: Union[KGEDataset, OgbKGEDatasetNBF, OgbKGEDataset], batch_size: int, split: str = 'test'):
        self.config = config
        self.device = device
        self.model = model
        self.model.eval()
        self.dataset = dataset
        self.split = split
        self.batch_size = batch_size

        self.processed_dir = os.path.join(config['dataset']['path'], "ogbl_biokg", 'processed')

        if split == 'test':
            self.kg_eval = self.dataset.kg_test
        elif split == 'valid':
            self.kg_eval = self.dataset.kg_val
        else:
            raise ValueError(f"Invalid split: {split}. Must be 'test' or 'valid'.")

        # Get dataloader
        self.dataloader = self.dataset.get_dataloader(batch_size=batch_size, split=split)

        # Get number of facts from the dataloader
        self.n_facts = len(self.dataloader.dataset)

        # Initialize tensors with the correct size
        self.rank_true_heads = torch.zeros(size=(self.n_facts,)).long().to(device)
        self.rank_true_tails = torch.zeros(size=(self.n_facts,)).long().to(device)
        self.filt_rank_true_heads = torch.zeros(size=(self.n_facts,)).long().to(device)
        self.filt_rank_true_tails = torch.zeros(size=(self.n_facts,)).long().to(device)
        self.evaluated = False
        self.is_typed_dataset = isinstance(dataset, (OgbKGEDataset, OgbKGEDatasetNBF))
        self.seen_triples = set()


    def filter_duplicate_triples(self, h_idx, r_idx, t_idx, filt_scores, y_batch, is_tail_prediction):
        # Move data to CPU for filtering
        h_idx_cpu = h_idx.cpu().numpy()
        r_idx_cpu = r_idx.cpu().numpy()
        t_idx_cpu = t_idx.cpu().numpy()
        filt_scores_cpu = filt_scores.cpu().numpy()
        y_batch_cpu = y_batch.cpu().numpy()

        y_pred = []
        y_true = []

        for j in range(len(h_idx_cpu)):
            h, r, t = h_idx_cpu[j], r_idx_cpu[j], t_idx_cpu[j]

            valid_mask = filt_scores_cpu[j] != float('-inf')
            candidates = np.arange(self.dataset.n_entities)[valid_mask]
            
            new_triples = set((h, r, c) if is_tail_prediction else (c, r, t) for c in candidates) - self.seen_triples
            self.seen_triples.update(new_triples)
            
            indices = [c for _, _, c in new_triples] if is_tail_prediction else [c for c, _, _ in new_triples]
        
            scores = filt_scores_cpu[j][indices]
            labels = y_batch_cpu[j][indices]
        
            y_pred.extend(scores)
            y_true.extend(labels)

        return y_pred, y_true

    def get_scores_and_labels_test_facts(self):
        if self.config['model_type'] == 'nbf' and self.config['dataset']['class'].lower() == 'ogblbiokg':
            return self.get_scores_and_labels_test_facts_ogb_nbf()
        cache_file = os.path.join(self.processed_dir, f'{self.config["model_type"]}_scores_labels.pkl')

        if os.path.exists(cache_file):
            print("Loading cached scores and labels...")
            with open(cache_file, 'rb') as f:
                return pickle.load(f)


        y_pred = []
        y_true = []
        batch_size = self.batch_size
        use_cuda = self.device.type == "cuda"

        for i, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader)):
            h_idx, r_idx, t_idx = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device)
            start_idx = i * batch_size
            end_idx = min((i+1) * batch_size, self.n_facts)

            # Tail prediction
            y_score_t = self.model.inference_tail_prediction(h_idx, r_idx, t_idx)
            if self.is_typed_dataset:
                filt_scores_t = self.dataset.filter_invalid_triples(y_score_t, t_idx, 'tail', self.split, start_idx, end_idx)
            else:
                filt_scores_t = filter_scores(y_score_t, self.dataset.dict_of_tails_train_val, h_idx, r_idx, None)

            y_batch_t = get_true_targets_batch(filt_scores_t, self.dataset.dict_of_tails_test, h_idx, r_idx)

            self.rank_true_tails[start_idx:end_idx] = get_rank(y_score_t, t_idx).detach()
            self.filt_rank_true_tails[start_idx:end_idx] = get_rank(filt_scores_t, t_idx).detach()

            # Head prediction
            y_score_h = self.model.inference_head_prediction(h_idx, r_idx, t_idx)
            if self.is_typed_dataset:
                filt_scores_h = self.dataset.filter_invalid_triples(y_score_h, h_idx, 'head', self.split, start_idx, end_idx)
            else:
                filt_scores_h = filter_scores(y_score_h, self.dataset.dict_of_heads_train_val, t_idx, r_idx, None)

            y_batch_h = get_true_targets_batch(filt_scores_h, self.dataset.dict_of_heads_test, t_idx, r_idx)

            self.rank_true_heads[start_idx:end_idx] = get_rank(y_score_h, h_idx).detach()
            self.filt_rank_true_heads[start_idx:end_idx] = get_rank(filt_scores_h, h_idx).detach()
            
            # Filter duplicate triples for tail prediction
            tail_pred, tail_true = self.filter_duplicate_triples(h_idx, r_idx, t_idx, 
                                                                 filt_scores_t, 
                                                                 y_batch_t, 
                                                                 is_tail_prediction=True)
            y_pred.extend(tail_pred)
            y_true.extend(tail_true)

            # Filter duplicate triples for head prediction
            head_pred, head_true = self.filter_duplicate_triples(h_idx, r_idx, t_idx, 
                                                                 filt_scores_h, 
                                                                 y_batch_h, 
                                                                 is_tail_prediction=False)
            y_pred.extend(head_pred)
            y_true.extend(head_true)

        y_pred = np.array(y_pred, dtype=np.float32)
        y_true = np.array(y_true, dtype=np.int64)

        self.evaluated = True

        if use_cuda:
            self.rank_true_heads = self.rank_true_heads.cpu()
            self.rank_true_tails = self.rank_true_tails.cpu()
            self.filt_rank_true_heads = self.filt_rank_true_heads.cpu()
            self.filt_rank_true_tails = self.filt_rank_true_tails.cpu()
        with open(cache_file, 'wb') as f:
            pickle.dump((y_pred, y_true), f)

        print("Scores and labels computed and cached.")
        return y_pred, y_true
        
    def get_scores_and_labels_test_facts_ogb_nbf(self):
        y_pred = []
        y_true = []
        batch_size = self.batch_size
        cache_file = os.path.join(self.processed_dir, f'{self.config["model_type"]}_scores_labels.pkl')

        if os.path.exists(cache_file):
            print("Loading cached scores and labels...")
            with open(cache_file, 'rb') as f:
                return pickle.load(f)

        for i, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader)):
            h_idx, r_idx, t_idx = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device)
            start_idx = i * batch_size
            end_idx = min((i+1) * batch_size, self.n_facts)      

            # Tail prediction
            y_score_t = torch.zeros(size=(len(h_idx), self.dataset.n_entities), device=self.device)
            if self.is_typed_dataset:
                filt_scores_t = self.dataset.filter_invalid_triples(y_score_t, t_idx, 'tail', self.split, start_idx, end_idx)
            else:
                filt_scores_t = filter_scores(y_score_t, self.dataset.dict_of_tails_train_val, h_idx, r_idx, None)

            y_batch_t = get_true_targets_batch(filt_scores_t, self.dataset.dict_of_tails_test, h_idx, r_idx)
            
            valid_tail_mask = filt_scores_t != float('-inf')
            
            # Create valid triples for tail prediction
            h_repeat = h_idx.unsqueeze(1).expand(-1, self.dataset.n_entities)
            r_repeat = r_idx.unsqueeze(1).expand(-1, self.dataset.n_entities)
            t_candidates = torch.arange(self.dataset.n_entities, device=self.device).expand(len(h_idx), -1)
            
            valid_triples_t = torch.stack([h_repeat[valid_tail_mask], t_candidates[valid_tail_mask], r_repeat[valid_tail_mask]], dim=1)
            
            scores_t = self.model.nbf_model.predict(torch.reshape(valid_triples_t, (batch_size , 501, 3)))

            filt_scores_t[valid_tail_mask] = scores_t.view(-1)

            # Head prediction
            y_score_h = torch.zeros(size=(len(h_idx), self.dataset.n_entities), device=self.device)
            if self.is_typed_dataset:
                filt_scores_h = self.dataset.filter_invalid_triples(y_score_h, h_idx, 'head', self.split, start_idx, end_idx)
            else:   
                filt_scores_h = filter_scores(y_score_h, self.dataset.dict_of_heads_train_val, t_idx, r_idx, None)
            y_batch_h = get_true_targets_batch(filt_scores_h, self.dataset.dict_of_heads_test, t_idx, r_idx)

            valid_head_mask = filt_scores_h != float('-inf')
            
            # Create valid triples for head prediction
            h_candidates = torch.arange(self.dataset.n_entities, device=self.device).expand(len(t_idx), -1)
            r_repeat = r_idx.unsqueeze(1).expand(-1, self.dataset.n_entities)
            t_repeat = t_idx.unsqueeze(1).expand(-1, self.dataset.n_entities)
            
            valid_triples_h = torch.stack([h_candidates[valid_head_mask], t_repeat[valid_head_mask], r_repeat[valid_head_mask]], dim=1)
            
            scores_h = self.model.nbf_model.predict(torch.reshape(valid_triples_h, (batch_size , 501, 3)))
            filt_scores_h[valid_head_mask] = scores_h.view(-1)

            self.rank_true_heads[start_idx:end_idx] = get_rank(y_score_h, h_idx).detach()
            self.filt_rank_true_heads[start_idx:end_idx] = get_rank(filt_scores_h, h_idx).detach()
            self.rank_true_tails[start_idx:end_idx] = get_rank(y_score_t, t_idx).detach()
            self.filt_rank_true_tails[start_idx:end_idx] = get_rank(filt_scores_t, t_idx).detach()

            # Filter duplicate triples for tail prediction
            tail_pred, tail_true = self.filter_duplicate_triples(h_idx, r_idx, t_idx, 
                                                                 filt_scores_t, 
                                                                 y_batch_t, 
                                                                 is_tail_prediction=True)
            y_pred.extend(tail_pred)
            y_true.extend(tail_true)

            # Filter duplicate triples for head prediction
            head_pred, head_true = self.filter_duplicate_triples(h_idx, r_idx, t_idx, 
                                                                 filt_scores_h, 
                                                                 y_batch_h, 
                                                                 is_tail_prediction=False)
            y_pred.extend(head_pred)
            y_true.extend(head_true)

        y_pred = np.array(y_pred, dtype=np.float32)
        y_true = np.array(y_true, dtype=np.int64)

        self.evaluated = True

        # Save results to cache
        with open(cache_file, 'wb') as f:
            pickle.dump((y_pred, y_true), f)

        print("Scores and labels computed and cached.")
        return y_pred, y_true

    def get_scores_and_ranks(self):
        # Compare Global Scores against their local ranking. 

        batch_size = self.batch_size
        verbose = True
        for batch_idx, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader),
                             unit='batch', disable=(not verbose),
                             desc='Link prediction evaluation'):
            h_idx, r_idx, t_idx = batch[0], batch[1], batch[2]

            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx+1) * batch_size, self.n_facts)
            scores = self.model.inference_tail_prediction(
                h_idx, r_idx, t_idx)
            filt_scores = filter_scores(
                scores, self.kg_eval.dict_of_tails, h_idx, r_idx, t_idx)
            self.rank_true_tails[start_idx:end_idx] = get_rank(scores, t_idx).detach()
            self.filt_rank_true_tails[start_idx:end_idx] = get_rank(
                filt_scores, t_idx).detach()
            self.probabilities[start_idx:end_idx] = scores.gather(1, t_idx.unsqueeze(1)).squeeze(1).detach()
        return self.probabilities.numpy(), self.filt_rank_true_tails.numpy()

    def evaluate(self, verbose: bool = True) -> None:
        use_cuda = self.device == "cuda"
        print("USE CUDA", use_cuda)

        start_idx = 0
        batch_size = self.batch_size

        for i, batch in tqdm(enumerate(self.dataloader), total=len(self.dataloader)):
            h_idx, r_idx, t_idx = batch[0].to(self.device), batch[1].to(self.device), batch[2].to(self.device)
            start_idx = i * batch_size
            end_idx = min((i+1) * batch_size, self.n_facts)

            # Tail prediction
            y_score_t = self.model.inference_tail_prediction(h_idx, r_idx, t_idx)
            if self.is_typed_dataset:
                filt_scores_t = self.dataset.filter_invalid_triples(y_score_t, t_idx, 'tail', self.split, start_idx, end_idx)
            else:
                filt_scores_t = filter_scores(y_score_t, self.dataset.dict_of_tails, h_idx, r_idx, t_idx)
            self.rank_true_tails[start_idx:end_idx] = get_rank(y_score_t, t_idx).detach()
            self.filt_rank_true_tails[start_idx:end_idx] = get_rank(filt_scores_t, t_idx).detach()

            # Head prediction
            y_score_h = self.model.inference_head_prediction(h_idx, r_idx, t_idx)
            if self.is_typed_dataset:
                filt_scores_h = self.dataset.filter_invalid_triples(y_score_h, h_idx, 'head', self.split, start_idx, end_idx)
            else:
                filt_scores_h = filter_scores(y_score_h, self.dataset.dict_of_heads, t_idx, r_idx, h_idx)
            self.rank_true_heads[start_idx:end_idx] = get_rank(y_score_h, h_idx).detach()
            self.filt_rank_true_heads[start_idx:end_idx] = get_rank(filt_scores_h, h_idx).detach()

        self.evaluated = True

        if use_cuda:
            self.rank_true_heads = self.rank_true_heads.cpu()
            self.rank_true_tails = self.rank_true_tails.cpu()
            self.filt_rank_true_heads = self.filt_rank_true_heads.cpu()
            self.filt_rank_true_tails = self.filt_rank_true_tails.cpu()

    def mean_rank(self):
        """

        Returns
        -------
        mean_rank: float
            Mean rank of the true entity when replacing alternatively head
            and tail in any fact of the dataset.
        filt_mean_rank: float
            Filtered mean rank of the true entity when replacing
            alternatively head and tail in any fact of the dataset.

        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        sum_ = (self.rank_true_heads.float().mean() +
              self.rank_true_tails.float().mean()).item()
        filt_sum = (self.filt_rank_true_heads.float().mean() +
                  self.filt_rank_true_tails.float().mean()).item()
        return sum_ / 2, filt_sum / 2

    def mean_rank_head(self):
        """

        Returns
        -------
        mean_rank: float
            Mean rank of the true entity when replacing alternatively head
            and tail in any fact of the dataset.
        filt_mean_rank: float
            Filtered mean rank of the true entity when replacing
            alternatively head and tail in any fact of the dataset.

        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        return self.rank_true_heads.float().mean().item(), self.filt_rank_true_heads.float().mean().item()

    def mean_rank_tail(self):
        """

        Returns
        -------
        mean_rank: float
            Mean rank of the true entity when replacing alternatively head
            and tail in any fact of the dataset.
        filt_mean_rank: float
            Filtered mean rank of the true entity when replacing
            alternatively head and tail in any fact of the dataset.

        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        return self.rank_true_tails.float().mean().item(), self.filt_rank_true_tails.float().mean().item()

    def hit_at_k_heads(self, k: int = 10) -> tuple[float, float]:
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        head_hit = (self.rank_true_heads <= k).float().mean()
        filt_head_hit = (self.filt_rank_true_heads <= k).float().mean()

        return head_hit.item(), filt_head_hit.item()

    def hit_at_k_tails(self, k: int = 10) -> tuple[float, float]:
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        tail_hit = (self.rank_true_tails <= k).float().mean()
        filt_tail_hit = (self.filt_rank_true_tails <= k).float().mean()

        return tail_hit.item(), filt_tail_hit.item()

    def hit_at_k(self, k: int = 10):
        """

        Parameters
        ----------
        k: int
            Hit@k is the number of entities that show up in the top k that
            give facts present in the dataset.

        Returns
        -------
        avg_hitatk: float
            Average of hit@k for head and tail replacement.
        filt_avg_hitatk: float
            Filtered average of hit@k for head and tail replacement.

        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')

        head_hit, filt_head_hit = self.hit_at_k_heads(k=k)
        tail_hit, filt_tail_hit = self.hit_at_k_tails(k=k)

        return (head_hit + tail_hit) / 2, (filt_head_hit + filt_tail_hit) / 2

    def hit_at_k_tail(self, k: int = 10):
        """

        Parameters
        ----------
        k: int
            Hit@k is the number of entities that show up in the top k that
            give facts present in the dataset.

        Returns
        -------
        avg_hitatk: float
            Average of hit@k for head and tail replacement.
        filt_avg_hitatk: float
            Filtered average of hit@k for head and tail replacement.

        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')

        tail_hit, filt_tail_hit = self.hit_at_k_tails(k=k)

        return tail_hit, filt_tail_hit

    def mrr(self) -> tuple[float, float]:
        """

        Returns
        -------
        avg_mrr: float
            Average of mean recovery rank for head and tail replacement.
        filt_avg_mrr: float
            Filtered average of mean recovery rank for head and tail
            replacement.

        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        head_mrr = (self.rank_true_heads.float()**(-1)).mean()
        tail_mrr = (self.rank_true_tails.float()**(-1)).mean()
        filt_head_mrr = (self.filt_rank_true_heads.float()**(-1)).mean()
        filt_tail_mrr = (self.filt_rank_true_tails.float()**(-1)).mean()

        return ((head_mrr + tail_mrr).item() / 2,
                (filt_head_mrr + filt_tail_mrr).item() / 2)

    def mrr_head(self) -> tuple[float, float]:
        """


        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        head_mrr = (self.rank_true_heads.float()**(-1)).mean().item()
        filt_head_mrr = (self.filt_rank_true_heads.float()**(-1)).mean().item()

        return head_mrr, filt_head_mrr

    def mrr_tail(self) -> tuple[float, float]:
        """

        Returns
        -------

        """
        if not self.evaluated:
            raise NotYetEvaluatedError('Evaluator not evaluated call '
                                       'LinkPredictionEvaluator.evaluate')
        tail_mrr = (self.rank_true_tails.float()**(-1)).mean().item()
        filt_tail_mrr = (self.filt_rank_true_tails.float()**(-1)).mean().item()

        return tail_mrr, filt_tail_mrr

    def print_results(self, k: Optional[Union[int, list[int]]] = 10, n_digits: int = 3) -> None:
        """

        Parameters
        ----------
        k: int or list
            k (or list of k) such that hit@k will be printed.
        n_digits: int
            Number of digits to be printed for hit@k and MRR.
        """

        if isinstance(k, int):
            print('Hit@{} : {} \t\t Filt. Hit@{} : {}'.format(
                k, round(self.hit_at_k(k=k)[0], n_digits),
                k, round(self.hit_at_k(k=k)[1], n_digits)))
        elif isinstance(k, list):
            for i in k:
                print('Hit@{} : {} \t\t Filt. Hit@{} : {}'.format(
                    i, round(self.hit_at_k(k=i)[0], n_digits),
                    i, round(self.hit_at_k(k=i)[1], n_digits)))
        else:
            raise AssertionError(
                f"Paramter k with value{k} is not a integer or list of integers")

        print('Mean Rank : {} \t Filt. Mean Rank : {}'.format(
            int(self.mean_rank()[0]), int(self.mean_rank()[1])))
        print('MRR : {} \t\t Filt. MRR : {}'.format(
            round(self.mrr()[0], n_digits), round(self.mrr()[1], n_digits)))


    def get_link_prediction_metrics(self):

        # scores, labels = self.get_scores_and_labels_test_facts()
        self.evaluate()
        mrr, filtered_mrr = self.mrr()
        mean_rank, filtered_mean_rank = self.mean_rank()
        hit_k_1, hit_k_1_filtered = self.hit_at_k(k=1)
        hit_k_3, hit_k_3_filtered = self.hit_at_k(k=3)
        hit_k_5, hit_k_5_filtered = self.hit_at_k(k=5)
        hit_k_10, hit_k_10_filtered = self.hit_at_k(k=10)

        mrr_head, mrr_head_filtered = self.mrr_head()
        mrr_tail, mrr_tail_filtered = self.mrr_tail()

        return {
            'mrr': mrr,
            'mean_rank': mean_rank,
            'mrr_head': mrr_head,
            'mrr_tail': mrr_tail,
            'filtered_mean_rank': filtered_mean_rank,
            'filtered_mrr': filtered_mrr,
            'hit_k_1': hit_k_1,
            'hit_k_1_filtered': hit_k_1_filtered,
            'hit_k_3': hit_k_3,
            'hit_k_3_filtered': hit_k_3_filtered,
            'hit_k_5': hit_k_5,
            'hit_k_5_filtered': hit_k_5_filtered,
            'hit_k_10': hit_k_10,
            'hit_k_10_filtered': hit_k_10_filtered,
            'mrr_tail_filtered': mrr_tail_filtered,
            'mrr_head_filtered': mrr_head_filtered,
        }
