from typing import Callable, Any
import numpy as np
from torch import Tensor
from torch.optim import Optimizer
import random
import torch
import torch.nn as nn
from datasets.data_utils import evaluate_detect
from libs.ood.utils import evaluate_detection
from libs.distance_utils import (calculate_diffusion_distances_alpha,
                                 calculate_graph_laplacian,
                                 calculate_resistance_distances, 
                                 calculate_matern_covariance,
                                 construct_kernel_matrix,
                                 calculate_pairwise_shortest_distances,
                                 calculate_exponential_covariance)


class OODDetectionExp:
    def __init__(
        self,
        cfg: dict,
        cfg_model: dict,
        model: nn.Module,
        criterion: Callable[[Tensor, Tensor], Tensor],
        eval_func: Callable[[Tensor, Tensor], Tensor],
        optimizer: Optimizer,
        warmup_optimizer: Optimizer = None,
        reporter: Any = None,
        dataset_ind: Any = None,
        dataset_ood_tr: Any = None,
        dataset_ood_te: Any = None
    ):
        self.cfg = cfg
        self.cfg_model = cfg_model
        self.model = model
        self.criterion = criterion
        self.eval_func = eval_func
        self.optimizer = optimizer
        self.warmup_optimizer = warmup_optimizer
        self.device = cfg["device"]
        self.num_epochs = cfg["epochs"]
        self.reporter = reporter
        self.model.to(self.device)
        self.dataset_ind = dataset_ind
        self.dataset_ood_tr = dataset_ood_tr
        self.dataset_ood_te = dataset_ood_te
        self.seed = cfg["seed"]

        # for the SDE matern kernel
        if self.cfg["name"] == "gspde":
            self.diffusion_alpha = self.cfg["diffusion_alpha"]
            self.alpha = self.cfg["alpha"]
            self.nu = self.cfg["nu"]
            self.kappa = self.cfg["kappa"]
        
        self.laplacian = calculate_graph_laplacian(self.dataset_ind.edge_index, self.dataset_ind.num_nodes)

        # for the covariance structure
        if self.cfg["name"] == "gspde":
            if self.cfg["covariance_structure"] == "resistance_distances":
                self.resistance_distances = calculate_resistance_distances(self.dataset_ind.edge_index, self.dataset_ind.num_nodes)
            if self.cfg["covariance_structure"] == "shortest_distances":
                self.shortest_distances = calculate_pairwise_shortest_distances(self.dataset_ind.edge_index, self.dataset_ind.num_nodes)
            if self.cfg["covariance_structure"] == "diffusion_distances":
                self.diffusion_distances = calculate_diffusion_distances_alpha(self.dataset_ind.edge_index, 
                                                                        self.dataset_ind.num_nodes, alpha=self.diffusion_alpha)
            if self.cfg["covariance_structure"] == "matern":
                self.matern_distances = calculate_matern_covariance(self.laplacian, self.nu, self.kappa)
            if self.cfg["covariance_structure"] == "exponential":
                self.exponential_distances = calculate_exponential_covariance(self.laplacian, self.kappa)

        self.fix_seed(self.seed)


    def fix_seed(self, seed: int):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True
            

    def reset_model(self):
        self.model.reset_parameters()

    def run(self):
        self.model.train()
        self.model.to(self.device)


        # set cholesky factor
        if self.cfg["name"] == "gspde":
            self.model.use_cholesky = self.cfg["use_cholesky"]
        # check covaraince structure validity
        if self.cfg["name"] == "gspde":
            if self.cfg["covariance_structure"] == "laplacian" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Laplacian kernel is not used for exponential kernel")
            if self.cfg["covariance_structure"] == "laplacian_squared" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Laplacian squared kernel is not used for exponential kernel")
            if self.cfg["covariance_structure"] == "resistance_distances" and self.cfg["covariance_kernel"] == False:
                raise ValueError("Resistance distances needs to be used for exponential kernel")
            if self.cfg["covariance_structure"] == "diffusion_distances" and self.cfg["covariance_kernel"] == False:
                raise ValueError("Diffusion distances needs to be used for exponential kernel")
            if self.cfg["covariance_structure"] == "matern" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Matern kernel is not used for exponential kernel")
            if self.cfg["covariance_structure"] == "exponential" and self.cfg["covariance_kernel"] == True:
                raise ValueError("Exponential kernel is not used for exponential kernel")

        # covariance structure must be positive definite
        if self.cfg["name"] == "gspde" and self.cfg["covariance_kernel"] == False:
            if self.cfg["covariance_structure"] == "laplacian":
                self.model.set_covariance_matrix(self.laplacian + torch.eye(self.laplacian.shape[0]) * 1e-6)
            elif self.cfg["covariance_structure"] == "laplacian_squared":
                self.model.set_covariance_matrix(self.laplacian @ self.laplacian.T + torch.eye(self.laplacian.shape[0]) * 1e-6)
            elif self.cfg["covariance_structure"] == "diffusion_distances":
                self.model.set_covariance_matrix(self.diffusion_distances)
            else:
                raise ValueError(f"Unknown covariance structure: {self.cfg['covariance_structure']}")

        elif self.cfg["name"] == "gspde" and self.cfg["covariance_kernel"] == True:
            if self.cfg["covariance_structure"] == "resistance_distances":
                self.model.set_covariance_matrix(construct_kernel_matrix(self.resistance_distances, self.alpha))
            elif self.cfg["covariance_structure"] == "diffusion_distances":
                self.model.set_covariance_matrix(construct_kernel_matrix(self.diffusion_distances, self.alpha))
            else:
                raise ValueError(f"Unknown covariance structure: {self.cfg['covariance_structure']}")
       
        for epoch in range(self.num_epochs):
            self.model.train()
            if self.cfg["name"] == "gpn":
                if epoch < self.cfg["GPN_warmup"]:
                    self.warmup_optimizer.zero_grad()
                    loss = self.model.loss_compute(self.dataset_ind, self.dataset_ood_tr, self.criterion, self.device, self.cfg_model)
                    loss.backward()
                    self.warmup_optimizer.step()
                else:
                    self.optimizer.zero_grad()
                    loss = self.model.loss_compute(self.dataset_ind, self.dataset_ood_tr, self.criterion, self.device, self.cfg_model)
                    loss.backward()
                    self.optimizer.step()
            else: # non-gpn for now gnn-safe
                self.optimizer.zero_grad()
                loss = self.model.loss_compute(self.dataset_ind, self.dataset_ood_tr, self.criterion, self.device, self.cfg_model)
                loss.backward()
                self.optimizer.step()
            
            # evaluate
            # TODO: get rid of gspde option and use evaluate_detect only. 
            if self.cfg["name"] == "gspde" or self.cfg["name"] == "gnsd":
                acc_res, loss_res, ood_res = evaluate_detection(self.model, self.dataset_ind, 
                                                                self.dataset_ood_te, self.criterion, 
                                                                self.eval_func, self.cfg_model, self.device)
                self.reporter.add_result(loss_res+acc_res+ood_res) 
            else:
                acc_res, loss_res, ood_res = evaluate_detect(self.model, self.dataset_ind, self.dataset_ood_te, self.criterion, self.eval_func, self.cfg_model, "cpu")
                self.reporter.add_result(loss_res+acc_res+ood_res)

            if epoch % self.cfg["display_step"] == 0:
                str = f'Epoch: {epoch:03d},' + \
                    f'Train Loss: {loss.item():.4f}, ' + \
                    f'Valid Loss: {loss_res[1]:.4f}, ' + \
                    f'Test Loss: {loss_res[2]:.4f}, ' + \
                    f'AUROC: {100 * ood_res[0]:.2f}%, ' + \
                    f'AUPR_in: {100 * ood_res[1]:.2f}%, ' + \
                    f'AUPR_out: {100 * ood_res[2]:.2f}%, ' + \
                    f'FPR95: {100 * ood_res[3]:.2f}%, ' + \
                    f'Detection Acc: {100 * ood_res[4]:.2f}%, ' + \
                    f"Train Score: {100 * acc_res[0]:.2f}%, " + \
                    f"Valid Score: {100 * acc_res[1]:.2f}%, " + \
                    f"Test Score: {100 * acc_res[2]:.2f}%"
                metrics = {
                    "epoch": epoch,
                    "Train Loss": loss.item(),
                    "Valid Loss": loss_res[1],
                    "Test Loss": loss_res[2],
                    "AUROC": 100 * ood_res[0],
                    "AUPR_in": 100 * ood_res[1],
                    "AUPR_out": 100 * ood_res[2],
                    "FPR95": 100 * ood_res[3],
                    "Detection Acc": 100 * ood_res[4],
                    "Train Score": 100 * acc_res[0],
                    "Valid Score": 100 * acc_res[1],
                    "Test Score": 100 * acc_res[2],
                }
                print(str)
                self.reporter.report(metrics)
              
        report_dict, _ = self.reporter.print_statistics()
        return report_dict  
        
