import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision import models
from PIL import Image
import csv
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch.nn.functional as F
import matplotlib.pyplot as plt
from PIL import Image
from sklearn.neighbors import KernelDensity
from scipy.integrate import quad


from model.CDQAE import GoldenRoad

from torchvision.transforms import RandomApply
from torch.utils import data
from data.standard_data import StandardData,ValStandardData
import argparse
from scipy.stats import wasserstein_distance


def normalize_features(full_features, subset_features):
    """
    将特征按每个维度归一化到 [0, 1]
    """
    min_vals = np.min(full_features, axis=0)
    max_vals = np.max(full_features, axis=0)
    ranges = np.where(max_vals - min_vals == 0, 1, max_vals - min_vals)  
    full_norm = (full_features - min_vals) / ranges
    subset_norm = (subset_features - min_vals) / ranges
    return full_norm, subset_norm

def calculate_information_entropy(full_features, subset_features):
    """
    :param full_features: ndarray, shape (N, D) 原始全集特征
    :param subset_features: ndarray, shape (n, D) 子集特征
    :return: (float, list), (平均信息熵, 每一维度信息量列表)
    """
    # full_features, subset_features = normalize_features(full_features, subset_features)
    full_features_normalized = np.zeros_like(full_features)
    mins_full = np.min(full_features, axis=0)
    maxs_full = np.max(full_features, axis=0)
    ranges_full = maxs_full - mins_full
    ranges_full[ranges_full == 0] = 1 
    full_features_normalized = (full_features - mins_full) / ranges_full
    N, D = full_features.shape
    info_list = []
    subset_features_normalized = (subset_features - mins_full) / ranges_full
    for i in range(D):
        info_value = calculate_I_Di(full_features_normalized[:, i], subset_features_normalized[:, i])
        info_list.append(info_value)
        print(f"Information content of dimension {i}: {info_value:.4f} bits")

    mean_entropy = np.mean(info_list)
    print(f"\nAverage entropy Ĥ(S): {mean_entropy:.4f} bits")
    return mean_entropy, info_list


def estimate_entropy_kde(values, bandwidth=1):

    values = np.asarray(values).reshape(-1, 1)
    kde = KernelDensity(kernel='gaussian', bandwidth=bandwidth)
    kde.fit(values)

    def pdf(x):
        x = np.array(x).reshape(-1, 1)
        return np.exp(kde.score_samples(x))

    def entropy_integrand(x):
        p = pdf([[x]])[0]
        return -p * np.log2(p + 1e-12)

    entropy, _ = quad(entropy_integrand, 0.0, 1.0, limit=100)
    return entropy

def calculate_I_cover(full_values, subset_values, k=None, alpha=3.0, beta=3.0, bandwidth=1):

    x_min, x_max = np.min(full_values), np.max(full_values)
    N = len(full_values)
    n = len(subset_values)

    if N <= 1 or x_max == x_min:
        return 0.0

    if k is None:
        k = int(np.sqrt(N))


    H_P = estimate_entropy_kde(full_values, bandwidth=bandwidth)


    bin_edges = np.linspace(x_min, x_max, k + 1)
    P_hist, _ = np.histogram(full_values, bins=bin_edges, density=False)
    valid_bins_full = np.where(P_hist > 0)[0]
    k_eff = len(valid_bins_full)

    if k_eff <= 1:
        return 0.0

    W = (x_max - x_min) / k
    covered_bins = set()
    for val in subset_values:
        idx = int((val - x_min) / W)
        idx = min(idx, k - 1)
        covered_bins.add(idx)

    c = len(covered_bins)


    total_info = H_P * np.log2(N + 1)

    coverage_score =  c / k_eff

    size_score = np.log2(n+1)/np.log2(N+1)

    I_cover = total_info * coverage_score * size_score
    return I_cover


def calculate_I_Di(full_values, subset_values):
    I_cover = calculate_I_cover(full_values, subset_values)
    
    return I_cover






import pandas as pd
import numpy as np
import os

def select_subset_with_limited_classes(
    csv_path: str,
    output_path: str,
    feature_col: str,
    subset_size: int,
    allowed_class_num: int,
    seed: int = 81769
):

    df = pd.read_csv(csv_path)
    np.random.seed(seed)
    # seed = None
    unique_classes = df[feature_col].unique()

    if allowed_class_num > len(unique_classes):
        raise ValueError(f"Requested number of classes {allowed_class_num} exceeds the total number of classes {len(unique_classes)}")

    selected_classes = np.random.choice(unique_classes, size=allowed_class_num, replace=False)
    df_selected = df[df[feature_col].isin(selected_classes)]

    if len(df_selected) < subset_size:
        raise ValueError(
            f"From the selected {allowed_class_num} classes, at most {len(df_selected)} samples can be drawn, "
            f"which is fewer than the target {subset_size}"
        )

    df_sampled = df_selected.sample(n=subset_size, random_state=seed)
    df_sampled["label"] = 1

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    df_sampled.to_csv(output_path, index=False)
    print(f"Subset saved to {output_path}, number of samples: {len(df_sampled)}, number of allowed classes: {allowed_class_num}")


def select_random_subset(csv_path: str, output_path: str, subset_size: int, seed: int = 123):
    df = pd.read_csv(csv_path)
    np.random.seed(seed)

    if len(df) < subset_size:
        raise ValueError(f"Requested subset size {subset_size} exceeds the total number of samples {len(df)}")

    df_sampled = df.sample(n=subset_size, random_state=seed)
    df_sampled["label"] = 1

    output_dir = os.path.dirname(output_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)

    df_sampled.to_csv(output_path, index=False)
    print(f"Random subset saved to {output_path}, number of samples: {len(df_sampled)}")

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 10  
feature_col = 'object_type'  
csv_dir = ''
txt_output_path = ''
full_feature_path = ""
whole_csv = ''
attribute_csv = ''
pretrained_path = ''

features_All = torch.load(full_feature_path, map_location='cpu')
features_Allnp = features_All.numpy()

model = GoldenRoad().to(device)
checkpoint = torch.load(pretrained_path, map_location=device)

from collections import OrderedDict
state_dict = checkpoint.get('state_dict', checkpoint)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
    k = k[len('model.'):] if k.startswith('model.') else k
    new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval()

os.makedirs(csv_dir, exist_ok=True)
with open(txt_output_path, 'w') as f_out:
    for class_num in range(1, 16):
        csv_path = os.path.join(csv_dir, f'{feature_col}_{class_num}.csv')
        
        select_subset_with_limited_classes(
            csv_path=attribute_csv,
            output_path=csv_path,
            feature_col=feature_col,
            subset_size=800,
            allowed_class_num=class_num
        )

        val_dataset = ValStandardData(csv_file=csv_path, train=False)
        val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

        feature_sublist = []
        with torch.no_grad():
            loop = tqdm(val_loader, desc=f"Extracting ({class_num})", leave=False)
            for images1, images2, _ in loop:
                x1 = images1.to(device)
                x2 = images2.to(device)

                z_quant, _, _, _ = model([x1, x2], 5)

                if z_quant.dim() == 1:
                    z_quant = z_quant.unsqueeze(0)

                feature_sublist.append(z_quant.cpu())

        features_sub = torch.cat(feature_sublist, dim=0)
        features_subnp = features_sub.numpy()

        mean_info, info_values = calculate_information_entropy(features_Allnp, features_subnp)

        f_out.write(f'ClassNum: {class_num}, MeanInfo: {mean_info:.6f} bits\n')
        for i, val in enumerate(info_values):
            f_out.write(f'  Dim {i:02d}: {val:.6f} bits\n')
        f_out.write('\n')
