from src.engine import Engine
from src.torchkge_evaluation import LinkPredictionEvaluator
from src.utils import count_parameters, calculate_distribution, kl_divergence, plot_distributions
import json
import os
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import timeit
from sklearn.metrics import precision_recall_curve, f1_score
import torch
from scipy.stats import wasserstein_distance
from sklearn.preprocessing import MinMaxScaler

class ExperimentRunner:
    def __init__(self, config):
        # Force use of pretrained model for experiments
        config['use_pretrained_model'] = True
        self.engine = Engine(config)
        self.config = config
        self.model = self.engine.model
        self.dataset = self.engine.dataset
        self.evaluator = LinkPredictionEvaluator(config, self.engine.device, self.model, self.dataset, self.config['test_batch_size'], split='test')
        self.scores = None
        self.labels = None

    def number_of_parameters(self):
        total_params = count_parameters(self.model)
        for _, submodule in self.model.named_children():
            total_params += count_parameters(submodule)

        result = "{:,.2f} MILLION".format(total_params / 1000000)
        self._write_and_print("number_of_parameters", result)
        return total_params

    @torch.no_grad()
    def link_prediction(self):
        metrics = self.evaluator.get_link_prediction_metrics()
        self._write_and_print("link_prediction_metrics", json.dumps(metrics, indent=2))

    @torch.no_grad()
    def precision_at_k(self, scores, labels):
        self._calculate_at_k("precision", scores, labels)

    @torch.no_grad()
    def recall_at_k(self, scores, labels):
        self._calculate_at_k("recall", scores, labels)

    @torch.no_grad()
    def mean_average_precision(self, scores, labels):
        self._calculate_mean_average_precision(scores, labels)

    @torch.no_grad()
    def optimal_threshold(self, scores, labels):
        self._calculate_optimal_threshold(scores, labels)

    @torch.no_grad()
    def _calculate_at_k(self, metric_type, scores, labels):
        total_positives = np.sum(labels)
        print("total_positives", total_positives)
        top_k = [int(total_positives/8), int(total_positives/4), int(total_positives/2), int(total_positives)]

        results = []
        for k in top_k:
            indices = np.argpartition(scores, -k)[-k:]
            topk_labels = labels[indices]
            topk_true_positives = np.sum(topk_labels)
      
            if metric_type == "precision":
                metric_value = topk_true_positives / k
                result = f"{metric_type.upper()} for top {k}: {metric_value:.4f}"
            else:  # recall
                metric_value = topk_true_positives / total_positives
                result = f"{metric_type.capitalize()}@{k}: {metric_value:.4f}"

            results.append(result)

        self._write_and_print(metric_type, "\n".join(results))

    @torch.no_grad()
    def _calculate_mean_average_precision(self, scores, labels):
        sorted_indices = np.argsort(-scores)
        sorted_labels = labels[sorted_indices]

        precisions = []
        relevant_count = 0

        for k, label in enumerate(sorted_labels):
            if label == 1:
                relevant_count += 1
                precision_at_k = relevant_count / (k + 1)
                precisions.append(precision_at_k)

        average_precision = np.mean(precisions) if relevant_count > 0 else 0.0

        result = f"Mean Average Precision: {average_precision:.4f}"
        self._write_and_print("mean_average_precision", result)

        return average_precision
    @torch.no_grad()
    def _calculate_optimal_threshold(self, scores, labels):
        precision, recall, thresholds = precision_recall_curve(labels, scores)
        print("precision", precision)
        print("recall", recall)

        f1_scores = np.zeros_like(precision)
        for i in range(len(precision)):
            if precision[i] + recall[i] > 0:
                f1_scores[i] = 2 * (precision[i] * recall[i]) / (precision[i] + recall[i])
            else:
                f1_scores[i] = 0.0

        best_threshold_idx = np.argmax(f1_scores)
        best_threshold = thresholds[best_threshold_idx]
        best_f1 = f1_scores[best_threshold_idx]

        predicted_labels = (scores >= best_threshold).astype(int)
        true_positives = np.sum((predicted_labels == 1) & (labels == 1))
        false_positives = np.sum((predicted_labels == 1) & (labels == 0))
        false_negatives = np.sum((predicted_labels == 0) & (labels == 1))

        results = [
            f"Optimal Threshold: {best_threshold:.4f}",
            f"Best F1 Score: {best_f1:.4f}",
            f"Precision at optimal threshold: {precision[best_threshold_idx]:.4f}",
            f"Recall at optimal threshold: {recall[best_threshold_idx]:.4f}",
            f"True Positives: {true_positives}",
            f"False Positives: {false_positives}",
            f"False Negatives: {false_negatives}",
            f"Total facts classified as true: {true_positives + false_positives}",
            f"Total true facts: {true_positives + false_negatives}"
        ]

        self._write_and_print("optimal_threshold", "\n".join(results))

    @torch.no_grad()
    def get_scores_and_labels(self):
        if self.scores is None or self.labels is None:
            self.scores, self.labels = self.evaluator.get_scores_and_labels_test_facts()
        return self.scores, self.labels

    def run_experiment(self):
        if self.config['experiment'] == "all":
            self.run_all_experiments()
        else:
            self.run_single_experiment(self.config['experiment'])

    def run_all_experiments(self):
        self.number_of_parameters()
        self.link_prediction()
        scores, labels = self.get_scores_and_labels()
        self.precision_at_k(scores, labels)
        self.recall_at_k(scores, labels)
        self.mean_average_precision(scores, labels)
        self.optimal_threshold(scores, labels)
        self.subject_distribution_shift()

    def run_single_experiment(self, experiment):
        match experiment:
            case "link-prediction":
                self.link_prediction()
            case "number-of-parameters":
                self.number_of_parameters()
            case "precision-at-k":
                scores, labels = self.get_scores_and_labels()
                self.precision_at_k(scores, labels)
            case "recall-at-k":
                scores, labels = self.get_scores_and_labels()
                self.recall_at_k(scores, labels)
            case "mean-average-precision":
                scores, labels = self.get_scores_and_labels()
                self.mean_average_precision(scores, labels)
            case "optimal-threshold":
                scores, labels = self.get_scores_and_labels()
                self.optimal_threshold(scores, labels)
            case "subject-distribution-shift":
                self.subject_distribution_shift()
            case "plot-kde-test-set":
                self.plot_kde(scores, labels)
            case "probabilistic-threshold-adjustment":
                scores, labels = self.get_scores_and_labels()
                self.probabilistic_threshold_adjustment(scores, labels)
            case _:
                raise ValueError("Experiment Unknown")



    @torch.no_grad()
    def probabilistic_threshold_adjustment(self, scores, labels):
        # Calculate initial optimal threshold
        precision, recall, thresholds = precision_recall_curve(labels, scores)
        f1_scores = 2 * (precision * recall) / (precision + recall + 1e-8)
        best_threshold_idx = np.argmax(f1_scores)
        initial_best_threshold = thresholds[best_threshold_idx]
        initial_best_f1 = f1_scores[best_threshold_idx]

        # Sort scores and labels
        sorted_indices = np.argsort(-scores)
        sorted_scores = scores[sorted_indices]
        sorted_labels = labels[sorted_indices]

        # Relabel facts above threshold as true
        adjusted_labels = sorted_labels.copy()
        facts_relabeled = 0
        for i in range(len(sorted_scores)):
            if sorted_scores[i] >= initial_best_threshold:
                if adjusted_labels[i] == 0:
                    adjusted_labels[i] = 1
                    facts_relabeled += 1
            else:
                break  # Stop when we reach scores below the threshold

        # Recalculate metrics using adjusted labels
        new_precision, new_recall, new_thresholds = precision_recall_curve(adjusted_labels, sorted_scores)
        new_f1_scores = 2 * (new_precision * new_recall) / (new_precision + new_recall + 1e-8)
        new_best_threshold_idx = np.argmax(new_f1_scores)
        new_best_threshold = new_thresholds[new_best_threshold_idx]
        new_best_f1 = new_f1_scores[new_best_threshold_idx]

        results = [
            f"Initial Optimal Threshold: {initial_best_threshold:.4f}",
            f"Initial Best F1 Score: {initial_best_f1:.4f}",
            f"Number of facts relabeled as true: {facts_relabeled}",
            f"New Optimal Threshold after adjustment: {new_best_threshold:.4f}",
            f"New Best F1 Score after adjustment: {new_best_f1:.4f}",
            f"New Precision at optimal threshold: {new_precision[new_best_threshold_idx]:.4f}",
            f"New Recall at optimal threshold: {new_recall[new_best_threshold_idx]:.4f}",
        ]

        self._write_and_print("probabilistic_threshold_adjustment", "\n".join(results))

        return adjusted_labels, sorted_scores


    def _write_and_print(self, experiment_type, content):
        results_path = f"./results/{self.config['dataset']['class']}/{self.config['model_type']}_{experiment_type}.txt"
        print(content)
        os.makedirs(os.path.dirname(results_path), exist_ok=True)
        with open(results_path, "w") as f:
            f.write(content)

    @torch.no_grad()
    def subject_distribution_shift(self):
        train_subjects = self.dataset.kg_train.head_idx.numpy()
        test_subjects = self.dataset.kg_test.head_idx.numpy()

        train_dist = calculate_distribution(train_subjects)
        test_dist = calculate_distribution(test_subjects)

        kl_div = kl_divergence(train_dist, test_dist)
        w_distance = wasserstein_distance(list(train_dist.keys()), list(test_dist.keys()),
                                          list(train_dist.values()), list(test_dist.values()))

        plot_path = os.path.join("./", 'subject_distribution_shift.png')
        plot_distributions(train_dist, test_dist, plot_path)

        results = {
            "KL_Divergence": kl_div,
            "Wasserstein_Distance": w_distance
        }

        self._write_and_print("subject_distribution_shift", json.dumps(results, indent=2))
        return results
