import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import os
import torch
import torch.backends.cudnn
import torchvision
from torchvision import transforms
import numpy as np
from prepare_datasets import get_real_subdataset, get_full_dataset, get_transformations, extend_dataset,get_two_real_subsets
from resnet_training import train_resnet
from theoryestimates import calculate_alpha_estimate

DATASET_EXPANSION = 1
real_subset_count = 800
leak_subset_count = 500
n_generated = 800
experiment_count = 10
do_clip = False
READ_EACH_GENERATED = 4000
BATCH_SIZE = 128

print(real_subset_fraction,leak_subset_fraction,n_generated)

transform_test = get_transformations("cifar10")[2]
train_set_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
targets = np.array(train_set_full.targets)


# subset_train2, cifar10features = get_real_subdataset(subset_fraction=real_subset_fraction,clip=do_clip)
subset_leak, subset_train2, features_leak, cifar10features=get_two_real_subsets("cifar10",train_set_full,\
                                    targets,leak_count=leak_subset_count,subset_count=real_subset_count,clip=do_clip)

NAME = "MODELComparison"
output_dir = "leak_plots"
os.makedirs(output_dir, exist_ok=True)
output_dir = f"leak_plots/{NAME}-rln-{real_subset_fraction}-{leak_subset_fraction}-{n_generated}"
os.makedirs(output_dir, exist_ok=True)

settings = [
    # {"label": "real-only", "use_generative": False, "method": None,"zero_centered": False,"model": ["flux"],"leak":False},
    {"label": "random flux", "use_generative": True, "method": "random","zero_centered": True,"model": ["flux"],"leak":False},
    {"label": "random sd14", "use_generative": True, "method": "random","zero_centered": True,"model": ["sd14"],"leak":False},
    {"label": "random sd3", "use_generative": True, "method": "random","zero_centered": True,"model": ["sd3"],"leak":False}
    # {"label": "l2-near", "use_generative": True, "method": "l2-near", "zero_centered": True,"model": ["flux"],"leak":True},
    # {"label": "frobenius", "use_generative": True, "method": "MPfast","zero_centered": True,"model": ["flux"],"leak":True},
    # {"label": "random/sd14", "use_generative": True, "method": "random","zero_centered": True,"model": "SD14","leak":True},
    # {"label": "l2-near/sd14", "use_generative": True, "method": "l2-near", "zero_centered": True,"model": "SD14","leak":True},
    # {"label": "frobenius/sd14", "use_generative": True, "method": "MPfast","zero_centered": True,"model": "SD14","leak":True}
    # {"label": "greedy frobenius", "use_generative": True, "method": "greedy-frobenius","zero_centered": True,"model": ["sd14","flux","pixart","sana1.5","sd3"],"leak":False},
    # {"label": "maxmin", "use_generative": True, "method": "cover-maxmin-batched","zero_centered": True},
    # {"label": "kmean", "use_generative": True, "method": "kmeans-diverse","zero_centered": True}
]
for config in settings:
    print(f"Processing config: {config['label']}")

    effective_subset_leak = subset_leak if config.get("leak", True) else None
    effective_features_leak = features_leak if config.get("leak", True) else None

    generated_data_count = n_generated
    # Create dataset
    _,_,gendataset = get_full_dataset(
        dataset_name = "cifar10",
        model_names = config["model"],
        subset_train=subset_train2,
        test_set = test_set,
        use_generative=config["use_generative"],
        cifar10_real_features=cifar10features,
        number_of_generated=n_generated,
        batch_size=BATCH_SIZE,
        method=config["method"],
        zero_centered=config["zero_centered"], #NOTE THAT THIS IS FALSSEEEEE
        clip=do_clip,
        expand = DATASET_EXPANSION,
        leak_dataset=effective_subset_leak,
        leak_features=effective_features_leak,
        read_amount = READ_EACH_GENERATED
    )
    
    if effective_subset_leak is not None:
        leak_usage_count = gendataset.get_number_of_chosen_leaks()
    else:
        leak_usage_count = 0
    # Plot setup
    fig, axes = plt.subplots(2, 5, figsize=(16, 8))
    axes = axes.flatten()

    alpha_vals = []
    for class_id in range(10):
        if config["label"] == "full-generated":
            class_name = f"{class_id:02d}_"
            class_dir = [k for k in gendataset.dinofeatures.keys() if k.startswith(class_name)][0]
            gen_features = gendataset.dinofeatures[class_dir]
        elif config["label"] == "leaks":
            gen_features = features_leak[f'{class_id}']
        else:
            gen_features = gendataset.get_class_generated_features(class_id)
        real_features = cifar10features[f'{class_id}']

        pca = PCA(n_components=2)
        pca.fit(real_features)

        real_2d = pca.transform(real_features)
        gen_2d = pca.transform(gen_features)

        #alpha calculation
        # alpha = np.round(calculate_alpha_estimate(real_features,gen_features),4)
        try:
            alpha = np.round(calculate_alpha_estimate(cifar10features[f'{class_id}'],gen_features),4)
        except Exception as e:
            print(f"Error calculating alpha for class_id {class_id}: {e}")
            alpha = -1
        alpha_vals.append(alpha)
        ax = axes[class_id]
        ax.scatter(real_2d[:, 0], real_2d[:, 1], alpha=0.5, label='Real', c='blue', s=15)
        ax.scatter(gen_2d[:, 0], gen_2d[:, 1], alpha=0.5, label='Generated', c='red', s=15)
        ax.set_title(f'Class {class_id} - test alpha {alpha}')
        ax.set_xticks([])
        ax.set_yticks([])
        ax.grid(True)
    alpha_vals = np.array(alpha_vals)
    handles, labels_ = ax.get_legend_handles_labels()
    fig.legend(handles, labels_, loc='upper right')
    plt.suptitle(f'Real & Chosen Generated (2D): {config["label"]} - leaks= {leak_usage_count}, alpha = {np.mean(alpha_vals)}', fontsize=16)
    plt.tight_layout(rect=[0, 0, 0.98, 0.95])

    # Save figure
    filename = config["label"].replace(" ", "_").replace("/", "-").lower() + ".png"
    plt.savefig(os.path.join(output_dir, filename), dpi=150)
    plt.close(fig)

    print(f"Saved plot to {os.path.join(output_dir, filename)}")