
import torch
import torch.backends.cudnn
import torchvision
from torchvision import transforms
import numpy as np
from prepare_datasets import get_real_subdataset, get_full_dataset, get_transformations, extend_dataset,get_two_real_subsets
from resnet_training import train_resnet

torch.backends.cudnn.benchmark = True
transform_pretrain = transforms.Compose([
    transforms.RandomCrop(32, padding=4),        # Random crop with padding
    transforms.RandomHorizontalFlip(),           # Flip horizontally with 0.5 prob
    transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),   # Brightness/contrast/saturation/hue
    # transforms.RandomErasing(p=0.2, scale=(0.02, 0.1), ratio=(0.3, 3.3)),  # Regularization
    # transforms.ToTensor(),
    # transforms.Normalize((0.4914, 0.4822, 0.4465),
    #                      (0.2023, 0.1994, 0.2010)),
])

transform_test = get_transformations("cifar10")[2]
train_set_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
PRETRAINED = "Full"
LR = 0.001
EPOCHS = 25
EXTEND_TRAIN_DATASET = False
if EXTEND_TRAIN_DATASET:
    train_set_full = extend_dataset(train_set_full)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

targets = np.array(train_set_full.targets)

real_subset_fraction = 0.02
n_generated = 100
experiment_count = 10
do_clip =False


print("===== Experiment Settings =====")
print(f"Pretrained Model     : {PRETRAINED}")
print(f"Learning Rate        : {LR}")
print(f"Epochs               : {EPOCHS}")
print(f"Extend Train Dataset : {EXTEND_TRAIN_DATASET}")
print(f"Subset Fraction      : {real_subset_fraction}")
print(f"Generated Samples    : {n_generated}")
print(f"Experiment Count     : {experiment_count}")
print(f"Using CLIP Features  : {do_clip}")
print("================================")

settings = [
    {"label": "real-only", "use_generative": False, "method": None,"zero_centered": False},
    {"label": "random", "use_generative": True, "method": "random","zero_centered": True},
    {"label": "l2-near", "use_generative": True, "method": "l2-near", "zero_centered": True},
    {"label": "frobenius", "use_generative": True, "method": "MPfast","zero_centered": True},
    {"label": "maxmin", "use_generative": True, "method": "cover-maxmin-batched", "zero_centered": True},
    {"label": "kmean", "use_generative": True, "method": "kmeans-diverse", "zero_centered": True}
]

accs = {config['label']: [] for config in settings}
results = {}

for run in range(experiment_count):
    print(f"\n Run {run + 1}/{experiment_count}")
    subset_train2, cifar10features = get_real_subdataset("cifar10",train_set_full, targets, subset_fraction=real_subset_fraction,clip=do_clip)
    for config in settings:
        print(f" Testing setting: {config['label']}")

        trainloader, testloader,_ = get_full_dataset(
            dataset_name = "cifar10",
            subset_train=subset_train2,
            test_set = test_set,
            use_generative=config["use_generative"],
            cifar10_real_features=cifar10features,
            number_of_generated=n_generated,
            batch_size=128,
            generated_root="./cifar10_sd14-4000",
            method=config["method"],
            zero_centered=config["zero_centered"], #NOTE THAT THIS IS FALSSEEEEE
            clip=do_clip
        )

        acc = train_resnet(
            trainloader, testloader,
            epochs=EPOCHS,
            learning_rate=LR,
            if_resnet18=True,
            pretrained=PRETRAINED,
            num_classes=10
        )

        print(f"Accuracy: {acc:.4f}")
        accs[config['label']].append(acc)

for label, scores in accs.items():
    scores = np.array(scores)
    results[label] = {
        "mean": scores.mean(),
        "std": scores.std(),
        "all": scores
    }

print("Results:")
for key, val in results.items():
    print(f"{key:>10s} → Avg: {val['mean']:.4f}, Std: {val['std']:.4f}")