import glob
import logging
import numpy as np
import os
import random
import torch
import torch.utils.data
import pandas as pd
import pyvista as pv
import progressbar as pb
from multiprocessing import *
import blosc
from collections import defaultdict
from tqdm import trange
from scipy.interpolate import interp1d
import numpy as np
from scipy.interpolate import PchipInterpolator
from sklearn.preprocessing import StandardScaler, PowerTransformer, QuantileTransformer, RobustScaler

class StandardScalerNormalizer:
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self, y_ref: torch.Tensor) -> torch.Tensor:
        self.scaler = self.scaler.fit(y_ref)
        #


    def transform(self, y: torch.Tensor) -> torch.Tensor:
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        elif isinstance(y, np.ndarray):
            y = y
        y_transformed = self.scaler.transform(y)
        return torch.tensor(y_transformed, dtype=torch.float32)

    def inverse_transform(self, z: torch.Tensor, WHETHER_STD: bool =False) -> torch.Tensor:
        if isinstance(z, torch.Tensor):
            z = z.detach().cpu().numpy()
        elif isinstance(z, np.ndarray):
            z = z

        if not WHETHER_STD:
            y = self.scaler.inverse_transform(z)
        else:
            y = self.scaler.scale_ * z
        return y  # torch.tensor(y, dtype=torch.float32, device=z.device)

    def postprocess_(self, mu: torch.Tensor, sigma: torch.Tensor):
        if isinstance(mu, torch.Tensor):
            mu = mu.detach().cpu().numpy()

        if isinstance(sigma, torch.Tensor):
            sigma = sigma.detach().cpu().numpy()



        mean_y = self.inverse_transform(mu)
        std_y = self.inverse_transform(sigma, WHETHER_STD=True)
        lower_ci = mean_y - std_y * 2
        upper_ci = mean_y + std_y * 2

        return {
            'mean_y': mean_y,
            'std_y': std_y,
            'lower_ci': lower_ci,
            'upper_ci': upper_ci,
        }



class PredefinedStandardScalerNormalizer:
    def __init__(self, ):
        self.scaler = StandardScaler()
        self.mean = {"C1": 5.0008, "C2": 6.5522, "X": 5.0063, "Y": 3.9455}
        self.std = {"C1": 2.8857, "C2": 2.4229, "X": 2.8843, "Y": 1.5249}


    def fit(self, name: str) -> torch.Tensor:
        self.scaler.mean_ = self.mean[name]
        self.scaler.scale_ = self.std[name]
        self.scaler.var_ = self.std[name]**2

        #


    def transform(self, y: torch.Tensor) -> torch.Tensor:
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        elif isinstance(y, np.ndarray):
            y = y
        y_transformed = self.scaler.transform(y)
        return torch.tensor(y_transformed, dtype=torch.float32)

    def inverse_transform(self, z: torch.Tensor, WHETHER_STD: bool =False) -> torch.Tensor:
        if isinstance(z, torch.Tensor):
            z = z.detach().cpu().numpy()
        elif isinstance(z, np.ndarray):
            z = z

        if not WHETHER_STD:
            y = self.scaler.inverse_transform(z)
        else:
            y = self.scaler.scale_ * z
        return y  # torch.tensor(y, dtype=torch.float32, device=z.device)

    def postprocess_(self, mu: torch.Tensor, sigma: torch.Tensor):
        if isinstance(mu, torch.Tensor):
            mu = mu.detach().cpu().numpy()

        if isinstance(sigma, torch.Tensor):
            sigma = sigma.detach().cpu().numpy()



        mean_y = self.inverse_transform(mu)
        std_y = self.inverse_transform(sigma, WHETHER_STD=True)
        lower_ci = mean_y - std_y * 2
        upper_ci = mean_y + std_y * 2

        return {
            'mean_y': mean_y,
            'std_y': std_y,
            'lower_ci': lower_ci,
            'upper_ci': upper_ci,
        }




class PowerNormalizer:
    def __init__(self):
        self.scaler = StandardScaler()

    def fit(self, y_ref: np.ndarray) -> torch.Tensor:
        y_ = np.clip(y_ref, a_min=0, a_max=y_ref.max())
        y_ = np.log(y_ + 1)
        self.scaler = self.scaler.fit(y_)
        #


    def transform(self, y: torch.Tensor) -> torch.Tensor:
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        elif isinstance(y, np.ndarray):
            y = y

        y = np.clip(y, a_min=0, a_max=y.max())

        y_transformed = np.log(y + 1)
        y_transformed = self.scaler.transform(y_transformed)
        return torch.tensor(y_transformed, dtype=torch.float32)

    def inverse_transform(self, z: torch.Tensor, WHETHER_STD: bool=False) -> torch.Tensor:
        if isinstance(z, torch.Tensor):
            z = z.detach().cpu().numpy()
        elif isinstance(z, np.ndarray):
            z = z

        if not WHETHER_STD:
            y1 = self.scaler.inverse_transform(z)
            y = np.exp(y1) - 1
            return y  # torch.tensor(y, dtype=torch.float32, device=z.device)
        else:
           print('wrong input, inverse transform cannot be used for std')


    def inverse_transform_only_standard_scaler(self, z: torch.Tensor, WHETHER_STD: bool=False) -> torch.Tensor:
        if isinstance(z, torch.Tensor):
            z = z.detach().cpu().numpy()
        elif isinstance(z, np.ndarray):
            z = z

        if not WHETHER_STD:
            y = self.scaler.inverse_transform(z)
        else:
            y = self.scaler.scale_* z

        return y #torch.tensor(y, dtype=torch.float32, device=z.device)

    def postprocess_(self, mu: torch.Tensor, sigma: torch.Tensor, ci: float = 0.95):
        """
        Given log-domain outputs from a model, compute the mean, std, and confidence interval
        in the original (exp-transformed) space.

        Args:
            mu: Tensor of predicted means in log-space
            logvar: Tensor of predicted log-variance in log-space
            ci: Confidence level for interval (default=0.95)

        Returns:
            dict with:
                - mean_y: E[y] in original space
                - std_y: Std[y] in original space
                - lower_ci: lower bound of log-normal CI (quantile-based)
                - upper_ci: upper bound of log-normal CI (quantile-based)
                - lower_std: mean - std (not a proper CI, for visualization)
                - upper_std: mean + std (not a proper CI, for visualization)
        """
        if isinstance(mu, torch.Tensor):
            mu = mu.detach().cpu().numpy()

        if isinstance(sigma, torch.Tensor):
            sigma = sigma.detach().cpu().numpy()


        from scipy.stats import norm

        # Original space mean and std (log-normal closed-form)
        mu_ = self.inverse_transform_only_standard_scaler(mu.reshape(-1, 1))
        std_ = self.inverse_transform_only_standard_scaler(sigma.reshape(-1, 1) ,WHETHER_STD=True)

        mean_y = np.exp(mu_ + 0.5 * std_**2) - 1
        var_y = (np.exp(std_**2) - 1.0) * np.exp(2 * mu_ + std_**2)
        std_y = np.sqrt(var_y)

        # Confidence interval from quantile
        z = norm.ppf(0.5 + ci / 2.0)  # e.g., 1.96 for 95%
        lower_ci = np.exp(mu_ - z * std_) - 1
        upper_ci = np.exp(mu_ + z * std_) - 1

        return {
            'mean_y': mean_y,
            'std_y': std_y,
            'lower_ci': lower_ci,
            'upper_ci': upper_ci,
        }




class RobustNormalizer:
    def __init__(self):
        self.scaler = RobustScaler()

    def fit(self, y_ref: torch.Tensor) -> torch.Tensor:
        self.scaler = self.scaler.fit(y_ref)
        #


    def transform(self, y: torch.Tensor) -> torch.Tensor:
        if isinstance(y, torch.Tensor):
            y = y.detach().cpu().numpy()
        elif isinstance(y, np.ndarray):
            y = y
        y_transformed = self.scaler.transform(y)
        return torch.tensor(y_transformed, dtype=torch.float32)

    def inverse_transform(self, z: torch.Tensor, WHETHER_STD: bool=False) -> torch.Tensor:
        if isinstance(z, torch.Tensor):
            z = z.detach().cpu().numpy()
        elif isinstance(z, np.ndarray):
            z = z

        if not WHETHER_STD:
            y = self.scaler.inverse_transform(z)
        else:
            y = self.scaler.scale_ * z
        return y #torch.tensor(y, dtype=torch.float32, device=z.device)
