
import torch
import torch.backends.cudnn
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms, datasets
import numpy as np
import yaml
import random
import json
import os
from prepare_datasets import get_real_subdataset, get_full_dataset, get_transformations, extend_dataset,get_two_real_subsets
from resnet_training import train_resnet_lightning
from torch.utils.data import random_split, Subset
torch.backends.cudnn.benchmark = True

import argparse

parser = argparse.ArgumentParser(description="Run experiment")
parser.add_argument("--config", type=str, required=True, help="Path (relative) to the config")
parser.add_argument("--experiment_idx", type=int, required=False, help="Index of the experiment to run")
args = parser.parse_args()

CONFIG_name = args.config
experiment_idx = args.experiment_idx
with open(f"./configs/{CONFIG_name}.yaml", "r") as f:
    config = yaml.safe_load(f)
config_path_parts = CONFIG_name.split('/')
config_path_without_name = "/".join(config_path_parts[:-1])
path = "./configs/" + config_path_without_name
name = config_path_parts[-1]

for setting in config["settings"]:
    if isinstance(setting["model"], str) and setting["model"].startswith("@"):
        setting["model"] = config[setting["model"][1:]]

PRETRAINED = config["PRETRAINED"]
LR = config["LR"]
EPOCHS = config["EPOCHS"]
BATCH_SIZE = config["BATCH_SIZE"]
DO_VALIDATION = config["DO_VALIDATION"]
LEAK_EXPERIMENT = config["LEAK_EXPERIMENT"]
READ_EACH_GENERATED = config["READ_EACH_GENERATED"]
DATASET_EXPANSION = config["DATASET_EXPANSION"]
EXPANDREALS = config["EXPANDREALS"]
real_subset_count = config["real_subset_count"]
leak_subset_count = config["leak_subset_count"]
n_generated = config["n_generated"]
experiment_count = config["experiment_count"]
DO_W2S = config["DO_W2S"]
do_clip = config["do_clip"]
MODELS_USED = config["MODELS_USED"]
settings = config["settings"]
index_file = config.get("index_file", "<default index path>")
output_dir = "<result path>"
only_save_indices = config.get("only_save_indices", False)
PRUNE = config.get("PRUNE", False)

print(f"Config loaded: {CONFIG_name}")

SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

transform_val = get_transformations("imagenet")[2]
train_set_full = datasets.ImageFolder(root='<imagenet100 path>', transform=None)

val_ratio = 0.1
val_size = int(len(train_set_full) * val_ratio)
train_size = len(train_set_full) - val_size
train_set_full, val_subset = random_split(train_set_full, [train_size, val_size], generator=torch.Generator().manual_seed(SEED))
val_dataset_with_transform = datasets.ImageFolder(root='<imagenet100 path>', transform=transform_val)
val_subset = Subset(val_dataset_with_transform, val_subset.indices)

validation_loader = torch.utils.data.DataLoader(val_subset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_set = datasets.ImageFolder(root='<imagenet100 validation path>', transform=transform_val)

targets = np.array([train_set_full.dataset.targets[i] for i in train_set_full.indices])

accs = {config['label']: [] for config in settings}
results = {}

for run in range(experiment_count):
    if experiment_idx is not None and run != experiment_idx:
        continue
    subset_train2, realfeatures=get_real_subdataset("imagenet",train_set_full,\
                                    targets,subset_count=real_subset_count,clip=do_clip, index_file=f"{index_file}_{run}.pkl")

    if only_save_indices:
        print(f"Indices saved for run {run + 1} in {index_file}_{run}.pkl")
        continue
    for config in settings:
        effective_subset_leak = None
        effective_features_leak = None
        print(f" Testing setting: {config['label']}")

        trainloader, testloader, gen_dataset= get_full_dataset(
            dataset_name = "imagenet",
            subset_train=subset_train2,
            test_set = test_set,
            use_generative=config["use_generative"],
            cifar10_real_features=realfeatures,
            number_of_generated=n_generated,
            batch_size=BATCH_SIZE,
            method=config["method"],
            zero_centered=config["zero_centered"],
            clip=do_clip,
            expand = DATASET_EXPANSION,
            leak_dataset=effective_subset_leak,
            leak_features=effective_features_leak,
            read_amount = READ_EACH_GENERATED,
            SEED=SEED,
            Prune=PRUNE
        )
        if gen_dataset is not None:
            print(gen_dataset.samples[:10])

        acc = train_resnet_lightning(
            trainloader, testloader,
            epochs=EPOCHS,
            learning_rate=LR,
            if_resnet18=True,
            pretrained=PRETRAINED,
            num_classes=100,
            validation_loader=validation_loader if DO_VALIDATION else None,
            plot_label=f"imagenet_{name}"
        )

        print(f"Accuracy: {acc:.4f}")
        accs[config['label']].append(acc)

for label, scores in accs.items():
    scores = np.array(scores)
    results[label] = {
        "mean": scores.mean().item(),
        "std": scores.std().item(),
        "all": scores.tolist()
    }


os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"{name}_results.json")

with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)
