import numpy as np
import random
import os
import h5py
from npeet import entropy_estimators as ee
from scipy.stats import entropy as stats_entropy

from collections import defaultdict

from tango import Step
from tango.common import Registrable, Lazy

import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import Callback, LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.profiler import PyTorchProfiler

from .deep_recon import LightningModule
from .dataset import DataModule, BrainDataModule
from .model import ReconstructionTransformer
from .utils import expand_task_and_dataset, extract_dataset_name_and_task_name, disbased_weight
from scipy.linalg import solve
from scipy.io import savemat

import math
from tqdm import tqdm, trange
import json
import coolname

from typing import Optional, Dict


class TemperatureUpdater(Callback):
    def __init__(self, tau_min, tau_max) -> None:
        self.tau_max = tau_max
        self.tau_min = tau_min
        
    def setup(self, trainer, pl_module, stage):
        pl_module.tau = self.tau_max
        
    def on_train_epoch_start(self, trainer, pl_module):
        tau_min = self.tau_min
        tau_max = self.tau_max
        
        r = trainer.current_epoch / trainer.max_epochs
        pl_module.tau = tau_max * math.exp(r * math.log(tau_min / tau_max))


class LightningTrainer(pl.Trainer, Registrable):
    default_implementation = "default"
LightningTrainer.register("default")(LightningTrainer)


@Step.register("logger")
class Logger(Step):
    CACHEABLE = False
    def run(
        self,
        group_name: Optional[str] = None,
        run_name: Optional[str] = None,
    ):
        if run_name is None:
            random.seed(None)
            run_name = coolname.generate_slug(3)
        print(f"Run name: {run_name}")
        dir_path = os.path.dirname(os.path.realpath(__file__))
        result_path = os.path.join(dir_path, "results")
        logger = WandbLogger(
            project="brain-hyperedge-IB",
            save_dir=result_path,
            group=group_name,
            name=run_name,
        )

        output_path = os.path.join(logger.save_dir, "ckpts", logger._name)
        while os.path.exists(output_path):
            version_id = 2
            while output_path[-1] == '/':
                output_path = output_path[:-1]
            output_path = output_path + f"-v{version_id}"
            version_id = version_id + 1
        os.makedirs(output_path, exist_ok=False)
        logger.output_path = output_path
        
        return logger


@Step.register("compute_entropy")
class ComputeEntropy(Step):
    CACHEABLE = True
    VERSION = "0004"
    def run(
        self,
        data_module: DataModule
    ) -> Dict[str, np.ndarray]:
        data_module.setup("predict")
        loaders = data_module.predict_dataloader()
        all_data = defaultdict(list)
        for loader in loaders:
            for batch in loader:
                x = batch['x']
                seq_len = x.size(-1)
                for start in range(seq_len):
                    task = batch["meta"][0]["task"]
                    all_data[task].append(x[..., start])
                    
        entropy_all_task = dict()
        for task, x in all_data.items():
            x = torch.cat(x, dim=0) # (n_sample, n_nodes, d)
            x = x.transpose(0, 1)
            x = list(x.numpy())
            
            bins = np.linspace(-2.5, 2.5, 100)
            entropies = []
            for xx in x:
                hist = np.histogram(xx, bins=bins, density=True)[0]
                entropy = stats_entropy(hist)
                entropies.append(entropy)

            entropy_all_task[task] = np.asarray(entropies)
        return entropy_all_task

@Step.register("train")
class RegressTrain(Step):
    CACHEABLE = False
    def run(
        self,
        load_name: Optional[str],
        seed: int,
        trainer: Lazy[LightningTrainer],
        model: Lazy[LightningModule],
        data_module: DataModule,
        logger: WandbLogger,
        load_from_init: bool = False,
    ) -> None:
        pl.seed_everything(seed, workers=True)

        model = model.construct(mat_save_path=os.path.join(logger.save_dir, "hyperedges", logger._name))
        logger.watch(model)

        if load_from_init:
            model.load_state_dict(torch.load(os.path.join(logger.save_dir, "ckpts", "init.ckpt")))

        checkpoint_callback = ModelCheckpoint(dirpath=logger.output_path,
                                            filename="epoch_{epoch}",
                                            save_top_k=2,
                                            monitor="val/loss",
                                            mode='min',
                                            save_last=True)
    
        trainer = trainer.construct(
            logger=logger,
            # fast_dev_run=3,
            # default_root_dir="results/ckpts/",
            gradient_clip_val=0.5,
            # gradient_clip_algorithm="value",
            # num_sanity_val_steps=0,
            callbacks=[checkpoint_callback, LearningRateMonitor()],
            # profiler=PyTorchProfiler(filename="profile"),
            deterministic=True
        )
        
        # load_name = "seed-42-true-false-true"
        if load_name is not None:
            return load_name
        else:
            trainer.fit(
                model,
                data_module,
            )
            trainer.predict(model, data_module)
            return logger._name

@Step.register("knn")
class KNearestNeighbors(Step):
    CACHEABLE = False
    def run(
        self,
        k: int,
        filt: str,
        data_path: str,
        logger: WandbLogger,
    ):        
        from sklearn.linear_model import Lasso
        from sklearn.feature_selection import mutual_info_regression
        import h5py
        import os
        from .utils import disbased_weight
        from scipy.io import savemat
        
        tolerance = 1e-2
        
        datasets = expand_task_and_dataset(data_path)
        for dataset in datasets:
            dataset_name, task_name = extract_dataset_name_and_task_name(dataset)
            hyperassign = {}
            file = h5py.File(dataset, "r")
            X = []
            Y = []
            for key in file.keys():
                data = file[key]
                pearson = data["pearson"][()]
                eucli = data["euclidean_dis"]
                sigma = np.median(eucli)
                x = data["x"][()]
                try:
                    y = data["y"][()]
                except:
                    continue
                num_nodes = x.shape[0]
                hyperedges = np.zeros((num_nodes, num_nodes))
                for i in range(num_nodes):
                    row_person = pearson[i]
                    row_person[i] = 0
                    current_node = x[i]
                    knn = np.argsort(row_person)[-k:]
                    knn_nodes = x[knn]
                    
                    hyperassign[i] = knn.tolist()

                    if filt == "lasso":
                        f = Lasso(alpha=0.01, fit_intercept=False)
                        f.fit(knn_nodes.T, current_node)
                        coef = f.coef_
                        mask = ~np.isclose(coef, 0, atol=tolerance) # define your tolerance here
                    elif filt == "mi":
                        mi = mutual_info_regression(knn_nodes.T, current_node)
                        mask = (mi > 0.2) # define your threshold here
                    elif filt == "none":
                        mask = None
                    else:
                        raise NotImplementedError
                    
                    if mask is not None:
                        knn = knn[mask]
                        knn_nodes = x[knn]
                    
                    knn = np.append(knn,i)            
                    hyperedges[i][knn] = disbased_weight(current_node, knn_nodes, sigma)

                X.append(hyperedges)
                Y.append(y)
                
            X = np.array(X)
            Y = np.array(Y)
            X = X.reshape(X.shape[0], np.prod(X.shape[1:]))
            X = np.transpose(X, (-1, 0))
            X = np.concatenate([X, -X], axis=0)
            Y = Y.reshape(-1, 1)
            assert X.ndim == 2
    
            mat_save_path = os.path.join(logger.save_dir, "hyperedges", logger._name)
            savemat(mat_save_path + f"_{dataset_name}_{task_name}.mat", dict(rest_1_mats=X, PMAT_CR=Y))
            
        return logger._name


@Step.register("l2-hypergraph")
class L2Hypergraph(Step):
    CACHEABLE = False
    def run(
        self,
        data_path: str,
        logger: WandbLogger,
    ):        
        def tau_mu(Z, mu):
            return np.sign(Z) * np.maximum(np.abs(Z) - mu, 0)

        def admm_optimization(X, Q, lambda1, lambda2, rho, mu, epsilon=0.001, max_iterations=1000):
            X = X.T
            N = X.shape[1]
            
            P = np.zeros((N, N))
            C = np.zeros((N, N))
            E = np.zeros(X.shape)
            M1 = np.zeros((N, 1))
            M2 = np.zeros((N, N))
            
            I = np.identity(N)
            ones = np.ones((N, 1))
            
            for k in range(max_iterations):
                # P-Update
                left_P = X.T @ X + rho * I + rho * ones @ ones.T
                right_P = X.T @ (X - E) + rho * (ones @ ones.T + C) - ones @ M1.T - M2
                P = solve(left_P, right_P)
                # C-Update
                left_C = (rho + lambda1) * I + lambda2 * np.diag(Q**2)
                right_C = rho * P + M2
                C_star = solve(left_C, right_C)
                C = C_star - np.diag(np.diag(C_star))
                
                # E-Update
                E = tau_mu(X @ P - X, mu/rho)
                
                # Gradient Ascent Update
                M1 = M1 + rho * (P.T @ ones - ones)
                M2 = M2 + rho * (P - C)
                
                # Convergence conditions
                conditions = [
                    np.linalg.norm(P.T @ ones - ones, np.inf),
                    np.linalg.norm(P - C, np.inf),
                    np.linalg.norm(P - P, np.inf),
                    np.linalg.norm(E - E, np.inf)
                ]
                
                if all(condition <= epsilon for condition in conditions):
                    break
            return P, C, E
        def compute_Q_matrix(X):
            pairwise_distances = np.linalg.norm(X[:, np.newaxis] - X, axis=2)
            # 2. Calculate the exponential of pairwise distances
            exp_distances = np.exp(pairwise_distances)
            
            # 3. Calculate q_{ij} using the given conditions
            denominator = np.sum(exp_distances, axis=1, keepdims=True) - np.diagonal(exp_distances)[:, np.newaxis]
            Q = exp_distances / denominator
            
            # Set diagonal elements to 0
            np.fill_diagonal(Q, 0)
            
            return Q
        lambda1 = lambda2 = 0.1
        mu = 1
        rho = 0.01
        datasets = expand_task_and_dataset(data_path)
        for dataset in datasets:
            dataset_name, task_name = extract_dataset_name_and_task_name(dataset)
            hyperassign = {}
            file = h5py.File(dataset, "r")
            hyperedges = []
            scores = []

            for key in tqdm(file.keys()):
                data = file[key]
                X = data['x'][()]

                try:
                    y = data['y'][()]
                except:
                    continue
            
                Q = compute_Q_matrix(X)
                P, C, E = admm_optimization(X, Q, lambda1, lambda2, rho, mu)
                S = C.T @ C
                w = S.sum(1)
                hyperedges.append(w)
                scores.append(y)


            X = np.array(hyperedges)
            Y = np.array(scores)
            X = X.reshape(X.shape[0], np.prod(X.shape[1:]))
            X = np.transpose(X, (-1, 0))
            Y = Y.reshape(-1, 1)
            assert X.ndim == 2

            mat_save_path = os.path.join(logger.save_dir, "hyperedges", logger._name)
            savemat(mat_save_path + f"_{dataset_name}_{task_name}.mat", dict(rest_1_mats=X, PMAT_CR=Y))
            
        return logger._name


if __name__ == "__main__":
    train = Train()
    train.run()
    