import matplotlib
import pandas as pd

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
import torch
from utilities import dataio
import os
import pickle
import torch.nn as nn
from functools import partial
from tqdm import trange

def map_predictor(model_trained, name_predictor: str):
    model_trained = model_trained.eval()
    if name_predictor == "global":
        current_func = partial(model_trained.infer_global_importance, IGNORE_CORR=False)
    elif name_predictor == "local":
        current_func = partial(model_trained.infer_with_subnetwork)
    elif name_predictor == "indp":
        current_func = partial(model_trained.infer_global_importance, IGNORE_CORR=True)
    # elif name_predictor == 'correlation':
    #     current_func = partial(model_trained.infer_all_correlation)
    return current_func

def batched_infer(name_predictor: str,
                model_trained: torch.nn.Module,
                arr_input_grids: torch.Tensor,
                cov_name: str or list,
                  batch_size=10000):
    current_predictor = map_predictor(model_trained, name_predictor)
    list_mu = []
    list_var = []
    with torch.no_grad():
        for i in trange(0, arr_input_grids.shape[0], batch_size):
            arr_batch = arr_input_grids[i:i+batch_size]
            mu_batch, var_batch = current_predictor(model_input=arr_batch, set_S=cov_name)
            list_mu.append(mu_batch)
            list_var.append(var_batch)
    f_mu = torch.cat(list_mu, dim=0)
    f_var = torch.cat(list_var, dim=0)
    return f_mu, f_var

def batched_predict(
                model_trained: torch.nn.Module,
                arr_input_grids: torch.Tensor,
                batch_size=10000):
    list_mu = []
    list_var = []
    with torch.no_grad():
        for i in range(0, arr_input_grids.shape[0], batch_size):
            arr_batch = arr_input_grids[i:i+batch_size]
            mu_batch, var_batch = model_trained.infer_mu_and_var_testing(arr_batch)
            list_mu.append(mu_batch)
            list_var.append(var_batch)
    f_mu = torch.cat(list_mu, dim=0)
    f_var = torch.cat(list_var, dim=0)
    return f_mu, f_var



def save_model_in_pkl(model_trained, savepath_model):
    # Save the model to disk
    with open(savepath_model, 'wb') as file:
        pickle.dump(model_trained, file)

def load_model_in_pkl(path_model):
    # Save the model to disk
    with open(path_model, 'rb') as file:
        model_trained = pickle.load(file)
    return model_trained


def cond_mkdir(path):
    if not os.path.exists(path):
        os.makedirs(path)

def make_input_for_vis_disentangle(scalar_dataset, current_cov, pos, num_of_samples=500):
    current_features = scalar_dataset.valid_dict_features[current_cov]
    min_cov = current_features.min()
    max_cov = current_features.max()
    range_cov = max_cov - min_cov
    arr_current_cov = torch.linspace(min_cov - range_cov * 0.1, max_cov + range_cov * 0.1, num_of_samples)

    list_covariates = []
    for ith_cov_name in scalar_dataset.covariate_names:
        if ith_cov_name == current_cov:
            list_covariates.append(arr_current_cov)
        else:
            list_covariates.append(torch.zeros_like(arr_current_cov))

    coords = np.repeat(pos[None, ...], repeats=num_of_samples, axis=0)[..., None]

    arr_covariates = torch.cat(list_covariates, dim=-1)
    coords = torch.from_numpy(coords).float()
    model_input = torch.cat((coords, arr_covariates), dim=-1)
    return model_input


def make_input_for_vis_comp(dataset_name):
    if dataset_name == 'Airway':
        return make_input_for_vis_comp_airway
    else:
        return make_input_for_vis_comp_general #make_input_for_vis_comp_adnihp



def make_input_for_vis_comp_airway(scalar_dataset, pos=None, device='cpu', list_dropna_covariates=None):
    tgt_var_name = scalar_dataset.tgt_var_name
    pd_data = scalar_dataset.prepared_data_with_nans
    if pos is not None:
        pd_data_at_pos = pd_data[(pd_data[scalar_dataset.geo_var_name]- pos).abs()<1e-2]
    else:
        pd_data_at_pos = pd_data
    if list_dropna_covariates is not None:
        pd_data_at_pos = pd_data_at_pos.dropna(subset=list_dropna_covariates, how='any')
    else:
        pd_data_at_pos = pd_data_at_pos.dropna()
    #pd_data_at_pos.fillna(0)



    coords = np.array(pd_data_at_pos[scalar_dataset.geo_var_name].values)[..., None]
    list_covariates = []
    for ith_cov in scalar_dataset.covariate_names:
        list_covariates.append(torch.from_numpy(np.array(pd_data_at_pos[ith_cov].values)[..., None]))
    if scalar_dataset.padding_muter:
        for ith_cov in scalar_dataset.covariate_names:
            list_covariates.append(torch.from_numpy(np.array(pd_data_at_pos[ith_cov + '_muter'].values)[..., None]))

    arr_covariates = torch.cat(list_covariates, dim=-1)

    coords = torch.from_numpy(coords).float()
    model_input = torch.cat((coords, arr_covariates), dim=-1).float()

    arr_csa_at_pos = np.array(pd_data_at_pos[tgt_var_name].values)
    model_input = model_input.to(device)
    model_output = torch.from_numpy(arr_csa_at_pos).float().to(device)
    return model_input, model_output








def make_shape_input_for_vis_comp_airway(scalar_dataset,  device='cpu', list_dropna_covariates=None):
    tgt_var_name = scalar_dataset.tgt_var_name
    pd_data = scalar_dataset.prepared_data_with_nans
    pd_data_at_pos = pd_data
    if list_dropna_covariates is not None:
        pd_data_at_pos = pd_data_at_pos.dropna(subset=list_dropna_covariates, how='any')
    else:
        pd_data_at_pos = pd_data_at_pos.dropna()
    #pd_data_at_pos.fillna(0)

    list_id = np.unique(pd_data_at_pos['id'].values)
    list_data = []
    for ith_id in list_id:
        current_scan = pd_data_at_pos[pd_data_at_pos['id'] == ith_id]
        current_model_input, current_model_output = extract_airway_per_scan(scalar_dataset, current_scan, device)

        current_model_output = denormalize(ds_=scalar_dataset, arr_=current_model_output, var_name='csa')
        current_model_input[..., 0] = denormalize(ds_=scalar_dataset, arr_=current_model_input[..., 0], var_name='pos')

        list_data.append({'input': current_model_input, 'output': current_model_output})
    return list_data



def extract_airway_per_scan(scalar_dataset, current_scan, device):

    list_covariates = []
    for ith_cov in scalar_dataset.covariate_names:
        list_covariates.append(torch.from_numpy(np.array(current_scan[ith_cov].values)[..., None]))
    if scalar_dataset.padding_muter:
        for ith_cov in scalar_dataset.covariate_names:
            list_covariates.append(torch.from_numpy(np.array(current_scan[ith_cov + '_muter'].values)[..., None]))

    arr_covariates = torch.cat(list_covariates, dim=-1)

    coords = torch.from_numpy(current_scan['pos'].values[..., None]).float()
    model_input = torch.cat((coords, arr_covariates), dim=-1).float()

    tgt_var_name = scalar_dataset.tgt_var_name
    arr_csa_at_pos = np.array(current_scan[tgt_var_name].values)
    model_input = model_input.to(device)
    model_output = torch.from_numpy(arr_csa_at_pos).float().to(device)
    return model_input, model_output








def make_input_for_vis_comp_general(scalar_dataset, device='cpu', list_dropna_covariates=None):
    tgt_var_name = scalar_dataset.tgt_var_name
    pd_data = scalar_dataset.prepared_data_with_nans
    if list_dropna_covariates is not None:
        pd_data = pd_data.dropna(subset=list_dropna_covariates, how='any')
    else:
        pd_data = pd_data.dropna()
    #pd_data_at_pos.fillna(0)

    list_covariates = []
    for ith_cov in scalar_dataset.covariate_names:
        list_covariates.append(torch.from_numpy(np.array(pd_data[ith_cov].values)[..., None]))
    if scalar_dataset.padding_muter:
        for ith_cov in scalar_dataset.covariate_names:
            list_covariates.append(torch.from_numpy(np.array(pd_data[ith_cov + '_muter'].values)[..., None]))

    arr_covariates = torch.cat(list_covariates, dim=-1)
    model_input = arr_covariates.float()

    arr_vol = np.array(pd_data[tgt_var_name].values)
    model_input = model_input.to(device)
    model_output = torch.from_numpy(arr_vol).float().to(device)
    return model_input, model_output



def make_input_for_eval(dataset_name):
    if 'Airway' in dataset_name:
        return make_input_for_eval_airway
    elif  dataset_name == 'ToyData':
        return make_input_for_eval_airway
    else:
        return make_input_for_eval_general


def make_input_for_eval_airway(scalar_dataset, pos=None, device='cpu', list_dropna_covariates=None):
    tgt_var_name = scalar_dataset.tgt_var_name
    if pos is not None:
        pd_data_at_pos = scalar_dataset.prepared_data[scalar_dataset.prepared_data[scalar_dataset.geo_var_name] == pos]
    else:
        pd_data_at_pos = scalar_dataset.prepared_data #_with_nans
    if list_dropna_covariates is not None:
        pd_data_at_pos = pd_data_at_pos.dropna(subset=list_dropna_covariates, how='any')
    else:
        pd_data_at_pos = pd_data_at_pos.dropna()
    #pd_data_at_pos.fillna(0)



    coords = np.array(pd_data_at_pos[scalar_dataset.geo_var_name].values)[..., None]
    list_covariates = []
    for ith_cov in scalar_dataset.covariate_names:
        list_covariates.append(torch.from_numpy(np.array(pd_data_at_pos[ith_cov].values)[..., None]))

    if scalar_dataset.padding_muter:
        for ith_cov in scalar_dataset.covariate_names:
            list_covariates.append(torch.from_numpy(np.array(pd_data_at_pos[ith_cov + '_muter'].values)[..., None]))

    arr_covariates = torch.cat(list_covariates, dim=-1)

    coords = torch.from_numpy(coords).float()
    model_input = torch.cat((coords, arr_covariates), dim=-1).float()

    arr_csa_at_pos = np.array(pd_data_at_pos[tgt_var_name].values)
    model_input = model_input.to(device)
    model_output = torch.from_numpy(arr_csa_at_pos).float().to(device)
    return model_input, model_output, pd_data_at_pos



def make_input_for_eval_general(scalar_dataset, pos=None, device='cpu', list_dropna_covariates=None):
    tgt_var_name = scalar_dataset.tgt_var_name
    pd_data_at_pos = scalar_dataset.prepared_data

    if list_dropna_covariates is not None:
        pd_data_at_pos = pd_data_at_pos.dropna(subset=list_dropna_covariates, how='any')
    else:
        pd_data_at_pos = pd_data_at_pos.dropna()
    #pd_data_at_pos.fillna(0)

    list_covariates = []
    for ith_cov in scalar_dataset.covariate_names:
        list_covariates.append(torch.from_numpy(np.array(pd_data_at_pos[ith_cov].values)[..., None]))

    if scalar_dataset.padding_muter:
        for ith_cov in scalar_dataset.covariate_names:
            list_covariates.append(torch.from_numpy(np.array(pd_data_at_pos[ith_cov + '_muter'].values)[..., None]))

    arr_covariates = torch.cat(list_covariates, dim=-1)
    model_input = arr_covariates.float()

    arr_csa_at_pos = np.array(pd_data_at_pos[tgt_var_name].values)
    model_input = model_input.to(device)
    model_output = torch.from_numpy(arr_csa_at_pos).float().to(device)
    return model_input, model_output, pd_data_at_pos





def make_input_for_vis_3d_comp(scalar_dataset, covs_to_plot, pos, num_of_samples=100):
    '''

    :param scalar_dataset:
    :param covs_to_plot:
    :param pos:
    :param num_of_samples:
    :return:
    '''

    '''
    load grids
    '''
    assert len(covs_to_plot) == 2
    cov1, cov2 = covs_to_plot[0], covs_to_plot[1]

    feat1 = scalar_dataset.valid_dict_features[cov1]
    min_cov1 = feat1.min()
    max_cov1 = feat1.max()
    range_cov1 = max_cov1 - min_cov1
    arr_cov1 = torch.linspace(min_cov1, max_cov1, num_of_samples)

    feat2 = scalar_dataset.valid_dict_features[cov2]
    min_cov2 = feat2.min()
    max_cov2 = feat2.max()
    range_cov2 = max_cov2 - min_cov2
    arr_cov2 = torch.linspace(min_cov2, max_cov2, num_of_samples)

    x1, x2 = torch.meshgrid(arr_cov1.squeeze(), arr_cov2.squeeze())
    dict_cov_grids = {cov1: x1.flatten(), cov2: x2.flatten()}

    list_covariates = []
    for ith_cov_name in scalar_dataset.covariate_names:
        if ith_cov_name in covs_to_plot:
            list_covariates.append(dict_cov_grids[ith_cov_name][..., None])
        else:
            list_covariates.append(torch.zeros_like(x1.flatten())[..., None])

    coords = np.repeat(pos[None, ...], repeats=len(list_covariates[0]), axis=0)[..., None]
    coords = torch.from_numpy(coords).float()
    arr_covariates = torch.cat(list_covariates, dim=-1)
    model_input_grids = torch.cat((coords, arr_covariates), dim=-1).float()

    '''
    load gt
    '''

    model_input_gt, arr_csa_at_pos = make_input_for_vis_comp(scalar_dataset, pos)
    return model_input_grids, model_input_gt, arr_csa_at_pos




def make_grids_and_dps_for_2d_vis(dataset_name='Airway'):
    if 'Airway' in dataset_name:
        return make_grids_and_dps_for_2d_vis_spatial
    elif 'AFQ' in dataset_name:
        return make_grids_and_dps_for_2d_vis_spatial
    elif dataset_name == 'ToyData':
        return make_grids_and_dps_for_2d_vis_spatial
    #if dataset_name == "ADNIHP":
    #if dataset_name == "ADNIHP":
    else:
        return make_grids_and_dps_for_2d_vis_general

def make_grids_and_dps_for_2d_vis_general(scalar_dataset,
                               covs_to_plot,
                               pos=None,
                               num_of_samples=100,
                               device='cpu'):
    '''

    :param scalar_dataset:
    :param covs_to_plot:
    :param pos:
    :param num_of_samples:
    :return:
    '''

    '''
    load grids
    '''

    dict_cov_grids = {}

    for ith_cov in range(len(covs_to_plot)):

        ith_cov_name = covs_to_plot[ith_cov]

        arr_feat = scalar_dataset.prepared_data[ith_cov_name]
        arr_feat = np.nan_to_num(arr_feat)
        min_cov = arr_feat.min()
        max_cov = arr_feat.max()
        arr_cov = torch.linspace(min_cov, max_cov, num_of_samples)
        x = arr_cov.squeeze()
        dict_cov_grids[ith_cov_name] = x.flatten()

    list_covariates = []
    for ith_cov_name in scalar_dataset.covariate_names:
        if ith_cov_name in covs_to_plot:
            list_covariates.append(dict_cov_grids[ith_cov_name][..., None])
        else:
            list_covariates.append(torch.zeros_like(x.flatten())[..., None])

    arr_covariates = torch.cat(list_covariates, dim=-1)
    model_input_grids = arr_covariates.float()

    '''
    load gt
    '''

    model_input_gt, arr_csa_at_pos = make_input_for_vis_comp_general(scalar_dataset=scalar_dataset,
                                                                     device=device,
                                                                     list_dropna_covariates=covs_to_plot)
    model_input_grids = model_input_grids.float().to(device)

    return model_input_grids, model_input_gt, arr_csa_at_pos






def make_grids_and_dps_for_2d_vis_spatial(scalar_dataset,
                               covs_to_plot,
                               pos=None,
                               num_of_samples=100,
                               device='cpu'):
    '''

    :param scalar_dataset:
    :param covs_to_plot:
    :param pos:
    :param num_of_samples:
    :return:
    '''

    '''
    load grids
    '''

    dict_cov_grids = {}

    for ith_cov in range(len(covs_to_plot)):

        ith_cov_name = covs_to_plot[ith_cov]

        arr_feat = scalar_dataset.valid_dict_features[ith_cov_name]
        arr_feat = np.nan_to_num(arr_feat)
        min_cov = arr_feat.min()
        max_cov = arr_feat.max()
        arr_cov = torch.linspace(min_cov, max_cov, num_of_samples)
        x = arr_cov.squeeze()
        dict_cov_grids[ith_cov_name] = x.flatten()

    list_covariates = []
    for ith_cov_name in scalar_dataset.covariate_names:
        if ith_cov_name in covs_to_plot:
            list_covariates.append(dict_cov_grids[ith_cov_name][..., None])
        else:
            list_covariates.append(torch.zeros_like(x.flatten())[..., None])

    coords = np.repeat(pos[None, ...], repeats=len(list_covariates[0]), axis=0)#[..., None]
    coords = torch.from_numpy(coords).float()
    arr_covariates = torch.cat(list_covariates, dim=-1)
    model_input_grids = torch.cat((coords, arr_covariates), dim=-1).float()

    '''
    load gt
    '''

    model_input_gt, arr_csa_at_pos = make_input_for_vis_comp_airway(scalar_dataset, pos, device, list_dropna_covariates=covs_to_plot)
    model_input_grids = model_input_grids.float().to(device)

    return model_input_grids, model_input_gt, arr_csa_at_pos



def make_grids_and_dps_for_2d_vis_airway_shape(scalar_dataset,
                               covs_to_plot,
                               cov,
                               num_of_samples=100,
                               device='cpu'):
    '''

    :param scalar_dataset:
    :param covs_to_plot:
    :param pos:
    :param num_of_samples:
    :return:
    '''

    '''
    load grids
    '''

    dict_cov_grids = {}

    dict_cov_grids[covs_to_plot[0]] = torch.Tensor([[cov] * num_of_samples]).flatten()



    list_covariates = []
    for ith_cov_name in scalar_dataset.covariate_names:
        if ith_cov_name in covs_to_plot:
            list_covariates.append(dict_cov_grids[ith_cov_name][..., None])
        else:
            list_covariates.append(torch.zeros(num_of_samples)[..., None])

    coords = normalize(scalar_dataset, np.linspace(0, 1, num_of_samples)[..., None], 'pos') #np.repeat(pos[None, ...], repeats=len(list_covariates[0]), axis=0)[..., None]
    coords = torch.from_numpy(coords).float()
    arr_covariates = torch.cat(list_covariates, dim=-1)
    model_input_grids = torch.cat((coords, arr_covariates), dim=-1).float().to(device)

    '''
    load gt
    '''

    list_data = make_shape_input_for_vis_comp_airway(scalar_dataset, device, list_dropna_covariates=covs_to_plot)
    model_input_grids = model_input_grids.float().to(device)

    return model_input_grids, list_data





def make_grids_and_dps_for_3d_vis(dataset_name='Airway'):
    if  'Airway' in dataset_name:
        return make_grids_and_dps_for_3d_vis_airway
    elif dataset_name == 'ToyData':
        return make_grids_and_dps_for_3d_vis_airway
    #if dataset_name == "ADNIHP":
    #if dataset_name == "ADNIHP":
    else:
        return make_grids_and_dps_for_3d_vis_general


def make_grids_and_dps_for_3d_vis_general(scalar_dataset,
                               covs_to_plot,
                               pos=None,
                               num_of_samples=100,
                               device='cpu'):
    '''

    :param scalar_dataset:
    :param covs_to_plot:
    :param pos:
    :param num_of_samples:
    :return:
    '''

    '''
    load grids
    '''

    model_input_grids =   build_model_input_grid_general(scalar_dataset, covs_to_plot,  num_of_samples)


    '''
    load gt
    '''

    model_input_gt, arr_val_at_pos = make_input_for_vis_comp_general(scalar_dataset=scalar_dataset,
                                                                     device=device,
                                                                     list_dropna_covariates=covs_to_plot)
    model_input_grids = model_input_grids.float().to(device)

    return model_input_grids, model_input_gt, arr_val_at_pos




def make_grids_and_dps_for_3d_vis_airway(scalar_dataset,
                               covs_to_plot: list,
                               pos=None,
                               num_of_samples=100,
                               device='cpu'):
    '''

    :param scalar_dataset:
    :param covs_to_plot:
    :param pos:
    :param num_of_samples:
    :return:
    '''

    '''
    load grids
    '''

    # dict_cov_grids = {}
    #
    # for ith_cov in range(len(covs_to_plot)):
    #
    #     ith_cov_name = covs_to_plot[ith_cov]
    #
    #     arr_feat = scalar_dataset.valid_dict_features[ith_cov_name]
    #     arr_feat = np.nan_to_num(arr_feat)
    #     min_cov = arr_feat.min()
    #     max_cov = arr_feat.max()
    #     arr_cov = torch.linspace(min_cov, max_cov, num_of_samples)
    #     x = arr_cov.squeeze()
    #     dict_cov_grids[ith_cov_name] = x.flatten()
    #
    # list_covariates = []
    # for ith_cov_name in scalar_dataset.covariate_names:
    #     if ith_cov_name in covs_to_plot:
    #         list_covariates.append(dict_cov_grids[ith_cov_name][..., None])
    #     else:
    #         list_covariates.append(torch.zeros_like(x.flatten())[..., None])
    #
    # coords = np.repeat(pos[None, ...], repeats=len(list_covariates[0]), axis=0)[..., None]
    # coords = torch.from_numpy(coords).float()
    # arr_covariates = torch.cat(list_covariates, dim=-1)
    # model_input_grids = torch.cat((coords, arr_covariates), dim=-1).float()

    model_input_grids =  build_model_input_grid(scalar_dataset, covs_to_plot, pos, num_of_samples)


    '''
    load gt
    '''

    model_input_gt, arr_csa_at_pos = make_input_for_vis_comp_airway(scalar_dataset, pos, device, list_dropna_covariates=covs_to_plot)
    model_input_grids = model_input_grids.float().to(device)

    return model_input_grids, model_input_gt, arr_csa_at_pos

def build_model_input_grid(
    scalar_dataset,
    covs_to_plot: list,
    pos: np.ndarray,
    num_points: int = 50,
    proposed: str = 'median'
):
    """
    为一组协变量 covs_to_plot 构造网格，并生成对应的 model_input。

    Args:
        scalar_dataset:
            - .covariate_names: 全部协变量名字的 list
            - .valid_dict_features[name]: numpy array, 训练/验证集中该协变量的取值
        covs_to_plot: 要在网格上扫的协变量名字列表
        pos: numpy array, 地理坐标等额外特征，形状 (G,)
        num_points: 每个 cov 上的离散点数，网格总点数 = num_points ** len(covs_to_plot)
        proposed: 'median' 或 'zero'，表示其他协变量的默认值

    Returns:
        model_input: torch.FloatTensor, 形状 (M, G + D)，其中
            M = num_points**len(covs_to_plot)，
            G = len(pos)，
            D = len(scalar_dataset.covariate_names)
    """
    cov_names = scalar_dataset.covariate_names

    D = len(cov_names)

    # 1) 为每个要绘制的 cov 构造一维 grid
    grids = []
    for name in covs_to_plot:
        arr = scalar_dataset.valid_dict_features[name]
        arr = np.nan_to_num(arr)
        lo, hi = arr.min(), arr.max()
        grids.append(np.linspace(lo, hi, num_points))
    # 2) N 维 meshgrid
    mesh = np.meshgrid(*grids, indexing='ij')
    M = mesh[0].size  # num_points**len(covs_to_plot)

    # 3) 将 mesh 展平为 M x len(covs_to_plot)
    mesh_flat = [m.reshape(-1) for m in mesh]  # 每个长度 M

    # 4) 构造整套协变量矩阵 M x D
    cov_grid = np.zeros((M, D), dtype=float)
    for j, name in enumerate(cov_names):
        if name in covs_to_plot:
            idx = covs_to_plot.index(name)
            cov_grid[:, j] = mesh_flat[idx]
        else:
            arr = scalar_dataset.valid_dict_features[name]
            arr = np.nan_to_num(arr)
            if proposed == 'median':
                cov_grid[:, j] = np.median(arr)
            else:  # 'zero'
                cov_grid[:, j] = 0.0

    # 5) 构造地理坐标部分：重复 pos M 次，得到 M x G
    coords = np.tile(pos.reshape(1, -1), (M, 1))

    # 6) 拼合 —— M x (G + D)
    model_input = np.concatenate([coords, cov_grid], axis=1)

    return torch.from_numpy(model_input).float()



def build_model_input_grid_general(
    scalar_dataset,
    covs_to_plot: list,
    num_points: int = 50,
    proposed: str = 'zero'
):
    """
    为一组协变量 covs_to_plot 构造网格，并生成对应的 model_input。

    Args:
        scalar_dataset:
            - .covariate_names: 全部协变量名字的 list
            - .valid_dict_features[name]: numpy array, 训练/验证集中该协变量的取值
        covs_to_plot: 要在网格上扫的协变量名字列表
        pos: numpy array, 地理坐标等额外特征，形状 (G,)
        num_points: 每个 cov 上的离散点数，网格总点数 = num_points ** len(covs_to_plot)
        proposed: 'median' 或 'zero'，表示其他协变量的默认值

    Returns:
        model_input: torch.FloatTensor, 形状 (M, G + D)，其中
            M = num_points**len(covs_to_plot)，
            G = len(pos)，
            D = len(scalar_dataset.covariate_names)
    """
    cov_names = scalar_dataset.covariate_names

    D = len(cov_names)

    # 1) 为每个要绘制的 cov 构造一维 grid
    grids = []
    for name in covs_to_plot:
        arr = scalar_dataset.unique_covariates[name]
        arr = np.nan_to_num(arr)
        lo, hi = arr.min(), arr.max()
        grids.append(np.linspace(lo, hi, num_points))
    # 2) N 维 meshgrid
    mesh = np.meshgrid(*grids, indexing='ij')
    M = mesh[0].size  # num_points**len(covs_to_plot)

    # 3) 将 mesh 展平为 M x len(covs_to_plot)
    mesh_flat = [m.reshape(-1) for m in mesh]  # 每个长度 M

    # 4) 构造整套协变量矩阵 M x D
    cov_grid = np.zeros((M, D), dtype=float)
    for j, name in enumerate(cov_names):
        if name in covs_to_plot:
            idx = covs_to_plot.index(name)
            cov_grid[:, j] = mesh_flat[idx]
        else:
            arr = scalar_dataset.valid_dict_features[name]
            arr = np.nan_to_num(arr)
            if proposed == 'median':
                cov_grid[:, j] = np.median(arr)
            else:  # 'zero'
                cov_grid[:, j] = 0.0

    return torch.from_numpy(cov_grid).float()










def make_input_for_vis_correlation(scalar_dataset, covs_to_plot,  num_of_samples=100, device='cpu'):
    '''

    :param scalar_dataset:
    :param covs_to_plot:
    :param pos:
    :param num_of_samples:
    :return:
    '''

    '''
    load grids
    '''
    # assert len(covs_to_plot) == 2
    # cov1, cov2 = covs_to_plot[0], covs_to_plot[1]
    #
    # feat2 = scalar_dataset.valid_dict_features[cov2]
    # min_cov2 = feat2.min()
    # max_cov2 = feat2.max()
    # arr_cov2 = torch.linspace(min_cov2, max_cov2, num_of_samples)
    #
    # x1, x2 = arr_cov1.squeeze(), arr_cov2.squeeze()
    # dict_cov_grids = {cov1: x1.flatten(), cov2: x2.flatten()}


    dict_cov_grids = {} #{cov1: x1.flatten(), cov2: x2.flatten()}

    for ith_cov in range(len(covs_to_plot)):

        ith_cov_name = covs_to_plot[ith_cov]

        arr_feat = scalar_dataset.valid_dict_features[ith_cov_name]
        arr_feat = np.nan_to_num(arr_feat)
        min_cov = arr_feat.min()
        max_cov = arr_feat.max()
        arr_cov = torch.linspace(min_cov, max_cov, num_of_samples)
        x = arr_cov.squeeze()
        dict_cov_grids[ith_cov_name] = x.flatten()



    list_covariates = []
    for ith_cov_name in scalar_dataset.covariate_names:
        if ith_cov_name in covs_to_plot:
            list_covariates.append(dict_cov_grids[ith_cov_name][..., None])
        else:
            list_covariates.append(torch.zeros_like(x.flatten())[..., None])

    arr_covariates = torch.cat(list_covariates, dim=-1)

    model_input_grids = arr_covariates.float().to(device)

    return model_input_grids




def make_list(dict_input, len_input):
    list_input = []
    for ith_data in range(len_input):
        current_dict = {}
        for i_key, i_value in dict_input.items():
            if not isinstance(i_value, dict):
                current_dict[i_key] = i_value[ith_data]
            else:
                current_dict[i_key] = {}
                for i_cov_name, i_cov_value in i_value.items():
                    current_dict[i_key][i_cov_name] = i_cov_value[ith_data]
        list_input.append(current_dict)
    return list_input




def make_list_input_for_vis_comp(scalar_dataset, pos):

    pd_data_at_pos = scalar_dataset.normed_data[scalar_dataset.normed_data['pos'] == pos]
    coords = np.array(pd_data_at_pos['pos'].values)[..., None]
    dict_covariates = {}
    for ith_cov_name in scalar_dataset.covariate_names:
        dict_covariates[ith_cov_name] = torch.from_numpy(np.array(pd_data_at_pos[ith_cov_name].values))

    model_input = {}
    model_input.update({'coords': torch.from_numpy(coords).float()})
    model_input.update({'covariates': dict_covariates})

    arr_csa_at_pos = np.array(pd_data_at_pos['csa'].values)
    return model_input, arr_csa_at_pos



def plot_regression_per_cov(
    X_train,
    y_train,
    file_name="regression_example"
):
    fig, ax1 = plt.subplots(nrows=1, ncols=1, sharey=True)
    #ax1.set_title("MAP")
    ax1.scatter(X_train.flatten(), y_train.flatten(), alpha=0.3, color="tab:orange")
    #ax1.legend()
    curve_x = np.sort(X_train.flatten())
    curve_y = y_train[np.argsort(X_train.flatten())]
    ax1.plot(curve_x.flatten(), curve_y.flatten(),  color="black")

    ax1.set_ylabel("$y$")
    ax1.set_xlabel("$x$")
    #plt.tight_layout()
    return fig


def plot_regression_for_comp(
    X,
    y_pred,
    y_gt,
    file_name="regression_example"
):
    fig, ax1 = plt.subplots(nrows=1, ncols=1, sharey=True)
    #ax1.set_title("MAP")
    ax1.scatter(X.flatten(), y_pred.flatten(), alpha=0.3, color="blue")
    ax1.scatter(X.flatten(), y_gt.flatten(), alpha=0.3, color="red")

    curve_x = np.sort(X.flatten())
    curve_y = y_pred[np.argsort(X.flatten())]
    ax1.plot(curve_x.flatten(), curve_y.flatten(),  color="black")
    #ax1.legend()
    ax1.set_ylabel("$y$")
    ax1.set_xlabel("$x$")
    #plt.tight_layout()
    return fig




def write_scalar_summary(model,
                         scalar_dataset,
                         writer,
                         total_steps,
                         prefix='train_',
                         device='cpu'):

    model.eval()
    '''
    1. x
    '''

    slt_percentiles = [0, 20, 40, 50, 60, 80, 100]
    list_pos = np.percentile(scalar_dataset.train_valid_pos, slt_percentiles)
    for ith_cov_name in scalar_dataset.covariate_names:
        for ith_pos in list_pos:
            # disentangle view
            model_input = make_input_for_vis_disentangle(scalar_dataset, ith_cov_name, ith_pos, num_of_samples=500)
            model_input = movedict2cuda(model_input, model.device)
            model_output = model(model_input)

            fig_per_cov = plot_regression_per_cov(X_train= model_input['covariates'][ith_cov_name].cpu().numpy(),
                            y_train = model_output['model_out'].detach().cpu().numpy(),
                            file_name= str(ith_cov_name) + '_' + str(ith_pos))
            writer.add_figure('disent_' + str(ith_cov_name) + '_' + str(ith_pos), fig_per_cov, global_step=total_steps)

            #

            model_input, arr_csa_at_pos = make_input_for_vis_comp(scalar_dataset, ith_pos)
            model_input = movedict2cuda(model_input, model.device)
            model_output = model(model_input)

            fig_comp = plot_regression_for_comp(X= model_input['covariates'][ith_cov_name].cpu().numpy(),
                            y_pred = model_output['model_out'].detach().cpu().numpy(),
                                                y_gt=arr_csa_at_pos,
                            file_name= str(ith_cov_name) + '_' + str(ith_pos))

            writer.add_figure('comp_with_gt_' + str(ith_cov_name) + '_' + str(ith_pos), fig_comp, global_step=total_steps)


def movedict2cuda(gt, device):
    for key, value in gt.items():
        if isinstance(value, torch.Tensor):
            gt[key] = value.to(device).float()
        elif isinstance(value, dict):
            for sub_key, sub_value in value.items():
                if isinstance(sub_value, torch.Tensor):
                    gt[key][sub_key] = sub_value.to(device).float()

    return gt




def record_prediction(pd_data_t1, dict_pred):
    pd_pred_data_t1 = pd_data_t1.copy()
    for i_key, i_val in dict_pred.items():
        if isinstance(i_val, torch.Tensor):
            i_val = i_val.detach().cpu().numpy()
        pd_pred_data_t1[i_key] = i_val
    return pd_pred_data_t1





def name2idxes(list_covars: list):
    dict_idx_covariates = {}
    dict_cov_idx = {}
    for i in range(len(list_covars)):
        dict_idx_covariates[i] = list_covars[i]
        dict_cov_idx[list_covars[i]] = i
    return dict_cov_idx, dict_idx_covariates


def denormlize_ds(ds_, pd_data):
    pd_new_data = pd_data.copy()
    for ith_cov in pd_data.columns.tolist():
        if ith_cov in list(ds_.dict_normalizer.keys()):
            pd_new_data[ith_cov] = denormalize(ds_=ds_, arr_=pd_data[ith_cov].values, var_name=ith_cov)
        elif ith_cov == "SGS":
            pd_new_data[ith_cov] = pd_data[ith_cov].values
    return pd_new_data


def denormlize_arr(ds_, arr_data: np.ndarray):
    list_data = []
    list_ = [ds_.geo_var_name] + ds_.covariate_names
    for ith_cov in range(len(list_)):
        current_data = denormalize(ds_=ds_, arr_=arr_data[..., ith_cov], var_name=list_[ith_cov])
        list_data.append(current_data)
    new_data = np.concatenate(list_data, axis=-1)
    return new_data

#
def denormalize(ds_: torch.utils.data.Dataset, arr_: np.ndarray or torch.Tensor, var_name: str, WHETHER_STD: bool=False):
    if not isinstance(arr_, np.ndarray):
        arr_ = arr_.detach().cpu().numpy()

    # var_mean = ds_.dict_feat_mean[var_name]
    # var_std = ds_.dict_feat_std[var_name]
    # if WHETHER_STD:
    #     denormed = arr_ * var_std
    # else:
    #     denormed = arr_ * var_std + var_mean

    denormed = ds_.dict_normalizer[var_name].inverse_transform(arr_.reshape(-1, 1), WHETHER_STD=WHETHER_STD)
    return denormed


def normalize(ds_: torch.utils.data.Dataset, arr_: np.ndarray or torch.Tensor, var_name: str):
    # var_mean = ds_.dict_feat_mean[var_name]
    # var_std = ds_.dict_feat_std[var_name]
    # normed = (arr_ - var_mean) / var_std
    normed = ds_.dict_normalizer[var_name].transform(arr_)
    return normed




def denormalize_covariates(ds_: torch.utils.data.Dataset, arr_: np.ndarray, list_covars: list, WHETHER_STD=False):
    if not isinstance(arr_, np.ndarray):
        arr_ = arr_.cpu().numpy()

    dict_cov_idx, dict_idx_covariates = name2idxes(ds_.covariate_names)
    num_geo_dim = ds_.in_geo_features
    list_deformed = []
    for var_name in list_covars:
        current_idx = dict_cov_idx[var_name] + num_geo_dim
        current_var = arr_[..., current_idx]
        denormed = ds_.dict_normalizer[var_name].inverse_transform(current_var.reshape(-1, 1))
        list_deformed.append(denormed)

    arr_denormed = np.concatenate(list_deformed, axis=-1)
    return arr_denormed





# def denormalize(ds_, arr_, var_name, WHETHER_STD=False):
#     if not isinstance(arr_, np.ndarray):
#         arr_ = arr_.cpu().numpy()
#
#     var_mean = ds_.dict_feat_mean[var_name]
#     var_std = ds_.dict_feat_std[var_name]
#     if WHETHER_STD:
#         denormed = arr_ * var_std
#     else:
#         denormed = arr_ * var_std + var_mean
#     return denormed
#
#
# def normalize(ds_, arr_, var_name):
#     var_mean = ds_.dict_feat_mean[var_name]
#     var_std = ds_.dict_feat_std[var_name]
#     normed = (arr_ - var_mean) / var_std
#     return normed
#
#
# def denormalize_covariates(ds_: pd.DataFrame, arr_: np.ndarray, list_covars: list, WHETHER_STD=False):
#     if not isinstance(arr_, np.ndarray):
#         arr_ = arr_.cpu().numpy()
#
#     dict_cov_idx, dict_idx_covariates = name2idxes(ds_.covariate_names)
#     num_geo_dim = ds_.in_geo_features
#     list_deformed = []
#     for var_name in list_covars:
#         current_idx = dict_cov_idx[var_name] + num_geo_dim
#         var_mean = ds_.dict_feat_mean[var_name]
#         var_std = ds_.dict_feat_std[var_name]
#         if WHETHER_STD:
#             denormed = arr_[..., current_idx] * var_std
#         else:
#             denormed = arr_[..., current_idx] * var_std + var_mean
#         list_deformed.append(denormed[..., None])
#
#     arr_denormed = np.concatenate(list_deformed, axis=-1)
#     return arr_denormed


#
# def postprocess_log_normal(ds_: pd.DataFrame, mu: torch.Tensor, sigma: torch.Tensor, ci: float = 0.95):
#     """
#     Given log-domain outputs from a model, compute the mean, std, and confidence interval
#     in the original (exp-transformed) space.
#
#     Args:
#         mu: Tensor of predicted means in log-space
#         logvar: Tensor of predicted log-variance in log-space
#         ci: Confidence level for interval (default=0.95)
#
#     Returns:
#         dict with:
#             - mean_y: E[y] in original space
#             - std_y: Std[y] in original space
#             - lower_ci: lower bound of log-normal CI (quantile-based)
#             - upper_ci: upper bound of log-normal CI (quantile-based)
#             - lower_std: mean - std (not a proper CI, for visualization)
#             - upper_std: mean + std (not a proper CI, for visualization)
#     """
#
#     from scipy.stats import norm
#
#     # Original space mean and std (log-normal closed-form)
#     mu_ = ds_.dict_normalizer[ds_.tgt_var_name].inverse_transform_only_standard_scaler(mu.reshape(-1, 1))
#     std_ = ds_.dict_normalizer[ds_.tgt_var_name].inverse_transform_only_standard_scaler(sigma.reshape(-1, 1),WHETHER_STD=True)
#
#
#     mean_y = np.exp(mu_ + 0.5 * std_**2)
#     var_y = (np.exp(std_**2) - 1.0) * np.exp(2 * mu_ + std_**2)
#     std_y = np.sqrt(var_y)
#
#     # Confidence interval from quantile
#     z = norm.ppf(0.5 + ci / 2.0)  # e.g., 1.96 for 95%
#     lower_ci = np.exp(mu_ - z * std_)
#     upper_ci = np.exp(mu_ + z * std_)
#
#
#
#     return {
#         'mean_y': mean_y,
#         'std_y': std_y,
#         'lower_ci': lower_ci,
#         'upper_ci': upper_ci,
#     }

def denormalize_from_distribution(ds_: torch.utils.data.Dataset, mu: torch.Tensor, sigma: torch.Tensor):
    dict_ = ds_.dict_normalizer[ds_.tgt_var_name].postprocess_(mu, sigma)
    mean_y = dict_['mean_y'].squeeze()
    std_y = dict_['std_y'].squeeze()
    high_bd_map = dict_['upper_ci'].squeeze()
    low_bd_map = dict_['lower_ci'].squeeze()
    return mean_y, std_y, low_bd_map, high_bd_map