import os
import torch

from scipy.stats import kendalltau
import pyfiglet

    
def evaluate(model, sampler):
    model.eval()

    pred = []
    targets = []

    n_batches = sampler.create_batches(shuffle=False)
    for i in range(n_batches):
        (
            corpus_batch_data, 
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes, 
            batch_target, 
            corpus_batch_adj
        ) = sampler.fetch_batched_data_by_id(i)
        out = model(sampler.packed_query_graphs, 
            sampler.query_graph_node_sizes, 
            sampler.query_graph_edge_sizes, 
            sampler.query_adj_list, 
            corpus_batch_data, 
            corpus_batch_data_node_sizes, 
            corpus_batch_data_edge_sizes,
            corpus_batch_adj)
        pred_clique_sizes = ((out[:,:-1] - out[:,1:]) > model.delta).long().argmax(-1)+2
        pred.append(pred_clique_sizes.data)
        targets.append(batch_target)
    all_pred = torch.cat(pred, dim=0)
    all_target = torch.cat(targets, dim=0)
    ratio = (torch.round(all_pred) / all_target).mean()
    mse = torch.nn.functional.mse_loss(all_target, all_pred, reduction="mean").item()
    mae = torch.nn.functional.l1_loss(all_target, all_pred, reduction="mean").item()
    rankcorr = kendalltau(all_pred.cpu().tolist(), all_target.cpu().tolist())[0]
    acc = (all_target == torch.round(all_pred)).sum() / len(all_target)

    return ratio, mse, rankcorr, mae, acc



class EarlyStoppingModule(object):
    """
    Module to keep track of validation score across epochs
    Stop training if score not imroving exceeds patience
    """

    def __init__(
        self, save_dir=".", task_name="TASK", patience=100, delta=0.0001, logger=None
    ):
        self.save_dir = save_dir
        self.task_name = task_name
        self.patience = patience
        self.delta = delta
        self.logger = logger
        self.create_dirs()
        self.best_scores = None
        self.num_bad_epochs = 0
        self.should_stop_now = False

    def create_dirs(self):
        # Initial
        save_dir = os.path.join(self.save_dir, "initialModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.initial_model_path = os.path.join(save_dir, self.task_name)

        # Latest
        save_dir = os.path.join(self.save_dir, "latestModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.latest_model_path = os.path.join(save_dir, self.task_name)

        # Best
        save_dir = os.path.join(self.save_dir, "bestValidationModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.best_model_path = os.path.join(save_dir, self.task_name)

    def save_initial_model(self, model):
        self.logger.info(f"saving initial model to {self.initial_model_path}")
        output = open(self.initial_model_path, mode="wb")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
            },
            output,
        )
        output.close()

    def save_latest_model(self, model, epoch, optimizer):
        output = open(self.latest_model_path, mode="wb")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "epoch": epoch,
                "patience": self.patience,
                "best_scores": self.best_scores,
                "num_bad_epochs": self.num_bad_epochs,
                "should_stop_now": self.should_stop_now,
                "optim_state_dict": optimizer.state_dict(),
            },
            output,
        )
        output.close()

    def load_latest_model(self):
        if not os.path.exists(self.latest_model_path):
            return None

        self.logger.info(f"loading latest trained model from {self.latest_model_path}",)
        checkpoint = torch.load(self.latest_model_path)
        self.patience = checkpoint["patience"]
        self.best_scores = checkpoint["best_scores"]
        self.num_bad_epochs = checkpoint["num_bad_epochs"]
        self.should_stop_now = checkpoint["should_stop_now"]
        return checkpoint

    def save_best_model(self, model, epoch):
        self.logger.info(f"saving best validated model to {self.best_model_path}")
        output = open(self.best_model_path, mode="wb")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "epoch": epoch,
            },
            output,
        )
        output.close()

    def load_best_model(self):
        self.logger.info(f"loading best validated model from {self.best_model_path}")
        checkpoint = torch.load(self.best_model_path)
        return checkpoint

    def diff(self, curr_scores):
        return sum([cs - bs for cs, bs in zip(curr_scores, self.best_scores)])

    def check(self, curr_scores, model, epoch, optimizer):
        if self.best_scores is None:
            self.best_scores = curr_scores
            self.save_best_model(model, epoch)
        elif self.diff(curr_scores) >= self.delta:
            self.num_bad_epochs = 0
            self.best_scores = curr_scores
            self.save_best_model(model, epoch)
        else:
            self.num_bad_epochs += 1
            if self.num_bad_epochs > self.patience:
                self.should_stop_now = True
        self.save_latest_model(model, epoch, optimizer)
        return self.should_stop_now
    

class DualEarlyStoppingModuleWithIsoStabilizationFollowedByFF(object):
    """
    Module to keep track of validation score across epochs
    Stop training if score not imroving exceeds patience
    """

    def __init__(
        self, save_dir=".", task_name="TASK", patience=100, delta=0.0001, logger=None, stabilization_factor=0.5, iso_scores_threshold=2
    ):
        self.save_dir = save_dir
        self.task_name = task_name
        self.patience = patience
        self.delta = delta
        self.logger = logger
        self.create_dirs()
        self.best_iso_scores = None
        self.global_best_iso_scores = None
        self.best_ff_scores = None
        self.num_bad_epochs = 0
        self.should_stop_now = False
        self.stabilization_patience = int(stabilization_factor * self.patience)
        self.iso_scores_threshold = iso_scores_threshold
        self.has_stabilized = False

    def create_dirs(self):
        # Initial
        save_dir = os.path.join(self.save_dir, "initialModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.initial_model_path = os.path.join(save_dir, self.task_name)

        # Latest
        save_dir = os.path.join(self.save_dir, "latestModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.latest_model_path = os.path.join(save_dir, self.task_name)

        # Best
        save_dir = os.path.join(self.save_dir, "bestValidationModels")
        if not os.path.isdir(save_dir):
            os.makedirs(save_dir)
        self.best_model_path = os.path.join(save_dir, self.task_name)

    def save_initial_model(self, model):
        self.logger.info(f"saving initial model to {self.initial_model_path}")
        output = open(self.initial_model_path, mode="wb")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "best_iso_scores": self.best_iso_scores,
                "best_ff_scores": self.best_ff_scores,
                "best_global_iso_scores": self.global_best_iso_scores,
                "num_bad_epochs": self.num_bad_epochs,
            },
            output,
        )
        output.close()

    def save_latest_model(self, model, epoch, optimizer):
        output = open(self.latest_model_path, mode="wb")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "epoch": epoch,
                "patience": self.patience,
                "num_bad_epochs": self.num_bad_epochs,
                "should_stop_now": self.should_stop_now,
                "optim_state_dict": optimizer.state_dict(),
                "best_iso_scores": self.best_iso_scores,
                "best_ff_scores": self.best_ff_scores,
                "best_global_iso_scores": self.global_best_iso_scores,
            },
            output,
        )
        output.close()

    def load_latest_model(self):
        if not os.path.exists(self.latest_model_path):
            return None

        self.logger.info(f"loading latest trained model from {self.latest_model_path}",)
        checkpoint = torch.load(self.latest_model_path)
        self.patience = checkpoint["patience"]
        self.best_iso_scores = checkpoint["best_iso_scores"]
        self.best_ff_scores = checkpoint["best_ff_scores"]
        self.global_best_iso_scores = checkpoint["best_global_iso_scores"]
        self.num_bad_epochs = checkpoint["num_bad_epochs"]
        self.should_stop_now = checkpoint["should_stop_now"]
        return checkpoint

    def save_best_model(self, model, epoch, model_save_type='iso'):
        self.logger.info(f"saving best validated model [type::{model_save_type}] to {self.best_model_path}_{model_save_type}")
        output = open(f'{self.best_model_path}_{model_save_type}', mode="wb")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "epoch": epoch,
                "best_iso_scores": self.best_iso_scores,
                "best_ff_scores": self.best_ff_scores,
                "best_global_iso_scores": self.global_best_iso_scores,
                "num_bad_epochs": self.num_bad_epochs,
            },
            output,
        )
        output.close()

    def load_best_model(self, model_load_type='ff'):
        self.logger.info(f"loading best validated model from {self.best_model_path}_{model_load_type}")
        checkpoint = torch.load(f'{self.best_model_path}_{model_load_type}')
        return checkpoint

    def diff(self, curr_scores, best_scores):
        return sum([cs - bs for cs, bs in zip(curr_scores, best_scores)])

    
    def iso_value_check(self, iso_scores):
        assert iso_scores[0] < 0 and len(iso_scores)==1 
        return iso_scores[0] >= (self.iso_scores_threshold * self.global_best_iso_scores[0])
    
    def dual_check(self, iso_scores, ff_scores, model, epoch, optimizer):
        """
        Check the iso_scores and ff_scores to determine if the model should be saved as the best model.
        
        Parameters:
        - iso_scores: The iso_scores of the current epoch.
        - ff_scores: The ff_scores of the current epoch.
        - model: The model being trained.
        - epoch: The current epoch number.
        - optimizer: The optimizer used for training.
        
        Returns:
        - should_stop_now: A boolean indicating whether training should stop.
        """
        
        if self.best_iso_scores is None:
            # If this is the first epoch, save the current scores as the best scores
            self.best_iso_scores = iso_scores
            self.best_ff_scores = ff_scores
            self.global_best_iso_scores = iso_scores
            self.save_best_model(model, epoch, 'iso')
            self.save_best_model(model, epoch, 'ff')
            
  
        elif not self.has_stabilized and self.diff(iso_scores, self.global_best_iso_scores) >= self.delta:
            # If the iso_scores have improved significantly, update the best scores and save the model
            self.num_bad_epochs = 0
            self.best_iso_scores = iso_scores
            self.best_ff_scores = ff_scores
            self.global_best_iso_scores = iso_scores
            self.save_best_model(model, epoch, 'iso')
            self.save_best_model(model, epoch, 'ff')
            
        elif (self.num_bad_epochs < self.patience) and self.has_stabilized and \
             (self.iso_value_check(iso_scores)) and (self.diff(ff_scores, self.best_ff_scores) >= self.delta):
            # If the model has stabilized and ff_scores have improved, and iso_scores have not deteriorated, update the best scores and save the model
            # self.num_bad_epochs += 1
            self.num_bad_epochs = 0
            self.best_iso_scores = iso_scores
            self.best_ff_scores = ff_scores
            self.save_best_model(model, epoch, 'ff')
            
        else:
            # Otherwise, increment the number of bad epochs
            self.num_bad_epochs += 1
            if not self.has_stabilized:
                if self.num_bad_epochs > self.stabilization_patience:
                    self.has_stabilized = True
                    banner = pyfiglet.figlet_format("Stabilized", font="slant", justify="center")
                    self.logger.info(banner)
            if self.num_bad_epochs > self.patience:
                self.should_stop_now = True
        self.save_latest_model(model, epoch, optimizer)
        return self.should_stop_now
