# NEW WAY:

import torch
import torch.backends.cudnn
import torchvision
from torchvision import transforms
import numpy as np
from prepare_datasets import get_real_subdataset, get_full_dataset, get_transformations, extend_dataset,get_two_real_subsets, get_validation
from resnet_training import train_resnet

torch.backends.cudnn.benchmark = True

PRETRAINED = "No"
LR = 0.1#1e-4
EPOCHS = 50
BATCH_SIZE=128
DO_VALIDATION = True
LEAK_EXPERIMENT = False
READ_EACH_GENERATED = [1000,3000,3000]
DATASET_EXPANSION = 1
transform_test = get_transformations("cifar10")[2]
train_set_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
if DO_VALIDATION:
    train_set_full, validation_loader = get_validation(train_set_full,BATCH_SIZE)
    targets = np.array([train_set_full.dataset.targets[i] for i in train_set_full.indices])
else:
    validation_loader = None
    targets = np.array(train_set_full.targets)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

real_subset_count = 300
leak_subset_count = 10
n_generated = 700
experiment_count = 10
do_clip = False
settings = [
    {"label": "real-only", "use_generative": False, "method": None,"zero_centered": False,"model": ["sd14","sana1.5","pixart"],"leak":False},
    # {"label": "style 0.8", "use_generative": True, "method": "random","zero_centered": True,"model": ["stylegan0.8"],"leak":False},
    # {"label": "style 1.0", "use_generative": True, "method": "random","zero_centered": True,"model": ["stylegan1.0"],"leak":False},
    # {"label": "style 0.2", "use_generative": True, "method": "random","zero_centered": True,"model": ["stylegan0.2"],"leak":False},
    # {"label": "style 0.4", "use_generative": True, "method": "random","zero_centered": True,"model": ["stylegan0.4"],"leak":False},
    # {"label": "style 0.6", "use_generative": True, "method": "random","zero_centered": True,"model": ["stylegan0.6"],"leak":False},
    {"label": "random", "use_generative": True, "method": "random","zero_centered": True,"model": ["sd14","sana1.5","pixart"],"leak":False},
    {"label": "l2", "use_generative": True, "method": "l2-near","zero_centered": True,"model": ["sd14","sana1.5","pixart"],"leak":False},
     {"label": "center", "use_generative": True, "method": "nearest-to-center","zero_centered": True,"model": ["sd14","sana1.5","pixart"],"leak":False},
    {"label": "frobenius", "use_generative": True, "method": "MPfast","zero_centered": True,"model": ["sd14","sana1.5","pixart"],"leak":False},
    # {"label": "flux", "use_generative": True, "method": "random","zero_centered": True,"model": ["flux"],"leak":False},
    # {"label": "pixart", "use_generative": True, "method": "random","zero_centered": True,"model": ["pixart"],"leak":False},
    # {"label": "sd3", "use_generative": True, "method": "random","zero_centered": True,"model": ["sd3"],"leak":False},
    # {"label": "random/sd14", "use_generative": True, "method": "random","zero_centered": True,"model": "SD14","leak":True},
    # {"label": "l2-near/sd14", "use_generative": True, "method": "l2-near", "zero_centered": True,"model": "SD14","leak":True},
    # {"label": "frobenius/sd14", "use_generative": True, "method": "MPfast","zero_centered": True,"model": "SD14","leak":True}
    # {"label": "greedy frobenius", "use_generative": True, "method": "greedy-frobenius","zero_centered": True,"model": ["sd14","flux","pixart","sana1.5","sd3"],"leak":False},
    # {"label": "maxmin", "use_generative": True, "method": "cover-maxmin-batched","zero_centered": True},
    # {"label": "kmean", "use_generative": True, "method": "kmeans-diverse","zero_centered": True}
]
accs = {config['label']: [] for config in settings}
results = {}

for run in range(experiment_count):
    print(f"\n Run {run + 1}/{experiment_count}")
    if not LEAK_EXPERIMENT:
        subset_train2, cifar10features = get_real_subdataset("cifar10",train_set_full, targets, subset_count=real_subset_count,clip=do_clip)
    else:
        subset_leak, subset_train2, features_leak, cifar10features=get_two_real_subsets("cifar10",train_set_full,\
                                    targets,leak_count=leak_subset_count,subset_count=real_subset_count,clip=do_clip)
        for i in features_leak:
            print(features_leak[i].shape)
    for config in settings:
        print(f" Testing setting: {config['label']}")
        effective_subset_leak = subset_leak if config.get("leak", True) else None
        effective_features_leak = features_leak if config.get("leak", True) else None
        trainloader, testloader,_ = get_full_dataset(
            dataset_name = "cifar10",
            model_names = config["model"],
            subset_train=subset_train2,
            test_set = test_set,
            use_generative=config["use_generative"],
            cifar10_real_features=cifar10features,
            number_of_generated=n_generated,
            batch_size=BATCH_SIZE,
            method=config["method"],
            zero_centered=config["zero_centered"], #NOTE THAT THIS IS FALSSEEEEE
            clip=do_clip,
            expand = DATASET_EXPANSION,
            leak_dataset=effective_subset_leak,
            leak_features=effective_features_leak,
            read_amount = READ_EACH_GENERATED
        )

        acc = train_resnet(
            trainloader, testloader,
            epochs=EPOCHS,
            learning_rate=LR,
            if_resnet18=True,
            pretrained=PRETRAINED,
            num_classes=10,
            validation_loader = validation_loader
        )

        print(f"Accuracy: {acc:.4f}")
        accs[config['label']].append(acc)

for label, scores in accs.items():
    scores = np.array(scores)
    results[label] = {
        "mean": scores.mean(),
        "std": scores.std(),
        "all": scores
    }

print("Results:")
for key, val in results.items():
    print(f"{key:>10s} → Avg: {val['mean']:.4f}, Std: {val['std']:.4f}")