import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
import optuna
from recorders.gradient_recorder import compute_gradient_norms
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def predict_by_max_logit(logits):
    return torch.argmax(logits, dim=-1)


def compute_accuracy_from_predictions(predictions, labels):
    """
    Compute classification accuracy.
    """
    return torch.mean(torch.eq(labels, predictions).float())

def create_head(feature_dim: int, num_classes: int):
    head = nn.Linear(feature_dim, num_classes)
    head.weight.data.fill_(0.0)
    head.bias.data.fill_(0.0)
    head.to(DEVICE)
    return head


def validate_linear(model, val_loader):
    model.eval()

    with torch.no_grad():
        labels = []
        predictions = []
        for batch_images, batch_labels in val_loader:
            batch_images = batch_images.to(DEVICE)
            batch_labels = batch_labels.type(torch.LongTensor).to(DEVICE)
            logits = model(batch_images)
            predictions.append(predict_by_max_logit(logits))
            labels.append(batch_labels)
            del logits
        predictions = torch.hstack(predictions)
        labels = torch.hstack(labels)
        accuracy = compute_accuracy_from_predictions(predictions, labels)
    return accuracy


def learn_linear_layer(
    train_loader: DataLoader,
    args: dict,
    feature_dim: int,
    num_classes: int,
    val_loader: DataLoader = None,
    record_l2_norms=False,
    complete_x_train=None,
    complete_y_train=None,
):
    head = create_head(feature_dim=feature_dim, num_classes=num_classes)
    optimizer = torch.optim.Adam(head.parameters(), lr=args.learning_rate)
    head.train()
    if record_l2_norms:
        all_l2_norms = torch.zeros(len(train_loader) * args.epochs, len(complete_y_train))  # B*epochs, N
    else:
        all_l2_norms = torch.ones(1)
    step_number = 0
    for _ in range(args.epochs):
        for batch_images, batch_labels in train_loader:
            # record l2 norms
            if record_l2_norms:
                l2_norms_step = compute_gradient_norms(
                    original_model=head,
                    X=complete_x_train,
                    y=complete_y_train,
                    loss_function=torch.nn.functional.cross_entropy,
                )
                all_l2_norms[step_number, :] = l2_norms_step
            batch_images = batch_images.to(DEVICE)
            batch_labels = batch_labels.type(torch.LongTensor).to(DEVICE)
            optimizer.zero_grad()
            logits = head(batch_images)
            loss = torch.nn.functional.cross_entropy(logits, batch_labels)
            loss.backward()
            optimizer.step()
            step_number += 1

    train_accuracy = validate_linear(head, train_loader)
    if val_loader is not None:
        val_accuracy = validate_linear(head, val_loader)
    else:
        val_accuracy = None

    return train_accuracy, val_accuracy, head, all_l2_norms


def objective_func(trial, train_loader: DataLoader, val_loader: DataLoader, args, feature_dim: int, num_classes: int):

    if args.private:
        args.max_grad_norm = trial.suggest_float("max_grad_norm", args.max_grad_norm_lb, args.max_grad_norm_ub)

    args.train_batch_size = trial.suggest_int("batch_size", args.train_batch_size_lb, args.train_batch_size_ub)

    args.learning_rate = trial.suggest_float("learning_rate", args.learning_rate_lb, args.learning_rate_ub)
    args.epochs = trial.suggest_int("epochs", args.epochs_lb, args.epochs_ub)

    _, val_accuracy, _, _ = learn_linear_layer(
        train_loader=train_loader, val_loader=val_loader, args=args, feature_dim=feature_dim, num_classes=num_classes
    )

    return val_accuracy


def optimize_hyperparameters(args, features, labels, feature_dim, num_classes, seed):
    all_data = TensorDataset(features, labels)
    train_set, val_set = torch.utils.data.random_split(all_data, [0.7, 0.3])
    train_loader = DataLoader(train_set, batch_size=min(args.train_batch_size, len(train_set)), shuffle=True)
    val_loader = DataLoader(val_set, batch_size=10, shuffle=False)

    # hyperparameter optimization
    sampler = optuna.samplers.TPESampler(seed=seed)
    study = optuna.create_study(study_name="dp_mia", direction="maximize", sampler=sampler)
    study.optimize(
        lambda trial: objective_func(trial, train_loader, val_loader, args, feature_dim, num_classes),
        n_trials=args.number_of_trials,
    )

    print("Best trial:")
    trial = study.best_trial

    print("Value: ", trial.value)

    print("Params: ")
    for key, value in trial.params.items():
        print("{}: {}".format(key, value))

    if args.private:
        args.max_grad_norm = trial.params["max_grad_norm"]
    args.train_batch_size = trial.params["batch_size"]
    args.learning_rate = trial.params["learning_rate"]
    args.epochs = trial.params["epochs"]

    return args
