from typing import Callable, Any
import numpy as np
from torch import Tensor
from torch.optim import Optimizer
import random
import torch
import torch.nn as nn
import time
from datasets.data_utils import evaluate_detect
from libs.ood.utils import evaluate_detection
from libs.distance_utils import (calculate_diffusion_distances_alpha,
                                 calculate_graph_laplacian,
                                 calculate_resistance_distances, 
                                 calculate_matern_covariance,
                                 construct_kernel_matrix,
                                 calculate_pairwise_shortest_distances,
                                 calculate_exponential_covariance)
from libs.rewire_utils import (add_high_corr_edges_optimized,
                               remove_low_corr_edges_optimized,
                               edge_corr_percentile,
                               compute_LI_edge)

class OODDetectionExp:
    def __init__(
        self,
        cfg: dict,
        cfg_model: dict,
        model: nn.Module,
        criterion: Callable[[Tensor, Tensor], Tensor],
        eval_func: Callable[[Tensor, Tensor], Tensor],
        optimizer: Optimizer,
        warmup_optimizer: Optimizer = None,
        reporter: Any = None,
        dataset_ind: Any = None,
        dataset_ood_tr: Any = None,
        dataset_ood_te: Any = None
    ):
        self.cfg = cfg
        self.cfg_model = cfg_model
        self.model = model
        self.criterion = criterion
        self.eval_func = eval_func
        self.optimizer = optimizer
        self.warmup_optimizer = warmup_optimizer
        self.device = cfg["device"]
        self.num_epochs = cfg["epochs"]
        self.reporter = reporter
        self.model.to(self.device)
        self.dataset_ind = dataset_ind
        self.dataset_ood_tr = dataset_ood_tr
        self.dataset_ood_te = dataset_ood_te
        self.seed = cfg["seed"]
        self.order = cfg.get("order", 15)
        self.profiling = cfg.get("profiling", False)  # Add profiling flag
        
        # Get gradient clipping parameters from optimizer config
        self.use_grad_clip = False
        self.clip_value = 1.0
        if "optimizer" in self.cfg_model and "grad_clip" in self.cfg_model["optimizer"]:
            self.use_grad_clip = self.cfg_model["optimizer"]["grad_clip"]
            if "clip_value" in self.cfg_model["optimizer"]:
                self.clip_value = self.cfg_model["optimizer"]["clip_value"]

        # for the SDE matern kernel
        if self.cfg["name"] == "gspde":
            self.diffusion_alpha = self.cfg["diffusion_alpha"]
            self.alpha = self.cfg["alpha"]
            self.nu = self.cfg["nu"]
            self.kappa = self.cfg["kappa"]
            self.polynomial_type = self.cfg.get("polynomial_type", "rational")  # Default to rational if not specified
        
        self.laplacian = calculate_graph_laplacian(self.dataset_ind.edge_index, self.dataset_ind.num_nodes)

        # for the covariance structure
        if self.cfg["name"] == "gspde":
            if self.cfg["covariance_structure"] == "resistance_distances":
                if self.profiling:
                    start_time = time.time()
                self.resistance_distances = calculate_resistance_distances(self.dataset_ind.edge_index, self.dataset_ind.num_nodes)
                if self.profiling:
                    preprocess_time = time.time() - start_time
                    print(f"Preprocessing time: {preprocess_time:.4f} seconds")
            
            if self.cfg["covariance_structure"] == "shortest_distances":
                if self.profiling:
                    start_time = time.time()
                self.shortest_distances = calculate_pairwise_shortest_distances(self.dataset_ind.edge_index, self.dataset_ind.num_nodes)
                if self.profiling:
                    preprocess_time = time.time() - start_time
                    print(f"Preprocessing time: {preprocess_time:.4f} seconds")
            
            if self.cfg["covariance_structure"] == "diffusion_distances":
                if self.profiling:
                    start_time = time.time()
                self.diffusion_distances = calculate_diffusion_distances_alpha(self.dataset_ind.edge_index, 
                                                                    self.dataset_ind.num_nodes, alpha=self.diffusion_alpha)
                if self.profiling:
                    preprocess_time = time.time() - start_time
                    print(f"Preprocessing time: {preprocess_time:.4f} seconds")
            
            if self.cfg["covariance_structure"] == "matern":
                if self.profiling:
                    start_time = time.time()
                self.matern_kernel = calculate_matern_covariance(self.laplacian, self.nu, self.kappa, 
                                                           polynomial_type=self.polynomial_type, order=self.order)
                if self.profiling:
                    preprocess_time = time.time() - start_time
                    print(f"Preprocessing time: {preprocess_time:.4f} seconds")
            
            if self.cfg["covariance_structure"] == "exponential":
                if self.profiling:
                    start_time = time.time()
                self.exponential_kernel = calculate_exponential_covariance(self.laplacian, self.kappa)
                if self.profiling:
                    preprocess_time = time.time() - start_time
                    print(f"Preprocessing time: {preprocess_time:.4f} seconds")

        self.fix_seed(self.seed)

        # for the rewiring case
        # if self.cfg["rewiring"] == True:
        #     self.nu = self.cfg["nu"]
        #     self.kappa = self.cfg["kappa"]
        #     self.polynomial_type = self.cfg.get("polynomial_type", "chebyshev")
        #     self.order = self.cfg.get("order", 30)
        #     original_LI_edge = compute_LI_edge(self.dataset_ind.edge_index, self.dataset_ind.y)
            
        #     if self.profiling:
        #         start_time = time.time()
        #     self.matern_kernel = calculate_matern_covariance(self.laplacian,
        #                                                    self.nu, 
        #                                                    self.kappa, 
        #                                                    polynomial_type=self.polynomial_type, 
        #                                                    order=self.order)
        #     if self.profiling:
        #         preprocess_time = time.time() - start_time
        #         print(f"Preprocessing time: {preprocess_time:.4f} seconds")
            
        #     percentile_low = self.cfg["percentile_low"]
        #     percentile_high = self.cfg["percentile_high"]
        #     threshold_low = edge_corr_percentile(self.dataset_ind.edge_index, self.matern_kernel, percentile_low)
        #     threshold_high = edge_corr_percentile(self.dataset_ind.edge_index, self.matern_kernel, percentile_high)
        #     rewired_edges = remove_low_corr_edges_optimized(self.dataset_ind.edge_index, self.matern_kernel, threshold_low)
        #     rewired_edges = add_high_corr_edges_optimized(rewired_edges, self.matern_kernel, threshold_high)
        #     print(f"Before rewiring, n_edges = {self.dataset_ind.edge_index.shape[1]}, after rewiring, n_edges = {rewired_edges.shape[1]}")
        #     self.dataset_ind.edge_index = rewired_edges
        #     self.dataset_ood_tr.edge_index = rewired_edges
        #     self.dataset_ood_te.edge_index = rewired_edges
        #     new_LI_edge = compute_LI_edge(rewired_edges, self.dataset_ind.y)
        #     print(f"Before rewiring, LI = {original_LI_edge}, after rewiring, LI = {new_LI_edge}")
        #     print(f"After rewiring, multiplier for LI = {new_LI_edge/original_LI_edge}")

    def fix_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
            

    def reset_model(self):
        self.model.reset_parameters()

    def run(self):
        self.model.train()
        self.model.to(self.device)

        # set cholesky factor
        if self.cfg["name"] == "gspde":
            self.model.set_cholesky(self.cfg["use_cholesky"])
        # check covaraince structure validity
        if self.cfg["name"] == "gspde":
            if self.cfg["covariance_structure"] == "laplacian" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Laplacian kernel is not used for exponential kernel")
            if self.cfg["covariance_structure"] == "laplacian_squared" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Laplacian squared kernel is not used for exponential kernel")
            if self.cfg["covariance_structure"] == "resistance_distances" and self.cfg["covariance_kernel"] == False:
                raise ValueError("Resistance distances needs to be used for exponential kernel")
            if self.cfg["covariance_structure"] == "diffusion_distances" and self.cfg["covariance_kernel"] == False:
                raise ValueError("Diffusion distances needs to be used for exponential kernel")
            if self.cfg["covariance_structure"] == "matern" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Matern kernel is not used for exponential kernel")
            if self.cfg["covariance_structure"] == "exponential" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Exponential kernel is not used for exponential kernel")

        # covariance structure must be positive definite
        if self.cfg["name"] == "gspde" and self.cfg["covariance_kernel"] == False:
            if self.cfg["covariance_structure"] == "laplacian":
                self.model.set_covariance_matrix(self.laplacian + torch.eye(self.laplacian.shape[0]) * 1e-6)
            elif self.cfg["covariance_structure"] == "laplacian_squared":
                self.model.set_covariance_matrix(self.laplacian @ self.laplacian.T + torch.eye(self.laplacian.shape[0]) * 1e-6)
            elif self.cfg["covariance_structure"] == "diffusion_distances":
                self.model.set_covariance_matrix(self.diffusion_distances)
            elif self.cfg["covariance_structure"] == "matern":
                self.model.set_covariance_matrix(self.matern_kernel)
            elif self.cfg["covariance_structure"] == "exponential":
                self.model.set_covariance_matrix(self.exponential_kernel)
            else:
                raise ValueError(f"Unknown covariance structure: {self.cfg['covariance_structure']}")

        elif self.cfg["name"] == "gspde" and self.cfg["covariance_kernel"] == True:
            if self.cfg["covariance_structure"] == "resistance_distances":
                self.model.set_covariance_matrix(construct_kernel_matrix(self.resistance_distances, self.alpha))
            elif self.cfg["covariance_structure"] == "diffusion_distances":
                self.model.set_covariance_matrix(construct_kernel_matrix(self.diffusion_distances, self.alpha))
            else:
                raise ValueError(f"Unknown covariance structure: {self.cfg['covariance_structure']}")
       
        # If profiling is enabled, measure one iteration time and exit
        if self.profiling:
            # Measure one iteration of training
            self.model.train()
            start_time = time.time()
            loss = self.model.loss_compute(self.dataset_ind, self.dataset_ood_tr, self.criterion, self.device, self.cfg_model)
            iteration_time = time.time() - start_time
            print(f"Training time (one iteration): {iteration_time:.4f} seconds")
            
            # Return early with a dummy result
            dummy_results = {"profiling_completed": True}
            return dummy_results
        
        # Regular training loop if not profiling
        for epoch in range(self.num_epochs):
            self.model.train()
            if self.cfg["name"] == "gpn":
                if epoch < self.cfg["GPN_warmup"]:
                    self.warmup_optimizer.zero_grad()
                    loss = self.model.loss_compute(self.dataset_ind, self.dataset_ood_tr, self.criterion, self.device, self.cfg_model)
                    loss.backward()
                    if self.use_grad_clip:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value)
                    self.warmup_optimizer.step()
                else:
                    self.optimizer.zero_grad()
                    loss = self.model.loss_compute(self.dataset_ind, self.dataset_ood_tr, self.criterion, self.device, self.cfg_model)
                    loss.backward()
                    if self.use_grad_clip:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value)
                    self.optimizer.step()
            else: # non-gpn for now gnn-safe
                self.optimizer.zero_grad()
                loss = self.model.loss_compute(self.dataset_ind, self.dataset_ood_tr, self.criterion, self.device, self.cfg_model)
                loss.backward()
                if self.use_grad_clip:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_value)
                self.optimizer.step()
            
            # evaluate
            # TODO: get rid of gspde option and use evaluate_detect only. 
            if self.cfg["name"] == "gspde" or self.cfg["name"] == "gnsd":
                acc_res, loss_res, ood_res = evaluate_detection(self.model, self.dataset_ind, 
                                                                self.dataset_ood_te, self.criterion, 
                                                                self.eval_func, self.cfg_model, self.device)
                self.reporter.add_result(loss_res+acc_res+ood_res) 
            else:
                acc_res, loss_res, ood_res = evaluate_detect(self.model, self.dataset_ind, 
                                                             self.dataset_ood_te, self.criterion, 
                                                             self.eval_func, self.cfg_model, self.device)
                self.reporter.add_result(loss_res+acc_res+ood_res)

            if epoch % self.cfg["display_step"] == 0:
                str = f'Epoch: {epoch:03d},' + \
                    f'Train Loss: {loss.item():.4f}, ' + \
                    f'Valid Loss: {loss_res[1]:.4f}, ' + \
                    f'Test Loss: {loss_res[2]:.4f}, ' + \
                    f'AUROC: {100 * ood_res[0]:.2f}%, ' + \
                    f'AUPR_in: {100 * ood_res[1]:.2f}%, ' + \
                    f'AUPR_out: {100 * ood_res[2]:.2f}%, ' + \
                    f'FPR95: {100 * ood_res[3]:.2f}%, ' + \
                    f'Detection Acc: {100 * ood_res[4]:.2f}%, ' + \
                    f"Train Score: {100 * acc_res[0]:.2f}%, " + \
                    f"Valid Score: {100 * acc_res[1]:.2f}%, " + \
                    f"Test Score: {100 * acc_res[2]:.2f}%"
                metrics = {
                    "epoch": epoch,
                    "Train Loss": loss.item(),
                    "Valid Loss": loss_res[1],
                    "Test Loss": loss_res[2],
                    "AUROC": 100 * ood_res[0],
                    "AUPR_in": 100 * ood_res[1],
                    "AUPR_out": 100 * ood_res[2],
                    "FPR95": 100 * ood_res[3],
                    "Detection Acc": 100 * ood_res[4],
                    "Train Score": 100 * acc_res[0],
                    "Valid Score": 100 * acc_res[1],
                    "Test Score": 100 * acc_res[2],
                }
                print(str)
                self.reporter.report(metrics)
              
        report_dict, _ = self.reporter.print_statistics()
        return report_dict  
        
