import numpy as np
import torch 
from torch.utils.data import DataLoader

import optuna

from functools import partial
import random
import gc

from src.dp_sgd.model import LinearHead
from src.dp_sgd.train import train_with_dp, evaluate
import src.dp_sgd.warning_ignores
from src.dp_sgd.config import num_classes, shots, min_epsilon



def objective(trial, train_dataset, validation_dataset, feature_dim, device):
    learning_rate = trial.suggest_float("learning_rate", 1e-7, 1e-2, log=True)
    max_grad_norm = trial.suggest_float("max_grad_norm", 0.2, 10.0, log=False)
    num_epochs = trial.suggest_int("num_epochs", 1, 200)
    batch_size = trial.suggest_int("batch_size", 10, int(num_classes * shots * 0.7))

    seed = torch.randint(0, 2**31 - 1, (1,)).item()

    model = LinearHead(input_dim=feature_dim, num_classes=num_classes).to(device)
    train_with_dp(
        model=model,
        train_dataset=train_dataset,
        learning_rate=learning_rate,
        max_grad_norm=max_grad_norm,
        num_epochs=num_epochs,
        batch_size=batch_size,
        max_physical_batch_size=128,
        epsilon=min_epsilon,
        delta=1 / len(train_dataset),
        device=device,
        seed=seed
    )
    validation_accuracy = evaluate(model, validation_dataset, device)

    del model
    gc.collect()
    torch.mps.empty_cache()

    return validation_accuracy


if __name__ == "__main__":
    # device = torch.accelerator.current_accelerator()
    device = "cpu"
    print(device)

    seed = 423472576
    torch.manual_seed(seed)
    torch.mps.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    train_features = torch.load(f"datasets/cifar10/few_shot_{shots}_train_vit-b-16-imagenet-21K_features.pt")
    train_labels = torch.load(f"datasets/cifar10/few_shot_{shots}_train_labels.pt")
    train_dataset = torch.utils.data.TensorDataset(train_features, train_labels)
    train_dataset, validation_dataset = torch.utils.data.random_split(
        train_dataset, [0.7, 0.3]
    )
    feature_dim = train_features.shape[1]

    study_name = f"accuracy_first_dp_sgd_ex_post_conversion"
    storage_name = f"sqlite:///results/accuracy-first/dp-sgd/{study_name}_hyperparameters.db"

    optuna.delete_study(study_name=study_name, storage=storage_name)
    study = optuna.create_study(study_name=study_name, storage=storage_name, direction="maximize", load_if_exists=False)
    study.optimize(
        partial(objective, train_dataset=train_dataset, validation_dataset=validation_dataset, feature_dim=feature_dim, device=device), 
        n_trials=20, show_progress_bar=True
    )