import os
import torch

from sklearn.metrics import average_precision_score
import time
import numpy as np
import tqdm
from utils.model_utils import gen_map_from_scores, gen_map_from_embeds, nanl_fast_inference

class EarlyStoppingModule(object):
    """
    Module to keep track of validation score across epochs
    Stop training if score not imroving exceeds patience
    """

    def __init__(
        self, save_dir=".", task_name="TASK", patience=100, delta=0.005, logger=None
    ):
        self.save_dir = save_dir
        self.task_name = task_name
        self.patience = patience
        self.delta = delta
        self.logger = logger
        self.create_dirs()
        self.best_scores = None
        self.num_bad_epochs = 0
        self.should_stop_now = False

    def create_dirs(self):
        # Initial
        save_dir = os.path.join(self.save_dir, "initialModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.initial_model_path = os.path.join(save_dir, self.task_name)

        # Latest
        save_dir = os.path.join(self.save_dir, "latestModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.latest_model_path = os.path.join(save_dir, self.task_name)

        # Best
        save_dir = os.path.join(self.save_dir, "bestValidationModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.best_model_path = os.path.join(save_dir, self.task_name)

    def save_initial_state(self, state_dict):
        assert not os.path.exists(self.initial_model_path), f"{self.initial_model_path} already exists. Do you mean to resume"
        self.logger.info(f"saving initial model to {self.initial_model_path}")
        output = open(self.initial_model_path, mode="wb")
        
        torch.save(state_dict, output)
        output.close()

    def save_latest_model(self, state_dict):
        output = open(self.latest_model_path, mode="wb")
        torch.save(state_dict,output)
        output.close()

    def load_latest_model(self):
        assert os.path.exists(self.latest_model_path), f"{self.latest_model_path} does not exist. Do you mean to start fresh"

        self.logger.info(f"loading latest trained model from {self.latest_model_path}",)
        checkpoint = torch.load(self.latest_model_path)
        self.patience = checkpoint["patience"]
        self.best_scores = checkpoint["best_scores"]
        self.num_bad_epochs = checkpoint["num_bad_epochs"]
        self.should_stop_now = checkpoint["should_stop_now"]
        return checkpoint

    def save_best_model(self,state_dict):
        self.logger.info(f"saving best validated model to {self.best_model_path}")
        output = open(self.best_model_path, mode="wb")
        torch.save(state_dict,output)
        output.close()

    def load_best_model(self, device="cuda"):
        self.logger.info(f"loading best validated model from {self.best_model_path}")
        checkpoint = torch.load(self.best_model_path, map_location=device)
        return checkpoint

    def diff(self, curr_scores):
        return sum([cs - bs for cs, bs in zip(curr_scores, self.best_scores)])

    def check(self, curr_scores, state_dict):
        if self.best_scores is None:
            self.best_scores = curr_scores
            if "best_val_ap" in state_dict:
                state_dict["best_val_ap"] = state_dict["val_ap_score"]
            if "best_val_map" in state_dict:
                state_dict["best_val_map"] = state_dict["val_map_score"]
            if "best_neg_val_loss" in state_dict:
                state_dict["best_neg_val_loss"] = state_dict["neg_val_loss"]
            state_dict["best_scores"] = self.best_scores
            state_dict["num_bad_epochs"] = self.num_bad_epochs
            self.save_best_model(state_dict)
        elif self.diff(curr_scores) >= self.delta:
            self.num_bad_epochs = 0
            self.best_scores = curr_scores
            if "best_val_ap" in state_dict:
                state_dict["best_val_ap"] = state_dict["val_ap_score"]
            if "best_val_map" in state_dict:
                state_dict["best_val_map"] = state_dict["val_map_score"]
            if "best_neg_val_loss" in state_dict:
                state_dict["best_neg_val_loss"] = state_dict["neg_val_loss"]
            state_dict["best_scores"] = self.best_scores
            state_dict["num_bad_epochs"] = self.num_bad_epochs
            self.save_best_model(state_dict)
        else:
            self.num_bad_epochs += 1
            if self.num_bad_epochs > self.patience:
                self.should_stop_now = True
        self.save_latest_model(state_dict)
        return state_dict
    
def pairwise_ranking_loss(pred_pos, pred_neg, margin):
    num_pos, dim = pred_pos.shape
    num_neg, _ = pred_neg.shape

    expanded_pred_pos = pred_pos.unsqueeze(1)
    expanded_pred_neg = pred_neg.unsqueeze(0)
    relu = torch.nn.ReLU()
    loss = relu(margin + expanded_pred_neg - expanded_pred_pos)
    mean_loss = torch.mean(loss, dim=(0, 1))

    return mean_loss


# def evaluate_model(model, dataset, return_running_time=False):
#     model.eval()

#     # Compute global statistics
#     pos_pairs, neg_pairs = dataset.pos_pairs, dataset.neg_pairs
#     average_precision_score = compute_average_precision(model, pos_pairs, neg_pairs, dataset)

#     # Compute per-query statistics
#     num_query_graphs = len(dataset.query_graphs)
#     per_query_avg_prec = []
    
#     total_running_time = 0
#     total_batches = 0

#     for query_idx in range(num_query_graphs):
#         print(query_idx)
#         pos_pairs_for_query = list(filter(lambda pair: pair[0] == query_idx, pos_pairs))
#         neg_pairs_for_query = list(filter(lambda pair: pair[0] == query_idx, neg_pairs))
#         print("filtering done")
#         if len(pos_pairs_for_query) > 0 and len(neg_pairs_for_query) > 0:
#             average_precision, running_time, batches = compute_average_precision(
#                 model, pos_pairs_for_query, neg_pairs_for_query, dataset, return_running_time=True
#             )
#             total_running_time += running_time
#             total_batches += batches

#             per_query_avg_prec.append(average_precision)
#     mean_average_precision = np.mean(per_query_avg_prec)
    
#     if return_running_time:
#         return average_precision_score, mean_average_precision, total_running_time / total_batches
#     else:
#         return average_precision_score, mean_average_precision


# def evaluate_model_fast(model, dataset):
#     model.eval()

#     num_query_graphs = len(dataset.query_graphs)
#     # Compute per-query statistics
#     per_query_avg_prec = []

#     per_query_preds = []
#     per_query_avg_prec  = []

#     for query_idx in tqdm.tqdm(range(num_query_graphs)):
#     # for query_idx in range(num_query_graphs):
#         all_pairs = dataset.all_pairs[query_idx*dataset.num_corpus_graphs:(query_idx+1)*dataset.num_corpus_graphs]
#         all_labels = dataset.all_gt[query_idx*dataset.num_corpus_graphs:(query_idx+1)*dataset.num_corpus_graphs]

#         predictions = []
#         num_batches = dataset.create_eval_batches(all_pairs)
#         for batch_idx in range(num_batches):
#             batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes, _ = dataset.fetch_batch_by_id(batch_idx)       
#             model_output = model(batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes)
#             predictions.append(model_output.data)

#         all_predictions = torch.cat(predictions, dim=0).cpu()
#         average_precision = average_precision_score(all_labels, all_predictions)

#         per_query_preds.append(all_predictions)
#         per_query_avg_prec.append(average_precision)

#     mean_average_precision = np.mean(per_query_avg_prec)
#     average_precision = average_precision_score(dataset.all_gt, torch.cat(per_query_preds, dim=0))

#     return average_precision, mean_average_precision


def evaluate_model_faster(conf, model, dataset):
    model.eval()
    with torch.no_grad():
        model.fetch_embed = True

        if conf.model.name == "NANL":
            all_scores = nanl_fast_inference(conf, model, dataset)
            map_score = gen_map_from_scores(all_scores, dataset.all_gt)
            ap_score = average_precision_score( dataset.all_gt, all_scores.reshape(-1).detach().cpu())
        elif conf.model.name == "NANLSIG":
            all_scores = nanl_fast_inference(conf, model, dataset, use_sig=True)
            map_score = gen_map_from_scores(all_scores, dataset.all_gt)
            ap_score = average_precision_score( dataset.all_gt, all_scores.reshape(-1).detach().cpu())
        else:
            q_embeds = model(dataset._pack_batch_1d(dataset.query_graphs),dataset.query_graph_node_sizes,dataset.query_graph_edge_sizes)
            c_embeds = model(dataset._pack_batch_1d(dataset.corpus_graphs),dataset.corpus_graph_node_sizes,dataset.corpus_graph_edge_sizes)
            if conf.model.scoring_function == "sighinge":
                map_score, all_preds = gen_map_from_embeds(conf, q_embeds, c_embeds, dataset.all_gt, model.sigmoid_a, model.sigmoid_b)  
            else:
                map_score, all_preds = gen_map_from_embeds(conf, q_embeds, c_embeds, dataset.all_gt)
            ap_score = average_precision_score(dataset.all_gt, all_preds)
            
        model.fetch_embed = False
    
    return ap_score, map_score
         




def compute_average_precision(model, pos_pairs, neg_pairs, dataset, return_pred_and_labels=False, return_running_time=False):
    assert not(return_running_time and return_pred_and_labels)
    all_pairs = pos_pairs + neg_pairs
    num_pos_pairs, num_neg_pairs = len(pos_pairs), len(neg_pairs)

    if return_running_time:
        total_running_time = 0
        total_batches = 0
    predictions = []
    num_batches = dataset.create_eval_batches(all_pairs)
    for batch_idx in range(num_batches):
        batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes, _ = dataset.fetch_batch_by_id(batch_idx)

        if return_running_time:
            start_time = time.time()

        model_output = model(batch_graphs, batch_graph_node_sizes, batch_graph_edge_sizes)

        if return_running_time:
            end_time = time.time()
            total_running_time += end_time - start_time
            total_batches += 1

        predictions.append(model_output.data)
    all_predictions = torch.cat(predictions, dim=0)
    all_labels = torch.cat([torch.ones(num_pos_pairs), torch.zeros(num_neg_pairs)])

    average_precision = average_precision_score(all_labels, all_predictions.cpu())
    if return_pred_and_labels:
        return average_precision, all_labels, all_predictions
    elif return_running_time:
        return average_precision, total_running_time, total_batches
    else:
        return average_precision