import os
import math
import numpy as np
import torch
from pathlib import Path
from tqdm import tqdm
from torchkde import KernelDensity
from KDEpy import TreeKDE
from sklearn.model_selection import KFold
from margflow.model_utils import log_prob_gaussian_mixture, sample_gaussian_mixture
from numpy.typing import NDArray


class DatasetIdentifier:
    def __init__(self, args):
        super(DatasetIdentifier, self).__init__()
        self.args = args
        n_samples = (
            args.n_simulations if args.dataset.split("_")[0] == "sbi" else args.n_samples_dataset
        )
        self.dataset_suffix = f"d{args.x_dim:d}_n{n_samples:d}_e{args.epsilon:.2f}"
        self.dataset_folder = Path(args.script_path) / f"data/{args.dataset}"
        if not os.path.exists(self.dataset_folder):
            os.makedirs(self.dataset_folder)


# Dataset providing discrete samples
class DiscreteSampleDataset(DatasetIdentifier):
    def __init__(self, args):
        super(DiscreteSampleDataset, self).__init__(args)
        self.dataset_suffix = "ML_" + self.dataset_suffix

    def sample(self, num_samples: int, data_type: str = "train") -> torch.Tensor:
        raise NotImplementedError

    def load_samples(self, split, n_samples, seed, overwrite=False) -> np.array:
        filename = f"{self.dataset_folder}/x_{split}{self.dataset_suffix}_seed{seed}.npz"
        if os.path.isfile(filename) and not overwrite:
            # samples = np.load(filename)
            binary_data = np.load(filename)
            samples = binary_data["samples"]
            try:
                context = binary_data["context"]
                data = (samples, context)
            except:
                context = None
                data = samples
        else:
            context = None
            samples = self.sample(n_samples, data_type=split)
            if isinstance(samples, tuple):
                samples, context = (
                    samples[0].detach().cpu().numpy(),
                    samples[1].detach().cpu().numpy(),
                )
            else:
                samples = samples.detach().cpu().numpy()
            noise_train = np.random.normal(
                loc=0.0, scale=1.0 * self.args.epsilon, size=samples.shape
            )
            samples = samples + noise_train
            if context is not None:
                data = (samples, context)
                np.savez(filename, samples=samples, context=context)
            else:
                data = samples
                np.savez(filename, samples=samples)
        return data

    def load_dataset(self, overwrite=False):
        np.random.seed(self.args.seed)
        torch.manual_seed(self.args.seed)
        if self.args.dataset.split("_")[0] == "sbi":
            val_fraction = 0.1
            theta, x = self.load_samples(
                "train",
                n_samples=self.args.n_simulations,
                seed=self.args.seed,
                overwrite=overwrite,
            )
            split = math.ceil(val_fraction * self.args.n_simulations)
            theta_train, theta_val, theta_test = (
                theta[split:, :],
                theta[:split, :],
                theta[:split, :],
            )
            x_train, x_val, x_test = x[split:], x[:split], x[:split]
            train_dataset = (theta_train, x_train)
            val_dataset = (theta_val, x_val)
            test_dataset = (theta_test, x_test)
        else:
            train_dataset = self.load_samples(
                "train",
                n_samples=self.args.n_samples_dataset,
                seed=self.args.seed,
                overwrite=overwrite,
            )
            val_dataset = self.load_samples(
                "val",
                n_samples=self.args.n_samples_dataset // 10,
                seed=self.args.seed,
                overwrite=overwrite,
            )
            test_dataset = self.load_samples(
                "test",
                n_samples=self.args.n_samples_dataset_test,
                seed=self.args.seed,
                overwrite=overwrite,
            )
        return train_dataset, val_dataset, test_dataset

    def parzen_log_prob(
        self, x: torch.Tensor, sigma: torch.Tensor, parzen_data: torch.Tensor
    ) -> torch.Tensor:
        return log_prob_gaussian_mixture(x, parzen_data, sigma)


def kfold_log_likelihood(X, bandwidth, k=5):
    """Estimate log-likelihood via k-fold CV (efficient)."""
    kf = KFold(n_splits=k, shuffle=True, random_state=42)
    log_likelihood = 0.0

    for train_index, test_index in kf.split(X):
        X_train, X_test = X[train_index], X[test_index]
        kde = TreeKDE(kernel="gaussian", bw=bandwidth)
        kde.fit(X_train)
        probs = kde.evaluate(X_test)
        log_likelihood += np.sum(np.log(probs + 1e-10))  # stability

    return log_likelihood


class DensityDataset(DatasetIdentifier):
    def __init__(self, args):
        super(DensityDataset, self).__init__(args)
        self.dataset_suffix = "KL_" + self.dataset_suffix
        self.domain_bound = 1.0  # for visualization assume density is mostly contained in [-domain_bound, domain_bounds]^d
        self.device = args.device

    def log_prob(self, points: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError

    def sample(self, n_samples: int) -> torch.Tensor:
        raise NotImplementedError

    # def kde_cv(
    #     self,
    #     n_samples: int,
    #     min_log_bw: float = -2,
    #     max_log_bw: float = 0.7,
    #     n_bws: int = 20,
    #     k_folds: int = 5,
    # ) -> torch.Tensor:
    #     n_samples_kde = n_samples * 50
    #     filename = f"{self.dataset_folder}/{self.args.dataset}_n{n_samples_kde}_kde_bandwidth.txt"
    #
    #     if os.path.exists(filename):
    #         with open(filename, "r") as f:
    #             self.best_bw = float(f.read())
    #         print(f"Optimal bandwidth ({self.best_bw}) loaded from file: {filename}")
    #     else:
    #         bandwidths = np.logspace(start=min_log_bw, stop=max_log_bw, num=n_bws)
    #         X = self.sample(n_samples=n_samples)
    #         X_np = X.detach().cpu().numpy()
    #
    #         print(
    #             f"Kde: running {k_folds}-fold cv over {n_bws} bandwidths with {n_samples} samples"
    #         )
    #         # Grid search over bandwidths
    #         best_bw = None
    #         best_score = -float("inf")
    #         with torch.no_grad():
    #             for bw in tqdm(bandwidths, desc="Searching bandwidth"):
    #                 score = kfold_log_likelihood(X_np, bw, k_folds)
    #                 if score > best_score:
    #                     best_score = score
    #                     best_bw = bw
    #
    #         self.best_bw = best_bw.item()
    #         print(f"Best bandwidth: {self.best_bw:.4f}")
    #         with open(filename, "w") as f:
    #             f.write(str(self.best_bw))
    #         print(f"Optimal bandwitdh saved to file: {filename}")
    #
    #     X = self.sample(n_samples=n_samples).to(self.device)
    #     self.kde_cv = KernelDensity(bandwidth=self.best_bw, kernel="gaussian").fit(X)
    #     print(f"KDE fit with {n_samples_kde} is now available with dataset.kde_cv(x)")
    #     self.dataset_kde = True

    # def sample_mixtures(self, n_samples, sigma=0.05):
    #     self.X = self.sample(n_samples=n_samples)
    #     self.sigma = sigma


# mainly for datasets used in paper:
# Papamakarios, George, Theo Pavlakou, and Iain Murray. "Masked autoregressive flow for density estimation."
# Advances in neural information processing systems 30 (2017).
class DiscreteSamplesFromFileDataset(DiscreteSampleDataset):
    def __init__(self, args):
        super(DiscreteSamplesFromFileDataset, self).__init__(args)
        self.index = 0
        self.index_val = 0
        self.index_test = 0

        self.train, self.val, self.test = self.get_split_data()
        self.N = self.train.shape[0]
        self.N_val = self.val.shape[0]
        self.N_test = self.test.shape[0]
        self.D = self.train.shape[-1]

    def load_data(self) -> np.array:
        raise NotImplementedError

    def get_split_data(self, train=-1, val=0.1, test=0.1):
        assert not train < -1
        data = self.load_data()
        data = self.normalize_np(data)
        data = self.shuffle_np(data)
        val_split = int(val * data.shape[0])
        test_split = int(test * data.shape[0])
        if train == -1:
            train_split = data.shape[0] - val_split - test_split
        else:
            train_split = int(train * data.shape[0])
        return (
            data[:train_split],
            data[train_split : train_split + val_split],
            data[train_split + val_split : train_split + val_split + test_split],
        )

    def normalize_tensor(self, x: torch.Tensor) -> torch.Tensor:
        if self.args.normalize_dataset_type == 0:
            return (x - x.min(0, keepdim=True)) / (x.max(0, keepdim=True) - x.min(0, keepdim=True))
        if self.args.normalize_dataset_type == 1:
            return (x - x.mean(0, keepdim=True)) / (x.std(0, keepdim=True))

    def normalize_np(self, x: NDArray) -> NDArray:
        if self.args.normalize_dataset_type == 0:
            return (x - x.min(0, keepdims=True)) / (
                x.max(0, keepdims=True) - x.min(0, keepdims=True)
            )
        elif self.args.normalize_dataset_type == 1:
            return (x - x.mean(0, keepdims=True)) / (x.std(0, keepdims=True))
        else:
            return x

    def shuffle_tensor(self, x: torch.Tensor) -> torch.Tensor:
        perm = torch.randperm(x.shape[0])
        return x[perm]

    def shuffle_np(self, x: NDArray) -> NDArray:
        np.random.shuffle(x)
        return x

    def sample(self, num_samples: int, data_type: str = "train") -> torch.Tensor:
        if data_type == "train":
            data, i, n = self.train, self.index, self.N
        elif data_type == "val":
            data, i, n = self.val, self.index_val, self.N_val
        elif data_type == "test":
            data, i, n = self.test, self.index_test, self.N_test
        else:
            raise ValueError
        if not self.args.accept_less_datapoints:
            assert num_samples <= n, "num_samples must be <= available samples"
        else:
            num_samples = min(num_samples, n)

        samples = (
            data[i : i + num_samples]
            if i + num_samples < n
            else data[np.r_[i:n, 0 : i + num_samples - n]]
        )
        new_index = (i + num_samples) % n
        if data_type == "train":
            self.index = new_index
        elif data_type == "val":
            self.index_val = new_index
        elif data_type == "test":
            self.index_test = new_index
        return torch.tensor(samples, device=self.args.device, dtype=self.args.dtype)

    def parzen_log_prob(
        self, x: torch.Tensor, sigma: torch.Tensor, parzen_data: torch.Tensor = None
    ) -> torch.Tensor:
        return super().parzen_log_prob(
            x,
            sigma,
            self.train.to(device=x.device, dtype=x.dtype) if parzen_data is None else parzen_data,
        )


class HybridDataset(DiscreteSampleDataset, DensityDataset):
    def __init__(self, args):
        super(HybridDataset, self).__init__(args)


class NologpDensityDataset(DensityDataset):
    def __init__(self, args):
        super().__init__(args)

        n_samples = 10_000
        if args.logp_estimator == "kde":
            self.kde_cv(n_samples=n_samples)
            self.logp_estimator = "kde"
        elif args.logp_estimator == "parzen":
            self.sigma = 0.05
            self.sample_mixtures(n_samples=n_samples)
            self.logp_estimator = "parzen"
        else:
            raise ValueError("Unknown logp_estimator: must be 'kde' or 'parzen' or 'None'")

    def kde_cv(
        self,
        n_samples: int,
        min_log_bw: float = -2,
        max_log_bw: float = 0.7,
        n_bws: int = 20,
        k_folds: int = 5,
    ) -> torch.Tensor:
        n_samples_kde = n_samples * 50
        filename = f"{self.dataset_folder}/{self.args.dataset}_n{n_samples_kde}_kde_bandwidth.txt"

        if os.path.exists(filename):
            with open(filename, "r") as f:
                self.best_bw = float(f.read())
            print(f"Optimal bandwidth ({self.best_bw}) loaded from file: {filename}")
        else:
            bandwidths = np.logspace(start=min_log_bw, stop=max_log_bw, num=n_bws)
            X = self.sample(n_samples=n_samples)
            X_np = X.detach().cpu().numpy()

            print(
                f"Kde: running {k_folds}-fold cv over {n_bws} bandwidths with {n_samples} samples"
            )
            # Grid search over bandwidths
            best_bw = None
            best_score = -float("inf")
            with torch.no_grad():
                for bw in tqdm(bandwidths, desc="Searching bandwidth"):
                    score = kfold_log_likelihood(X_np, bw, k_folds)
                    if score > best_score:
                        best_score = score
                        best_bw = bw

            self.best_bw = best_bw.item()
            print(f"Best bandwidth: {self.best_bw:.4f}")
            with open(filename, "w") as f:
                f.write(str(self.best_bw))
            print(f"Optimal bandwitdh saved to file: {filename}")

        X = self.sample(n_samples=n_samples).to(self.device)
        self.kde_cv = KernelDensity(bandwidth=self.best_bw, kernel="gaussian").fit(X)
        print(f"KDE fit with {n_samples_kde} is now available with dataset.kde_cv(x)")
        self.dataset_kde = True

    def log_prob(self, x):
        if self.args.logp_estimator == "kde":
            logp = torch.clamp(self.kde_cv.score_samples(x).to(self.args.device), min=-20)
        elif self.args.logp_estimator == "parzen":
            logp = log_prob_gaussian_mixture(
                x=x,
                mixtures=self.mixtures,
                sigma=torch.tensor(self.sigma, device=self.args.device),
            )
        else:
            raise NotImplementedError("Log prob not available for this dataset. Use args.")
        return logp

    def sample_mixtures(self, n_samples):
        self.mixtures = self.sample(n_samples=n_samples)

    def sample_estimator(self, n_samples):
        if self.args.logp_estimator == "kde":
            samples = self.kde_cv.sample(n_samples=n_samples)
        elif self.args.logp_estimator == "parzen":
            samples = sample_gaussian_mixture(
                n_samples=n_samples,
                means=self.mixtures,
                sigma=torch.tensor(self.sigma, device=self.args.device),
            )
        else:
            raise NotImplementedError("Log prob not available for this dataset. Use args.")
        return samples


class HybridNologpDataset(DiscreteSampleDataset, NologpDensityDataset):
    def __init__(self, args):
        super().__init__(args)
