import numpy as np
import torch
import gc
import pickle
import os
import argparse
from training.cached_feature_loader import CachedFeatureLoader
import time

from training.learn_linear_layer import learn_linear_layer, optimize_hyperparameters
from torch.utils.data import DataLoader, TensorDataset
from lira.lira import convert_logit_to_prob, calculate_statistic, log_loss

DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def parse_command_line():
    parser = argparse.ArgumentParser()

    # hyperparameters
    parser.add_argument("--learning_rate", "-lr", type=float, default=0.003, help="Learning rate.")
    parser.add_argument("--epochs", "-e", type=int, default=400, help="Number of fine-tune epochs.")
    parser.add_argument("--train_batch_size", "-b", type=int, default=200, help="Batch size.")

    # differential privacy options
    parser.add_argument(
        "--private", dest="private", default=False, action="store_true", help="If true, use differential privacy."
    )
    parser.add_argument("--noise_multiplier", type=float, default=1.0, help="Noise multiplier.")
    parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm.")
    parser.add_argument("--target_epsilon", type=float, default=10.0, help="Maximum value of epsilon allowed.")

    # tuning params
    parser.add_argument("--epochs_lb", type=int, default=1, help="LB of fine-tune epochs.")
    parser.add_argument("--epochs_ub", type=int, default=200, help="UB of fine-tune epochs.")
    parser.add_argument("--train_batch_size_lb", type=int, default=10, help="LB of Batch size.")
    parser.add_argument("--train_batch_size_ub", type=int, default=1000, help="UB of Batch size.")
    parser.add_argument("--max_grad_norm_lb", type=float, default=0.2, help="LB of maximum gradient norm.")
    parser.add_argument("--max_grad_norm_ub", type=float, default=10.0, help="UB of maximum gradient norm.")
    parser.add_argument("--learning_rate_lb", type=float, default=1e-7, help="LB of learning rate")
    parser.add_argument("--learning_rate_ub", type=float, default=1e-2, help="UB of learning rate")
    parser.add_argument("--number_of_trials", type=int, default=20, help="The number of trials for optuna")

    # LiRA options
    parser.add_argument(
        "--num_shadow_models", type=int, default=256, help="Number of shadow models to train for the LiRA attack."
    )
    parser.add_argument("--checkpoint_dir", "-c", default="../checkpoints", help="Directory to save checkpoint to.")

    parser.add_argument("--n_classes", type=int, default=-1, help="Number of classes.")
    parser.add_argument("--shots", type=int, default=-1, help="Shots (per class). Default -1 means all")
    parser.add_argument(
        "--feature_extractor", choices=["vit-b-16", "BiT-M-R50x1"], help="Feature extractor used for shadow models."
    )
    parser.add_argument(
        "--dataset",
        choices=[
            "cifar100",
            "cifar10",
            "svhn",
            "oxford_iiit_pet",
            "patch_camelyon",
            "resisc45",
            "dtd",
            "oxford_flowers102",
            "diabetic_retinopathy_detection",
            "eurosat",
        ],
        help="Dataset used for shadow models.",
    )

    parser.add_argument("--record_l2_norms", action="store_true", help="If specified, record l2 gradient norms.")
    parser.add_argument(
        "--use_specified_hypers",
        action="store_true",
        help="If specified, use specified hypers and don't optimize them.",
    )

    # seeding
    parser.add_argument(
        "--seed", type=int, default=0, help="Seed for hyperparameter optimization and dataset sampling."
    )
    parser.add_argument("--data_seed", type=int, default=0, help="Seed for dataset sampling.")

    # utils
    parser.add_argument("--data_path", help="Location of the cached feature dims.")

    args = parser.parse_args()
    return args


def get_stat_and_loss_aug(model, x, y, sample_weight=None):
    """A helper function to get the statistics and losses.

    Here we get the statistics and losses for the images.

    Args:
        model: model to make prediction
        x: samples
        y: true labels of samples (integer valued)
        sample_weight: a vector of weights of shape (n_samples, ) that are
            assigned to individual samples. If not provided, then each sample is
            given unit weight. Only the LogisticRegressionAttacker and the
            RandomForestAttacker support sample weights.
        batch_size: the batch size for model.predict

    Returns:
        the statistics and cross-entropy losses
    """
    losses, stat, accuracies = [], [], []
    data = x.to(DEVICE)
    with torch.no_grad():
        logits = model(data).cpu().numpy()
        predictions = np.argmax(logits, axis=-1)
        accuracies.append((y == predictions).astype(float))
    prob = convert_logit_to_prob(logits)
    losses.append(log_loss(y, prob, sample_weight=sample_weight))
    stat.append(calculate_statistic(prob, y, sample_weight=sample_weight, is_logits=False))
    return (
        np.expand_dims(np.concatenate(stat), axis=1),
        np.expand_dims(np.concatenate(losses), axis=1),
        np.expand_dims(np.concatenate(accuracies), axis=1),
        logits,
    )


def _save_dict_to_csv(dict, path):
    with open(path, "w") as f:
        f.write("key,value\n")
        for key in dict:
            f.write(key + "," + str(dict[key]) + "\n")


def run_lira(args, x_train, y_train, feature_dim, num_classes, x_test, y_test):
    # Sample weights are set to `None` by default, but can be changed here.
    sample_weight = None
    n = x_train.shape[0]
    # delta = 1.0 / float(n / 2)

    # Train the target and shadow models. We will use one of the model in `models`
    # as target and the rest as shadow.
    # Here we use the same architecture and optimizer. In practice, they might
    # differ between the target and shadow models.
    in_indices = []  # a list of in-training indices for all models
    stat = []  # a list of statistics for all models
    accuracies_all = []  # a list of accuries for all models
    test_accuracies_all = dict()
    all_l2_norms = []
    all_logits = []

    for i in range(args.num_shadow_models + 1):
        # Generate a binary array indicating which example to include for training
        in_indices_i = np.random.binomial(1, 0.5, n).astype(bool)
        in_indices.append(in_indices_i)

        # train the model
        model_train_images = x_train[in_indices[-1]]
        model_train_labels = y_train[in_indices[-1]]
        model_train_images = model_train_images.to(DEVICE)
        model_train_labels = model_train_labels.to(DEVICE)

        train_loader = DataLoader(
            TensorDataset(model_train_images, model_train_labels),
            batch_size=min(args.train_batch_size, len(model_train_labels)),
            shuffle=True,
        )

        _, _, model, all_l2_norms_model = learn_linear_layer(
            train_loader=train_loader,
            val_loader=None,
            args=args,
            feature_dim=feature_dim,
            num_classes=num_classes,
            record_l2_norms=args.record_l2_norms,
            complete_x_train=x_train,
            complete_y_train=y_train,
        )

        # Get the statistics of the current model.
        s, _, accuracies, logits = get_stat_and_loss_aug(model, x_train, y_train.numpy(), sample_weight)
        stat.append(s)
        all_logits.append(logits)
        accuracies_all.append(accuracies)
        print(
            f"Trained model #{i} with {in_indices[-1].sum()} examples. In train acc {accuracies[in_indices_i].mean()} Out acc: {round(accuracies[~in_indices_i].mean(), 2)}",
            flush=True,
        )

        # save L2 norms of gradients
        if args.record_l2_norms:
            all_l2_norms.append(all_l2_norms_model)

        # evaluate test accuracy on 10 models:
        if i < 10 and args.dataset != "resisc45" and args.dataset != "eurosat":
            _, _, test_accuracies = get_stat_and_loss_aug(model, x_test, y_test.numpy(), sample_weight)
            test_accuracies_all[str(i)] = np.average(test_accuracies)
            print(f"test accuracy is {np.average(test_accuracies)}")

        # Avoid OOM
        del model
        torch.cuda.empty_cache()
        gc.collect()

    # save stat, in_indices, accuracies and args
    with open(
        os.path.join(
            args.checkpoint_dir,
            f"logits_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(all_logits, f)
    with open(
        os.path.join(
            args.checkpoint_dir,
            f"labels_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(y_train, f)
    with open(
        os.path.join(
            args.checkpoint_dir,
            f"stat_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(stat, f)
    with open(
        os.path.join(
            args.checkpoint_dir,
            f"in_indices_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(in_indices, f)
    with open(
        os.path.join(
            args.checkpoint_dir,
            f"accuracies_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(accuracies_all, f)
    _save_dict_to_csv(
        args.__dict__,
        os.path.join(
            args.checkpoint_dir,
            f"args_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.csv",
        ),
    )
    _save_dict_to_csv(
        test_accuracies_all,
        os.path.join(
            args.checkpoint_dir,
            f"test_acuracies_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.csv",
        ),
    )

    if args.record_l2_norms:
        with open(
            os.path.join(
                args.checkpoint_dir,
                f"l2_norms_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.pkl",
            ),
            "wb",
        ) as f:
            pickle.dump(all_l2_norms, f)


def main():
    start_time = time.time()
    args = parse_command_line()
    cached_feature_loader = CachedFeatureLoader(
        dataset=args.dataset, path_to_cache_dir=args.data_path, feature_extractor=args.feature_extractor
    )
    train_features, train_labels, selected_elements, class_mappings = cached_feature_loader.load_train_data(
        args.shots, args.n_classes, args.data_seed
    )
    feature_dim = cached_feature_loader.obtain_feature_dim()
    num_classes = len(torch.unique(train_labels))
    print(args.dataset)
    if args.dataset != "resisc45" and args.dataset != "eurosat":
        test_features, test_labels = cached_feature_loader.load_test_data(class_mappings)
    else:
        test_features, test_labels = None, None

    # save selected elements
    with open(
        os.path.join(
            args.checkpoint_dir,
            f"selected_elements_{args.dataset}_{args.feature_extractor}_{args.n_classes}_{args.shots}_{args.target_epsilon}_{args.data_seed}_{args.seed}.pkl",
        ),
        "wb",
    ) as f:
        pickle.dump(selected_elements, f)

    # optimize hyperparameters
    if not args.use_specified_hypers:
        print("optimize hyperparameters", flush=True)
        in_indices_i = np.random.binomial(1, 0.5, train_features.shape[0]).astype(bool)
        model_train_images = train_features[in_indices_i]
        model_train_labels = train_labels[in_indices_i]
        args = optimize_hyperparameters(
            args, model_train_images, model_train_labels, feature_dim, num_classes, args.seed
        )

    # train the shadow models
    run_lira(
        args=args,
        x_train=train_features,
        y_train=train_labels,
        feature_dim=feature_dim,
        num_classes=num_classes,
        x_test=test_features,
        y_test=test_labels,
    )
    print(f"elapsed time {(time.time()-start_time) / 60} minutes")


if __name__ == "__main__":
    main()
