import numpy as np
import gc
import torch
from adversarial_networks import Classifier, Discriminator, Generator
from experiment_regret import train_model, train_model_with_stopping, \
    process_prediction, adversarial_predictions, gradient_step, train_baseline
from performance_metrics import metrics, metrics_adversarial
from datasets import get_batches, get_dataset, GrowingNumpyDataSet
from models import (
    TorchBinaryLogisticRegression,
    get_predictions,
    get_accuracies,
    get_accuracies_simple,
    get_error_breakdown,
)

FIXED_STEPS = True
TEST_BATCH_SIZE = 1000

DEVICE = torch.device('cuda') if torch.cuda.is_available() \
    else torch.device('cpu')


def run_regret_experiment_pytorch(
        dataset,
        training_mode,
        nn_params,
        linear_model_hparams,
        exploration_hparams,
        logging_frequency,
        seed,
        baseline_accuracy=None,
        loss_validation_baseline=None,
):
    # TODO: remove/pull up into hparams
    MLP = True
    verbose = False
    regret_wrt_baseline = exploration_hparams.regret_wrt_baseline
    num_full_minimization_steps = nn_params.num_full_minimization_steps

    # TODO
    if dataset == "MNIST" or dataset == "Adult":
        baseline_batch_size = nn_params.batch_size
    else:
        baseline_batch_size = 10
    (
        protected_datasets_train,
        protected_datasets_test,
        train_dataset,
        test_dataset,
    ) = get_dataset(
        dataset=dataset,
        batch_size=baseline_batch_size,
        test_batch_size=TEST_BATCH_SIZE,
        seed=seed,
        fit_intercept=False
    )

    if regret_wrt_baseline is True:
        if baseline_accuracy is None:
            baseline_accuracy, loss_validation_baseline = train_baseline(
                dataset,
                nn_params,
                linear_model_hparams
                )
    else:
        baseline_accuracy, loss_validation_baseline = [0, 0]

    if exploration_hparams.decision_type == "counterfactual":
        if exploration_hparams.epsilon_greedy or exploration_hparams.adjust_mahalanobis:
            raise ValueError(
                "Decision type set to counterfactual, can't set exploration constants."
            )

    if regret_wrt_baseline is True:
        print("Baseline model accuracy {}".format(baseline_accuracy))
    else:
        print("No baseline")

    accuracies_list = []
    biased_accuracies_list = []
    pseudo_error_breakdown_list = []
    eps_error_breakdown_list = []
    train_error_breakdown_list = []
    test_error_breakdown_list = []
    loss_validation = []
    loss_validation_biased = []
    train_regret = []
    train_regret_justadv = []
    metrics_adv = []
    metrics_biased = []
    protected_accepted_list = []
    actual_protected_accepted_list = []

    counter = 0
    biased_data_totals = 0

    model = TorchBinaryLogisticRegression(
        random_init=nn_params.random_init,
        alpha=exploration_hparams.alpha,
        MLP=MLP,
        representation_layer_size=nn_params.representation_layer_size,
        dim=train_dataset.dimension
    )
    model_biased = TorchBinaryLogisticRegression(
        random_init=nn_params.random_init,
        alpha=exploration_hparams.alpha,
        MLP=MLP,
        representation_layer_size=nn_params.representation_layer_size,
        dim=train_dataset.dimension
    )
    model_biased_prediction = None
    if exploration_hparams.decision_type == "counterfactual" or \
            exploration_hparams.decision_type == "justadversarial_counterfactual":
        model_biased_prediction = TorchBinaryLogisticRegression(
            random_init=nn_params.random_init,
            alpha=exploration_hparams.alpha,
            MLP=MLP,
            representation_layer_size=nn_params.representation_layer_size,
            dim=train_dataset.dimension
        )

    cummulative_data_covariance = []
    inverse_cummulative_data_covariance = []

    train_accuracies_biased = []
    timesteps = []

    biased_dataset = GrowingNumpyDataSet()
    unbiased_dataset = GrowingNumpyDataSet()

    # training parameters
    params = {
        'encoded_dim': 100,
        'hidden_layer': 100,
        # 'lambd': 5,
        'lambd': 1,
        'adv_loss_type': 'BCE'
    }
    # Optimization parameters:
    # reg controls the amount of regularisation when using WGAN loss
    opt_params = {
        'clas_lr': 1e-4,
        'gen_lr': 1e-4,
        'disc_lr': 5e-4,
        'betas': (0., 0.99),
        'reg': 0.0002,
        'batch_size': 20,
        'n_epochs': exploration_hparams.adv_num_epochs
    }

    unbiased_model = Classifier(params['encoded_dim'],
                                params['hidden_layer'])
    generator = Generator(train_dataset.dimension, params['encoded_dim'])
    discriminator = Discriminator(params['encoded_dim'],
                                  params['hidden_layer'],
                                  params['adv_loss_type'])
    generator.to(DEVICE)
    discriminator.to(DEVICE)
    unbiased_model.to(DEVICE)

    num_protected_characteristic = 0.
    actual_protected_characeristic = 0.

    while counter < nn_params.max_num_steps:
        if exploration_hparams.network_reset:
            unbiased_model = Classifier(params['encoded_dim'],
                                        params['hidden_layer'])
            generator = Generator(train_dataset.dimension,
                                  params['encoded_dim'])
            discriminator = Discriminator(params['encoded_dim'],
                                          params['hidden_layer'],
                                          params['adv_loss_type'])
            generator.to(DEVICE)
            discriminator.to(DEVICE)
            unbiased_model.to(DEVICE)

        counter += 1

        global_batch, protected_batches = get_batches(
            protected_datasets_train, train_dataset, nn_params.batch_size
        )
        batch_X, batch_y = global_batch

        if counter == 1:
            model.initialize_model(batch_X.shape[1])
            model_biased.initialize_model(batch_X.shape[1])
            if exploration_hparams.decision_type == "counterfactual" \
                    or exploration_hparams.decision_type == "justadversarial_counterfactual":
                model_biased_prediction.initialize_model(batch_X.shape[1])

            optimizer_model = torch.optim.Adam(model.network.parameters(),
                                               lr=0.01)
            optimizer_biased = torch.optim.Adam(
                model_biased.network.parameters(), lr=0.01
            )

        if training_mode == "full_minimization":
            unbiased_dataset.add_data(batch_X, batch_y)

            if FIXED_STEPS:
                model = train_model(
                    model,
                    num_full_minimization_steps,
                    unbiased_dataset,
                    nn_params.batch_size,
                )
            else:
                model = train_model_with_stopping(
                    model,
                    num_full_minimization_steps,
                    unbiased_dataset,
                    nn_params.batch_size,
                    verbose=verbose,
                    restart_model_full_minimization=nn_params.restart_model_full_minimization,
                    eps=0.0001 * np.log(counter + 2) / 2,
                )
            gc.collect()

        elif training_mode == "gradient_step":
            model, optimizer_model = gradient_step(
                model, optimizer_model, batch_X, batch_y
            )

        if exploration_hparams.decision_type == "simple":
            if biased_dataset.get_size() == 0:
                # ACCEPT ALL POINTS IF THE BIASED DATASET IS NOT INITIALIZED
                global_biased_prediction = [1 for _ in
                                            range(nn_params.batch_size)]
            else:
                global_biased_prediction, protected_biased_predictions = get_predictions(
                    global_batch,
                    protected_batches,
                    model_biased,
                    inverse_cummulative_data_covariance,
                )

        elif exploration_hparams.decision_type == "counterfactual" \
                or exploration_hparams.decision_type == "justadversarial_counterfactual":
            if training_mode != "full_minimization":
                raise ValueError(
                    "The counterfactual decision mode is incompatible with all "
                    "training modes different from full_minimization"
                )
            if biased_dataset.get_size() == 0:
                # ACCEPT ALL POINTS IF THE BIASED DATASET IS NOT INITIALIZED
                global_biased_prediction = [1 for _ in
                                            range(nn_params.batch_size)]
            else:
                # First get epsilon greedy, then apply pseudolabel.
                # batch_size x 1
                initial_biased_pred, _ = get_predictions(
                    global_batch,
                    protected_batches,
                    model_biased,
                    inverse_cummulative_data_covariance,
                )
                # TODO: check if epsilon set?
                if exploration_hparams.decision_type == "counterfactual":
                    proposed_pos = torch.rand_like(
                        initial_biased_pred) < exploration_hparams.epsilon
                else:
                    training_batch_size = min(biased_dataset.get_size(),
                                              baseline_batch_size * 100)
                    historic_X, historic_y = biased_dataset.get_batch(
                        training_batch_size)
                    historic_X = historic_X.to(torch.float32).to(DEVICE)
                    historic_y = historic_y.to(torch.float32).to(DEVICE)
                    if exploration_hparams.adv_full_test_dataset:
                        adv_batch_X, _ = unbiased_dataset.get_batch(
                            training_batch_size)
                    else:
                        adv_batch_X = batch_X
                    adv_batch_X = adv_batch_X.to(torch.float32).to(DEVICE)
                    batch_X = batch_X.to(torch.float32).to(DEVICE)
                    batch_y = batch_y.to(torch.float32).to(DEVICE)
                    print(
                        "Start of adversarial training of the unbiased model -- timestep ",
                        counter
                    )
                    unbiased_model, generator = adversarial_predictions(
                        generator, discriminator,
                        unbiased_model, opt_params,
                        params,
                        adv_batch_X, historic_X, historic_y)

                    metrics_biased.append(
                        metrics(batch_X, batch_y, model_biased))
                    metrics_adv.append(
                        metrics_adversarial(batch_X, batch_y, unbiased_model,
                                            generator))
                    proposed_pos = unbiased_model(generator(batch_X))
                global_biased_prediction = proposed_pos.squeeze()

        biased_batch_X = []
        biased_batch_y = []
        biased_batch_size = 0
        biased_train_accuracy = 0
        batch_regret = 0
        batch_regret_justadv = torch.tensor(0)

        # TODO: pull out and combine with method above.
        try:
            pred_len = len(global_biased_prediction)
        except TypeError:
            global_biased_prediction = global_biased_prediction.unsqueeze(-1)
            pred_len = len(global_biased_prediction)
        for i in range(pred_len):
            label = batch_y[i]
            accuracy, regret, accepted = process_prediction(
                global_biased_prediction[i], label, linear_model_hparams,
                exploration_hparams, regret_wrt_baseline, baseline_accuracy,
                counter
            )
            biased_train_accuracy += accuracy
            batch_regret += regret
            if accepted:
                biased_batch_X.append(batch_X[i].unsqueeze(0))
                biased_batch_y.append(label)
                biased_batch_size += 1
        size = len(global_biased_prediction)
        biased_train_accuracy = biased_train_accuracy / size
        batch_regret = batch_regret / size * 1.0

        biased_data_totals += biased_batch_size
        if len(biased_batch_X) > 0:
            biased_batch_X = torch.cat(biased_batch_X)
            biased_batch_y = torch.Tensor(biased_batch_y).to(DEVICE)

        # Train biased model on biased data
        if biased_batch_size > 0:
            if training_mode == "full_minimization":
                print("Adding data to biased dataset")
                biased_dataset.add_data(biased_batch_X, biased_batch_y)
                print(
                    "Training the biased model -- timestep ",
                    counter
                )

                if FIXED_STEPS:
                    model_biased = train_model(
                        model_biased,
                        num_full_minimization_steps,
                        biased_dataset,
                        nn_params.batch_size,
                    )
                else:
                    model_biased = train_model_with_stopping(
                        model_biased,
                        num_full_minimization_steps,
                        biased_dataset,
                        nn_params.batch_size,
                        verbose=verbose,
                        restart_model_full_minimization=nn_params.restart_model_full_minimization,
                        eps=0.0001 * np.log(counter + 2) / 2,
                    )
                gc.collect()

            elif training_mode == "gradient_step":
                model_biased, optimizer_biased = gradient_step(
                    model_biased, optimizer_biased, biased_batch_X,
                    biased_batch_y
                )

            else:
                raise ValueError("Unrecognized training mode")

            if exploration_hparams.decision_type == "simple":
                representation_X = model_biased.get_representation(
                    biased_batch_X
                ).detach()
                # representation_X = representation_X.numpy()
                representation_X = representation_X.cpu().numpy()
                if exploration_hparams.adjust_mahalanobis:
                    if len(cummulative_data_covariance) == 0:
                        cummulative_data_covariance = np.dot(
                            np.transpose(representation_X), representation_X
                        )
                    else:
                        cummulative_data_covariance = (
                                exploration_hparams.mahalanobis_discount_factor
                                * cummulative_data_covariance
                                + np.dot(np.transpose(representation_X),
                                         representation_X)
                        )

                    # This can be done instead by using the Sherman-Morrison Formula.
                    inverse_cummulative_data_covariance = torch.from_numpy(
                        np.linalg.inv(
                            exploration_hparams.mahalanobis_regularizer
                            * np.eye(representation_X.shape[1])
                            + cummulative_data_covariance
                        )
                    ).float()

        protected_characteristic = exploration_hparams.protected_characteristic

        num_protected_characteristic += batch_X[:,
                                        protected_characteristic].sum()
        actual_protected_characeristic += (
                batch_X[:, protected_characteristic] * batch_y).sum()

        # DIAGNOSTICS
        # Compute accuracy diagnostics
        if counter % logging_frequency * 1.0 == 0:

            protected_accepted = \
                biased_dataset.get_batch(biased_dataset.get_size())[0][:,
                protected_characteristic].sum() / num_protected_characteristic
            protected_accepted_list.append(protected_accepted.cpu().numpy())
            actual_protected_accepted_list.append((
                                                              actual_protected_characeristic / num_protected_characteristic).cpu().numpy())

            train_regret.append(batch_regret.cpu())
            train_regret_justadv.append(batch_regret_justadv.cpu())
            train_accuracies_biased.append(biased_train_accuracy.cpu())
            timesteps.append(counter)
            global_batch_test, protected_batches_test = get_batches(
                protected_datasets_test, test_dataset, 1000
            )
            batch_X_test, batch_y_test = global_batch_test
            total_accuracy, _ = get_accuracies(
                global_batch_test,
                protected_batches_test,
                model,
                linear_model_hparams.threshold,
            )

            with torch.no_grad():
                # Compute loss diagnostics
                biased_loss = model_biased.get_loss(batch_X_test, batch_y_test)
                loss = model.get_loss(batch_X_test, batch_y_test)
                loss_validation.append(loss.detach().cpu())
                loss_validation_biased.append(biased_loss.detach().cpu())

            accuracies_list.append(total_accuracy)
            biased_total_accuracy, _ = get_accuracies(
                global_batch_test,
                protected_batches_test,
                model_biased,
                linear_model_hparams.threshold,
            )
            biased_accuracies_list.append(biased_total_accuracy)
            if model_biased_prediction is not None:
                train_breakdown = get_error_breakdown(
                    global_batch,
                    model_biased_prediction,
                    linear_model_hparams.threshold,
                )
                test_breakdown = get_error_breakdown(
                    global_batch_test,
                    model_biased_prediction,
                    linear_model_hparams.threshold,
                )
                train_error_breakdown_list.append(train_breakdown)
                test_error_breakdown_list.append(test_breakdown)
            # Compute training biased accuracy
            # TODO: this errors sometimes! is this too big?
            # TODO: dataset_X is a list, not numpy.
            train_biased_batch = biased_dataset.get_batch(1000)
            biased_train_accuracy = get_accuracies_simple(
                train_biased_batch, model_biased,
                linear_model_hparams.threshold
            )
            with torch.no_grad():
                loss_train_biased = model_biased.get_loss(
                    train_biased_batch[0], train_biased_batch[1]
                )
                loss_train_biased = loss_train_biased.detach()

            if verbose:
                print("Iteration {}".format(counter))
                print(
                    "Total proportion of biased data {}".format(
                        1.0 * biased_data_totals / (
                                    nn_params.batch_size * counter)
                    )
                )
                print("Biased TRAIN accuracy  ", biased_train_accuracy)
                print("Biased TRAIN loss ", loss_train_biased)

                print(f"Baseline accuracy: {baseline_accuracy}")
                # Compute the global accuracy.
                print(f"Unbiased Accuracy: {total_accuracy}")
                # Compute the global accuracy.
                print(f"Biased Accuracy {biased_total_accuracy}")
                print(f"Validation Loss Unbiased: {loss_validation[-1]}")
                print(f"Validation Loss Biased {loss_validation_biased[-1]}")

    test_biased_accuracies_cum_averages = np.cumsum(biased_accuracies_list)
    test_biased_accuracies_cum_averages = test_biased_accuracies_cum_averages / (
            np.arange(len(timesteps)) + 1
    )
    accuracies_cum_averages = np.cumsum(accuracies_list)
    accuracies_cum_averages = accuracies_cum_averages / (
                np.arange(len(timesteps)) + 1)
    train_biased_accuracies_cum_averages = np.cumsum(train_accuracies_biased)
    train_biased_accuracies_cum_averages = train_biased_accuracies_cum_averages / (
            np.arange(len(timesteps)) + 1
    )
    train_cum_regret = np.cumsum(train_regret)
    train_cum_regret_justadv = np.cumsum(train_regret_justadv)
    protected_accepted_list = np.array(protected_accepted_list)
    actual_protected_accepted = np.array(actual_protected_accepted_list)
    return (
        timesteps,
        test_biased_accuracies_cum_averages,
        accuracies_cum_averages,
        train_biased_accuracies_cum_averages,
        train_cum_regret,
        train_cum_regret_justadv,
        protected_accepted_list,
        actual_protected_accepted,
        loss_validation,
        loss_validation_biased,
        loss_validation_baseline,
        baseline_accuracy,
        train_error_breakdown_list,
        test_error_breakdown_list,
        pseudo_error_breakdown_list,
        eps_error_breakdown_list,
        metrics_biased,
        metrics_adv
    )
