import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import glob
import logging
import numpy as np
import os
import random
import torch
import torch.utils.data
import pandas as pd
import pyvista as pv
import progressbar as pb
from multiprocessing import *
import blosc
from collections import defaultdict
from tqdm import trange
from model.networks.normalizers import StandardScalerNormalizer, PredefinedStandardScalerNormalizer

def conditional_gaussian_2d(
    mu: np.ndarray,
    Sigma: np.ndarray,
    y_obs: float
):
    """
    Compute the conditional distribution p(x1 | x2 = y_obs)
    for a 2D Gaussian [x1, x2] ~ N(mu, Sigma).

    Args:
        mu: (2,) mean vector of [x1, x2]
        Sigma: (2,2) covariance matrix
        y_obs: observed value of x2

    Returns:
        mu_cond: conditional mean of x1 | x2 = y_obs
        sigma_cond: conditional std deviation of x1 | x2 = y_obs
    """
    mu1, mu2 = mu
    sigma11 = Sigma[0, 0]
    sigma22 = Sigma[1, 1]
    sigma12 = Sigma[0, 1]

    # Conditional mean and variance
    mu_cond = mu1 + sigma12 / sigma22 * (y_obs - mu2)
    var_cond = sigma11 - sigma12**2 / sigma22
    return mu_cond, np.sqrt(var_cond)


def inverse_distri_pars(a, mu_c1, std_c1, std_eps, c2):
    """
    Compute the conditional distribution p(X | Y = y_obs) where:
        Y = a * X + eps,
        X ~ N(mu_X, sigma_X^2),
        eps ~ N(0, sigma_eps^2)

    Args:
        a: coefficient in Y = aX + eps
        sigma_X: std of X
        sigma_eps: std of eps
        mu_X: mean of X (default 0)
        y_obs: observed value of Y

    Returns:
        mu_post: conditional mean of X | Y = y_obs
        std_post: conditional std of X | Y = y_obs
    """
    var_X = std_c1 ** 2
    var_eps = std_eps ** 2

    # posterior variance
    sigma_post_sq = 1 / (1 / var_X + (a ** 2) / var_eps)

    # posterior mean
    mu_post = sigma_post_sq * (mu_c1 / var_X + a * c2 / var_eps)
    std_post = torch.sqrt(sigma_post_sq)

    return mu_post, std_post




def design_system():
    # Compute nonlinear effects for each covariate:
    # The effect of cov1 is modeled as a sine transformation.
    def f1(x, c):
        y = torch.sin(0.15 * (c + 1) * x) + torch.cos(0.3 * (c + 1) * x)
        y = y.to(x.device)
        return y

    def f2(x, c):
        y = torch.exp(0.1 * (c + x + 1))
        y = y.to(x.device)
        return y


    # ##
    # a = 0.8
    # mu_c1 = 5
    # std_c1 = 5
    # std_eps = 0.1



    ##
    a = 0.8
    mu_c1 = 5
    std_c1 = 5
    std_eps = 0.


    def g1(c1):
        c1 = c1.squeeze()
        mu_c2 = torch.log(torch.clamp(c1, min=1e-4) + 1) * 4 #torch.exp(c1 / 4) #+ std_eps * torch.randn_like(c1)
        mu_c2 = mu_c2.to(c1.device)
        std_c2 = torch.zeros_like(c1) + std_eps #* torch.randn_like(c1)
        std_c2 = std_c2.to(c1.device)

        mu_all = torch.stack((c1, mu_c2), dim=-1)
        std_all = torch.stack((torch.zeros_like(c1), std_c2), dim=-1)
        var_all = std_all **2

        Sigma_cond = torch.zeros((len(c1), 2, 2), device=c1.device)
        idx_x, idx_y = 0, 1
        Sigma_cond[..., idx_x, idx_x] = torch.zeros_like(c1).squeeze()  # deterministic
        Sigma_cond[..., idx_y, idx_y] = std_eps **2
        Sigma_cond[..., idx_x, idx_y] = torch.zeros_like(c1).squeeze()
        Sigma_cond[..., idx_y, idx_x] = torch.zeros_like(c1).squeeze()


        return mu_all, var_all, Sigma_cond

    def g2(c2):
        c2 = c2.squeeze()
        mu_post, std_post = torch.exp(c2 / 4)- 1, torch.zeros_like(c2)  ##torch.log(torch.clamp(c2, min=1e-4)) * 4, torch.zeros_like(c2)
        mu_post = mu_post.to(c2.device)
        std_post = std_post.to(c2.device)

        mu_all = torch.stack((mu_post, c2), dim=-1)
        std_all = torch.stack((std_post, torch.zeros_like(c2)), dim=-1)
        var_all = std_all ** 2

        Sigma_cond = torch.zeros((len(c2), 2, 2), device=c2.device)
        idx_x, idx_y = 0, 1
        Sigma_cond[..., idx_x, idx_x] = std_post **2  # deterministic
        Sigma_cond[..., idx_y, idx_y] = torch.zeros_like(c2).squeeze()
        Sigma_cond[..., idx_x, idx_y] = torch.zeros_like(c2).squeeze()
        Sigma_cond[..., idx_y, idx_x] = torch.zeros_like(c2).squeeze()

        return mu_all, var_all, Sigma_cond




    def noise_par_1(x, c1):
        # For cov1: linear noise dependence on cov1.
        sigma1 = 0.05 * x * (0.1 * c1)
        return sigma1 ** 2


    def noise_par_2(x, c2):
        # For cov2: constant noise level.
        sigma2 = 0.05 * x * (torch.exp(c2 * 0.1) - 1)
        return sigma2 ** 2


    dict_f = {"C1": f1, "C2": f2, }
    dict_g = {"C1": g1, "C2": g2, }
    dict_noise_par = {"C1": noise_par_1, "C2": noise_par_2}
    return dict_f, dict_g, dict_noise_par



def generate_covaraites(n_samples, g):
    print(n_samples)
    cov1 = torch.rand(int(n_samples)) * 10
    mu_c2, var_c2, covar_c2 = g(cov1)
    cov2 = torch.randn(n_samples) * var_c2[..., -1].sqrt() + mu_c2[..., -1]

    #cov2 = torch.exp(cov1 / 4) #(cov1 - 5) ** 2 / 2 #torch.randn(n_samples) * std_c2[..., -1] + mu_c2[..., -1]
    return cov1, cov2


def generate_data(n_samples = 1000, VIS=True):
    # Set random seed for reproducibility
    dict_f, dict_g, dict_noise_par =  design_system()
    # Generate two independent covariates
    cov1, cov2 = generate_covaraites(n_samples, dict_g["C1"])


    # Generate a more complex one-dimensional spatial variable:
    x = torch.rand(n_samples) * 10  #np.random.uniform(0, 10, size=n_samples)


    # Compute nonlinear effects for each covariate:
    effect1 = dict_f["C1"](x, cov1)
    effect2 = dict_f["C2"](x, cov2)
    # For noise par
    sigma_1 = dict_noise_par["C1"](x, cov1).sqrt()
    sigma_2 = dict_noise_par["C2"](x, cov2).sqrt()

    sigma_spatial_htsk = sigma_1 + sigma_2
    spatiotemporal_noise = torch.randn(len(x)) * sigma_spatial_htsk
    # The final outcome y is the sum of the noisy effects
    y = effect1 + effect2
    y_with_noise = (effect1 + effect2) + spatiotemporal_noise

    # Create a DataFrame to hold the data and computed effects
    df = pd.DataFrame({
        "X": x,
        'C1': cov1,
        'C2': cov2,
        "C1_muter": np.zeros(n_samples),
        "C2_muter": np.zeros(n_samples),
        "PID": np.arange(n_samples),
        'effect1': effect1,
        'effect2': effect2,
        "Y": y_with_noise,
        'y_gt': y,
        "ID": np.arange(n_samples),

    })



    print(df.head())


    if VIS:

        arr_pts = torch.linspace(0, 10, n_samples)#torch.rand(n_samples) * 10

        arr_cov1 = torch.ones_like(arr_pts) * 10
        mu_c2, std_c2, covar_c2 = dict_g["C1"](arr_cov1)
        arr_cov2 = mu_c2[..., -1] #torch.randn(n_samples) * std_c2[..., -1] + mu_c2[..., -1]

        # Compute nonlinear effects for each covariate:
        arr_f1 = dict_f["C1"](arr_pts, arr_cov1)
        arr_f2 = dict_f["C2"](arr_pts, arr_cov2)

        #sigma_spatial_htsk = noise_par(arr_pts, arr_cov1, arr_cov2)
        sigma1 = dict_noise_par["C1"](arr_pts, arr_cov1).sqrt()
        sigma2 = dict_noise_par["C2"](arr_pts, arr_cov2).sqrt()
        sigma_spatial_htsk = sigma1 + sigma2

        # The final outcome y is the sum of the noisy effects
        y_with_max_cov = arr_f1 + arr_f2

        fig, ax = plt.subplots(figsize=(8, 4))
        ax.plot(arr_pts, y_with_max_cov, alpha=0.5, color='black', label='X1', linewidth=2.0)
        ax.fill_between(arr_pts, y_with_max_cov - sigma_spatial_htsk, y_with_max_cov + sigma_spatial_htsk, color='black', alpha=0.3)
        plt.savefig("/playpen-raid/Author/LucidAtlas/data/ToySTData/vis_spatial_at_max_cov.png")
        plt.close()



        arr_cov1 = torch.ones_like(arr_pts)
        mu_c2, std_c2, covar_c2 = dict_g["C1"](arr_cov1)
        arr_cov2 = mu_c2[..., -1] #torch.randn(n_samples) * std_c2[..., -1] + mu_c2[..., -1]

        # Compute nonlinear effects for each covariate:
        arr_f1 = dict_f["C1"](arr_pts, arr_cov1)
        arr_f2 = dict_f["C2"](arr_pts, arr_cov2)


        sigma1 = dict_noise_par["C1"](arr_pts, arr_cov1).sqrt()
        sigma2 = dict_noise_par["C2"](arr_pts, arr_cov2).sqrt()
        sigma_spatial_htsk = sigma1 + sigma2

        # The final outcome y is the sum of the noisy effects
        y_with_max_cov = arr_f1 + arr_f2

        fig, ax = plt.subplots(figsize=(8, 4))
        ax.plot(arr_pts, y_with_max_cov, alpha=0.5, color='black', label='X1', linewidth=2.0)
        ax.fill_between(arr_pts, y_with_max_cov - sigma_spatial_htsk, y_with_max_cov + sigma_spatial_htsk, color='black', alpha=0.3)
        plt.savefig("/playpen-raid/Author/LucidAtlas/data/ToySTData/vis_spatial_at_min_cov.png")
        plt.close()


    return df, dict_f, dict_g, dict_noise_par




def create_missingness(df_data: pd.DataFrame, nan_ratio: float=0.3):
    # 设定比例，比如 30% 的值设为 NaN
    #nan_ratio = 0.3
    n_nan = int(len(df_data) * nan_ratio)

    # 随机选择索引
    nan_indices = np.random.choice(df_data.index, size=n_nan, replace=False)

    # 将 'B' 列中的这些索引置为 NaN
    df_data.loc[nan_indices, 'C2'] = np.nan
    return df_data



class ToySTDataset(torch.utils.data.Dataset):
    '''
    This is for making training cases with missingness
    '''
    def __init__(
            self,
            filename_datasource: str=None,
            filename_split: str=None,
            slt_percentile: float=None,
            covariate_names: list=["C1", "C2"],
            tgt_var_name: str='Y',
            split: str='train',
            augtype: str='none',
            allow_missingness: bool=False,
            padding_muter: bool=False,
            training_sample_size: int=200000,
    ):
        self.DATASETNANE = 'ToyData'
        self.geo_var_name = 'X'
        self.covariate_names = covariate_names
        self.tgt_var_name = tgt_var_name
        self.allow_missingness = allow_missingness if 'train' in split else False
        self.padding_muter = padding_muter
        df_data_train, dict_f, dict_g, dict_noise_par = generate_data(n_samples = training_sample_size)

        self.gt_f_mu = dict_f
        self.gt_g = dict_g
        self.gt_f_var = dict_noise_par


        df_data_train = create_missingness(df_data_train , nan_ratio=0.3)

        if split == 'train' or split == 'train_val':
            df_data = df_data_train
        elif split == 'val':
            df_data, _, _, _ = generate_data(n_samples=100000)
        elif split == 'test':
            df_data, _, _, _ = generate_data(n_samples=100000)


        # 2. get normalizer
        self.dict_covariate_normalizer = self.get_statistics_of_covariates(df_data_train)
        self.dict_geo_normalizer = self.get_statistics_of_geo_features(df_data_train)
        self.dict_normalizer = {**self.dict_geo_normalizer, **self.dict_covariate_normalizer}

        # 3. normalize data
        self.train_valid_dict_features, \
        self.train_valid_arr_features, \
        self.train_normed_data = \
        self.normalize(df_data_train)

        self.valid_dict_features, \
        self.valid_arr_features, \
        self.normed_data = \
        self.normalize(df_data)

        # 4. augment the dataset
        # self.pd_aug_data = self.make_augment_dataset(self.normed_data)
        # num_of_cases = self.count_aug_cases(self.normed_data)
        self.customize_aug(augtype=augtype)
        self.prepared_data_with_nans = self.prepared_data.copy()

        # 5. process missingness
        '''
        If allow_missingness is set to True, it means the dataset will give samples with missingness. 
        The missing covariates are indicated by the muter.
        If it is set to False, the dataset just give complete samples
        '''
        if self.allow_missingness:
            self.prepared_data = self.prepared_data.fillna(0)
        else:
            self.prepared_data = self.prepared_data.dropna()
        self.NUM_OF_CASES = len(self.prepared_data)
        print("There are " + str(len(self.normed_data)) + " records.")
        print("There are " + str(self.NUM_OF_CASES) + " records after augmentation or filtering.")

        print("There are " + str(len(np.unique(self.prepared_data_with_nans['PID']))) + "patients.")
        print("There are " + str(len(np.unique(self.prepared_data['PID'])))  + " patients after augmentation or filtering.")

        self.normalize_unique()

        self.train_valid_pos = np.array(self.train_normed_data['X'].values)


    def normalize_unique(self):
        for ith_cov in self.covariate_names:
            self.unique_covariates[ith_cov] = self.dict_covariate_normalizer[ith_cov].transform(self.unique_covariates[ith_cov].values.reshape(-1,1))
        return


    def __len__(self):
        return len(self.prepared_data)

    def customize_aug(self, augtype='full'):
        if augtype == 'none':
            self.prepared_data = self.not_augment_dataset(self.normed_data)
            #self.count_aug_cases(self.normed_data)

    def not_augment_dataset(self, pd_data_ori):

        pd_data = pd_data_ori.copy()
        list_data_dict = pd_data.to_dict('records')
        list_data = []

        for data_idx in trange(len(list_data_dict)):
            ith_dict_data = list_data_dict[data_idx]
            ith_arr_data = []
            for ith_cov in self.covariate_names:
                ith_arr_data.append(ith_dict_data[ith_cov])
            ith_arr_data = np.array(ith_arr_data)

            list_current_aug_group = []

            current_aug_data = ith_arr_data.copy()
            BINARY_MUTATION = np.isnan(current_aug_data)
            if self.padding_muter:
                current_aug_data_with_muter = np.concatenate([current_aug_data, BINARY_MUTATION], axis=-1)
            else:
                current_aug_data_with_muter = current_aug_data
            dict_current_data = self.make_dict_from_arr_data(current_aug_data_with_muter)
            dict_current_data['X'] = ith_dict_data['X']
            dict_current_data[self.tgt_var_name] = ith_dict_data[self.tgt_var_name]
            dict_current_data['ID'] = ith_dict_data['ID']
            dict_current_data['PID'] = ith_dict_data['PID']
            list_current_aug_group.append(dict_current_data)

            list_data += list_current_aug_group

        pd_aug = pd.DataFrame.from_records(list_data)
        pd_aug = pd_aug.drop_duplicates(subset=self.covariate_names + ['X', self.tgt_var_name, 'ID'])
        pd_aug = pd_aug.dropna(subset=self.covariate_names, how='all')
        return pd_aug

    def make_dict_from_arr_data(self, arr):
        assert len(arr) == len(self.covariate_names) * 2 or len(arr) == len(self.covariate_names)
        if len(arr) == len(self.covariate_names):
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
        elif len(arr) == len(self.covariate_names) * 2:
            dict_data = {}
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx]] = arr[idx]
            for idx in range(len(self.covariate_names)):
                dict_data[self.covariate_names[idx] + '_muter'] = arr[idx + len(self.covariate_names)]

        return dict_data
    def get_statistics_of_covariates(self, df_data_train):
        train_ids = np.unique(df_data_train['ID'])
        list_unique_covariates = []

        for ith_id in train_ids:
            list_unique_covariates.append(df_data_train[df_data_train['ID'] == ith_id].iloc[0])
        df_unique_covariates = pd.DataFrame.from_records(list_unique_covariates)
        self.unique_covariates = df_unique_covariates[self.covariate_names]

        self.dict_covariate_normalizer = {}

        for ith_covariate in self.covariate_names:
            current_cov_val = df_unique_covariates[ith_covariate].dropna().values.reshape(-1, 1)
            self.dict_covariate_normalizer[ith_covariate] = PredefinedStandardScalerNormalizer() #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='normal') #  # #StandardScaler() #PowerTransformer(method='yeo-johnson') #QuantileTransformer(output_distribution='uniform') ##SmoothCopulaNormalizer()
            self.dict_covariate_normalizer[ith_covariate].fit(ith_covariate)

        return self.dict_covariate_normalizer

    def get_statistics_of_geo_features(self, df_data_train):
        names_of_geo_features = ['X', self.tgt_var_name]
        self.dict_geo_normalizer = {}
        for ith_geo_feat in ['X']:
            current_geo_val = df_data_train[ith_geo_feat].dropna().values.reshape(-1, 1)
            self.dict_geo_normalizer[ith_geo_feat] = PredefinedStandardScalerNormalizer()
            self.dict_geo_normalizer[ith_geo_feat].fit(ith_geo_feat)

        for ith_geo_feat in [self.tgt_var_name]:
            current_geo_val = df_data_train[ith_geo_feat].dropna().values.reshape(-1, 1)
            self.dict_geo_normalizer[ith_geo_feat] = PredefinedStandardScalerNormalizer()
            self.dict_geo_normalizer[ith_geo_feat].fit(ith_geo_feat)
        return self.dict_geo_normalizer

    def normalize(self, df_data):

        list_covariates = []
        dict_covariates = {}
        for ith_cov in self.covariate_names:
            arr_current_cov = np.array(df_data[ith_cov])
            arr_normed_cov = self.dict_covariate_normalizer[ith_cov].transform(arr_current_cov.reshape(-1, 1))
            list_covariates.append(arr_normed_cov.squeeze())
            dict_covariates[ith_cov] = arr_normed_cov.squeeze()

        arr_covariates = np.array(list_covariates).T

        vol_values = self.dict_geo_normalizer[self.tgt_var_name].transform(np.array(df_data[self.tgt_var_name]).reshape(-1, 1))
        pos_values = self.dict_geo_normalizer['X'].transform(np.array(df_data['X']).reshape(-1, 1))


        dict_normalized_pd_data = {'ID': df_data['ID'], 'PID': df_data['PID'], "C1_muter": df_data["C1_muter"], "C2_muter": df_data["C2_muter"]}
        dict_normalized_pd_data.update(dict_covariates)
        dict_normalized_pd_data.update({'X': pos_values.squeeze()})
        dict_normalized_pd_data.update({self.tgt_var_name: vol_values.squeeze()})
        pd_data = pd.DataFrame.from_dict(dict_normalized_pd_data)

        return dict_covariates, arr_covariates, pd_data
    def __getitem__(self, idx):

        # loading features
        list_covariates = []
        for ith_cov in range(len(self.covariate_names)):
            list_covariates.append(torch.tensor(self.prepared_data[self.covariate_names[ith_cov]].iloc[idx]).float())

        coords = torch.tensor(self.prepared_data['X'].iloc[idx]).float()

        if self.padding_muter:
            # loading muters
            list_covariate_muters = []
            for ith_cov in range(len(self.covariate_names)):
                list_covariate_muters.append(
                    torch.tensor(self.prepared_data[self.covariate_names[ith_cov] + '_muter'].iloc[idx]).float())

            model_input = torch.tensor([coords] + list_covariates + list_covariate_muters).float()

        else:
            model_input = torch.tensor([coords] + list_covariates).float()

        vol = torch.tensor(self.prepared_data[self.tgt_var_name].iloc[idx]).float()[None,...]

        if torch.isnan(vol):
            print('1')
        return model_input, vol



def make_toydata_model_input(ds_, ds_prepared_data, tgt_var_name):
    list_covariate_names = ds_.covariate_names


    # loading features
    list_covariates = []
    for ith_cov in range(len(list_covariate_names)):
        list_covariates.append(torch.tensor(ds_prepared_data[list_covariate_names[ith_cov]].values).float())

    if ds_.padding_muter:
        # loading muters
        list_covariate_muters = []
        for ith_cov in range(len(list_covariate_names)):
            list_covariate_muters.append(
                torch.tensor(ds_prepared_data[list_covariate_names[ith_cov] + '_muter'].values).float())
        model_input = torch.stack(list_covariates + list_covariate_muters, dim=-1)
    else:
        model_input = torch.stack(list_covariates, dim=-1).float()

    vol = torch.tensor(ds_prepared_data[tgt_var_name].values).float()[..., None]

    return model_input, vol


def get_toy_data_for_id(test_idx, ds_, pos=None):
    pd_ds = ds_.prepared_data
    pd_slt_data = pd_ds[pd_ds['ID'] == test_idx]
    model_input, csa = make_toydata_model_input(ds_, pd_slt_data,  ds_.tgt_var_name)
    if len(model_input) != len(pd_slt_data):
        print('1')
    return model_input, csa, pd_slt_data




if __name__ == "__main__":
    a = generate_data(n_samples=1000000, VIS=True)

    ds_train = ToySTDataset(
        filename_datasource="",
        filename_split="",
        covariate_names=["C1", "C2"],
        tgt_var_name="Y",
        split='train',
    training_sample_size=100)

    print(len(ds_train))