
import torch
import torch.backends.cudnn
import torchvision
import json
from torchvision import transforms
import numpy as np
from prepare_datasets import get_real_subdataset, get_full_dataset, get_transformations, extend_dataset,get_two_real_subsets, get_validation
from resnet_training import train_resnet, train_resnet_w2s
import yaml
import random
import os

import argparse

parser = argparse.ArgumentParser(description="Run experiment")
parser.add_argument("--config", type=str, required=True, help="Path (relative) to the config")
args = parser.parse_args()

CONFIG_name = args.config
with open(f"./<configs path>/{CONFIG_name}.yaml", "r") as f:
    config = yaml.safe_load(f)
config_path_parts = CONFIG_name.split('/')
config_path_without_name = "/".join(config_path_parts[:-1])
path = "./<configs path>/" + config_path_without_name
name = config_path_parts[-1]

for setting in config["settings"]:
    if isinstance(setting["model"], str) and setting["model"].startswith("@"):
        setting["model"] = config[setting["model"][1:]]

PRETRAINED = config["PRETRAINED"]
LR = config["LR"]
EPOCHS = config["EPOCHS"]
BATCH_SIZE = config["BATCH_SIZE"]
DO_VALIDATION = config["DO_VALIDATION"]
LEAK_EXPERIMENT = config["LEAK_EXPERIMENT"]
READ_EACH_GENERATED = config["READ_EACH_GENERATED"]
DATASET_EXPANSION = config["DATASET_EXPANSION"]
EXPANDREALS = config["EXPANDREALS"]
real_subset_count = config["real_subset_count"]
leak_subset_count = config["leak_subset_count"]
n_generated = config["n_generated"]
experiment_count = config["experiment_count"]
DO_W2S = config["DO_W2S"]
do_clip = config["do_clip"]
MODELS_USED = config["MODELS_USED"]
settings = config["settings"]
index_file = config.get("index_file", "./<default index path>")
output_dir = f"<results path>"
PRUNE = config.get("PRUNE", False)
SEED = 10
only_save_indices = config.get("only_save_indices", False)
add_synthetic = config.get("add_synthetic", False)
model_name = config.get("model_name", "resnet18")

is_vit=False
if model_name.lower()=="vit" and PRETRAINED!="No":
    is_vit=True


random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
transform_test = get_transformations("cifar10",is_vit=is_vit)[2]
train_set_full = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)
if DO_VALIDATION:
    train_set_full, validation_loader = get_validation(train_set_full,BATCH_SIZE,is_vit=is_vit)
    targets = np.array([train_set_full.dataset.targets[i] for i in train_set_full.indices])
else:
    validation_loader = None
    targets = np.array(train_set_full.targets)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

accs = {config['label']: [] for config in settings}
leak_counts = {config['label']: [] for config in settings}
results = {}

for run in range(experiment_count):
    if not LEAK_EXPERIMENT:
        subset_train2, cifar10features = get_real_subdataset("cifar10",train_set_full, targets, subset_count=real_subset_count,clip=do_clip,expand_reals=EXPANDREALS, index_file=f"{index_file}_{run}.pkl")
    else:
        subset_leak, subset_train2, features_leak, cifar10features=get_two_real_subsets("cifar10",train_set_full,\
                                    targets,leak_count=leak_subset_count,subset_count=real_subset_count,clip=do_clip, index_file=f"{index_file}_{run}.pkl",leak_index_file= f"{index_file}_leak_{run}.pkl")
        

    if only_save_indices:
        print(f"Indices saved for run {run + 1} in {index_file}_{run}.pkl")
        continue

    for config in settings:
        print(f" Testing setting: {config['label']}")
        effective_subset_leak = subset_leak if config.get("leak", True) else None
        effective_features_leak = features_leak if config.get("leak", True) else None
        trainloader, testloader, generator_model = get_full_dataset(
            dataset_name = "cifar10",
            model_names = config["model"],
            subset_train=subset_train2,
            test_set = test_set,
            use_generative=config["use_generative"],
            cifar10_real_features=cifar10features,
            number_of_generated=n_generated,
            batch_size=BATCH_SIZE,
            method=config["method"],
            zero_centered=config["zero_centered"],
            clip=do_clip,
            expand = DATASET_EXPANSION,
            leak_dataset=effective_subset_leak,
            leak_features=effective_features_leak,
            read_amount = READ_EACH_GENERATED,
            Prune = PRUNE,
            SEED = SEED,
            add_synthetic = add_synthetic,
            is_vit=is_vit
        )
        if DO_W2S:
            print("Training with W2S method")
            acc = train_resnet_w2s(
                trainloader, testloader,
                epochs=EPOCHS,
                learning_rate=LR,
                if_resnet18=True,
                pretrained=PRETRAINED,
                num_classes=10,
                validation_loader = validation_loader,
                model_name=model_name
            )
        else:
            print("Training with standard method")
            acc = train_resnet(
                trainloader, testloader,
                epochs=EPOCHS,
                learning_rate=LR,
                if_resnet18=True,
                pretrained=PRETRAINED,
                num_classes=10,
                validation_loader = validation_loader,
                model_name=model_name
            )

        print(f"Accuracy: {acc:.4f}")
        accs[config['label']].append(acc)
        if LEAK_EXPERIMENT:
            leak_counts[config['label']].append(generator_model.get_number_of_chosen_leaks())
        train_loader = None
        test_loader = None
        torch.cuda.empty_cache()

for label, scores in accs.items():
    scores = np.array(scores)
    results[label] = {
        "mean": scores.mean().item(),
        "std": scores.std().item(),
        "all": scores.tolist()
    }
    if LEAK_EXPERIMENT:
        lc = np.array(leak_counts[label])
        results[label]["leak_mean"] = lc.mean().item()
        results[label]["leak_all"] = lc.tolist()

os.makedirs(output_dir, exist_ok=True)
output_file = os.path.join(output_dir, f"{name}_results.json")

with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)
