import argparse
import os
import pandas as pd

import numpy as np
from tqdm import tqdm
import wandb

from aiau.models.bootstrap_ensemble import BootstrapEnsemble
from sklearn.neural_network import MLPRegressor
from aiau.data.data_manager import DataManager
from aiau.oracles import (
    Oracle,
    NoisyBenchmarkOracle,
    aleotoric_and_epistemic_noise_function,
    CorrelatedNoiseOracle,
)
from aiau.data.synthetic_datasets import (
    generate_2d_synthetic_data,
    generate_non_regular_2d_synthetic_data,
    generate_freq_power_2d_synthetic_data,
)
from aiau.strategy import (
    RandomStrategy,
    LeastConfidenceStrategy,
    EMSEStrategy,
    EMSEBiasStrategy,
    BiasReductionStrategy,
    MaxTermSelectionStrategy,
    DiffLeastConfidenceStrategy,
    DiffEMSEStrategy,
    DiffBiasReductionStrategy,
    DiffMaxTermSelectionStrategy,
    BALDStrategy,
    PemsePemseDiffStrategy,
    BatchBALDStrategy,
)
from aiau.data.index_initialisation_utils import (
    lower_left_corner_indices_2d,
    random_indices,
)

#######################################################
#######################################################
# Constants
NUM_ITERATIONS = 100
NUM_INITIAL_LABELLED_INDICES = [10, 100]
QUERY_SIZE = 10
#######################################################
#######################################################


def create_output_directory_single(
    base_dir,
    estimation_approach,
    oracle_type,
    oracle_noise_level,
    dataset,
    initialisation_type,
):
    """Create a structured directory to save results."""
    setting = "batch"

    if oracle_type == 1:
        oracle_type_string = "type_1"
    elif oracle_type == 2:
        oracle_type_string = "type_2"
    elif oracle_type == 3:
        oracle_type_string = "type_3"
    else:
        raise ValueError("Invalid oracle type")

    if oracle_type == 1:
        path = os.path.join(
            base_dir,
            setting,
            estimation_approach,
            oracle_type_string,
            dataset,
            initialisation_type,
        )
    else:
        path = os.path.join(
            base_dir,
            setting,
            estimation_approach,
            oracle_type_string,
            str(oracle_noise_level),
            dataset,
            initialisation_type,
        )
    os.makedirs(path, exist_ok=True)
    return path


def setup_dataset(dataset):
    if dataset == "SyntheticRegular":
        X, y = generate_2d_synthetic_data(x1_size=50, x2_size=50)
        indices = np.arange(X.shape[0])
    elif "SyntheticFreqPower" in dataset:
        freq = int(dataset.split("_")[1])
        power = int(dataset.split("_")[2])
        X, y = generate_freq_power_2d_synthetic_data(
            freq=freq, power=power, x1_size=50, x2_size=50
        )
        indices = np.arange(X.shape[0])
    elif "SyntheticNonRegular" in dataset:
        x1_mod = int(dataset.split("_")[1])
        x2_mod = int(dataset.split("_")[2])
        X, y = generate_non_regular_2d_synthetic_data(
            x1_mod=x1_mod, x2_mod=x2_mod, x1_size=50, x2_size=50
        )
        indices = np.arange(X.shape[0])
    else:
        raise ValueError("Invalid dataset name")

    return X, y, indices


def setup_initial_labelled_indices(
    X, initialisation_type, num_initial_labelled_indices
):
    if initialisation_type == "corner":
        initially_labelled_indices = lower_left_corner_indices_2d(
            X, num_samples=num_initial_labelled_indices
        )
    elif initialisation_type == "random":
        initially_labelled_indices = random_indices(
            X, num_samples=num_initial_labelled_indices
        )
    else:
        raise ValueError("Invalid initialisation type")

    return initially_labelled_indices


def setup_oracle(oracle_type, oracle_noise_level, dm):
    if oracle_type == 1:
        oracle = Oracle()
    elif oracle_type == 2:
        oracle = NoisyBenchmarkOracle(
            noise_function=aleotoric_and_epistemic_noise_function,
            noise_level=oracle_noise_level,
            eps_noise_level=1.0,
        )
    elif oracle_type == 3:
        oracle = CorrelatedNoiseOracle(
            data_manager=dm, noise_level=oracle_noise_level, eps_noise_level=1.0
        )
    else:
        raise ValueError("Invalid oracle type")

    return oracle


def setup_strategy(strategy, estimation_approach, oracle_type_string, oracle):
    if strategy == "random":
        return RandomStrategy()
    elif strategy == "least_confidence":
        return LeastConfidenceStrategy()
    elif strategy == "emse":
        return EMSEStrategy(estimation_approach, oracle_type_string, oracle)
    elif strategy == "emse_bias":
        return EMSEBiasStrategy(estimation_approach, oracle_type_string, oracle)
    elif strategy == "bias_reduction":
        return BiasReductionStrategy(estimation_approach, oracle_type_string, oracle)
    elif strategy == "max_term_selection":
        return MaxTermSelectionStrategy(estimation_approach, oracle_type_string, oracle)
    elif strategy == "diff_least_confidence":
        return DiffLeastConfidenceStrategy()
    elif strategy == "diff_emse":
        return DiffEMSEStrategy(estimation_approach, oracle_type_string, oracle)
    elif strategy == "diff_bias_reduction":
        return DiffBiasReductionStrategy(
            estimation_approach, oracle_type_string, oracle
        )
    elif strategy == "diff_max_term_selection":
        return DiffMaxTermSelectionStrategy(
            estimation_approach, oracle_type_string, oracle
        )
    elif strategy == "bald":
        return BALDStrategy()
    elif strategy == "batchbald":
        return BatchBALDStrategy()
    elif strategy == "pemse_pemse_diff":
        return PemsePemseDiffStrategy(estimation_approach, oracle_type_string, oracle)
    else:
        raise ValueError("Invalid strategy")


def run_experiment(
    estimation_approach,
    oracle_type,
    oracle_noise_level,
    dataset,
    initialisation_type,
    strategy_name,
    batch_strategy,
):
    if oracle_type == 1:
        oracle_type_string = "type_1"
    elif oracle_type == 2:
        oracle_type_string = "type_2"
    elif oracle_type == 3:
        oracle_type_string = "type_3"
    else:
        raise ValueError("Invalid oracle type")

    columns = [
        "strategy",
        "iteration",
        "mse",
        "rmse",
        "r2",
        "explained_variance",
        "replicate",
        "num_labelled",
        "num_initial_labelled",
        "estimation_approach",
        "batch_strategy",
        "query",
    ]
    results = []

    final_mse = []
    for num_initial_labelled_indices in NUM_INITIAL_LABELLED_INDICES:
        np.random.seed(12345)
        X, y, indices = setup_dataset(dataset)
        initially_labelled_indices = setup_initial_labelled_indices(
            X, initialisation_type, num_initial_labelled_indices
        )

        for replicate in range(10):
            np.random.seed(12345 + replicate)
            # Setup DataManager
            dm = DataManager(
                indices=indices,
                observations=X,
                targets=y,
                initially_labelled_indices=initially_labelled_indices,
            )
            oracle = setup_oracle(oracle_type, oracle_noise_level, dm)
            dm.initialise(oracle)

            # Create a model
            nn_kwargs = {
                "hidden_layer_sizes": (32, 32, 16),
                "max_iter": 500,
                "solver": "adam",
                "batch_size": 1000000,
            }
            ensemble = BootstrapEnsemble(
                num_models=QUERY_SIZE, model=MLPRegressor, model_kwargs=nn_kwargs
            )

            # Setup up strategy
            strategy = setup_strategy(
                strategy_name, estimation_approach, oracle_type_string, oracle
            )

            # Active learning loop
            for iteration in tqdm(range(NUM_ITERATIONS)):
                train_X, train_y = dm.construct_noisy_dataset()
                ensemble.fit(train_X, train_y)
                ensemble_predictions = ensemble.predict(dm.full_X)

                mse = np.mean((np.mean(ensemble_predictions, axis=0) - dm.full_y) ** 2)
                rmse = np.sqrt(mse)
                r2 = 1 - mse / np.var(dm.full_y)
                explained_variance = 1 - np.var(
                    dm.full_y - np.mean(ensemble_predictions, axis=0)
                ) / np.var(dm.full_y)

                requery_value = False if oracle_type == 1 else True
                if strategy_name == "random":
                    indices_to_label = strategy.select_next_indices(
                        dm,
                        num_suggestions=QUERY_SIZE,
                        requery=requery_value,
                        batch_strategy=batch_strategy,
                    )
                else:
                    indices_to_label = strategy.select_next_indices(
                        dm,
                        ensemble,
                        num_suggestions=QUERY_SIZE,
                        requery=requery_value,
                        batch_strategy=batch_strategy,
                    )

                dm.update_noisy_targets(oracle, indices_to_label)
                dm.update_labelled_indices(indices_to_label, iteration+1)

                row = [
                    strategy.name,
                    iteration,
                    mse,
                    rmse,
                    r2,
                    explained_variance,
                    replicate,
                    len(dm.labelled_indices),
                    num_initial_labelled_indices,
                    estimation_approach,
                    batch_strategy,
                    indices_to_label,
                ]
                results.append(row)

            final_mse.append(mse)

    if wandb.run:
        wandb.log({"mse_loss": np.mean(final_mse)})

    results_df = pd.DataFrame(results, columns=columns)
    return results_df


def main():
    parser = argparse.ArgumentParser(
        description="Run active learning experiments with various number of initial points in single setting."
    )
    parser.add_argument(
        "--estimation_approach",
        type=str,
        required=True,
        choices=["direct", "quadratic", "cheat"],
        help="Specify the estimation approach (direct, quadratic, or cheat)",
    )
    parser.add_argument(
        "--oracle_type",
        type=int,
        choices=[1, 2, 3],
        required=True,
        help="Specify the oracle type (1, 2, or 3)",
    )
    parser.add_argument(
        "--oracle_noise_level",
        type=float,
        required=True,
        help="Specify the noise level for the oracle",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        required=True,
        help="Specify the problem or dataset name, e.g., SyntheticRegular, SyntheticFreqPower_4_4, SyntheticNonRegular_4_4",
    )
    parser.add_argument(
        "--init",
        type=str,
        choices=["corner", "random"],
        required=True,
        help="Specify the initialization type (corner or random)",
    )
    parser.add_argument(
        "--strategy",
        type=str,
        required=True,
        choices=[
            "random",
            "least_confidence",
            "emse",
            "emse_bias",
            "bias_reduction",
            "max_term_selection",
            "diff_least_confidence",
            "diff_emse",
            "diff_bias_reduction",
            "diff_max_term_selection",
            "bald",
            "batchbald",
            "pemse_pemse_diff",
        ],
        help="Specify the strategy to use",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="./notebooks/results",
        help="Specify the base output directory",
    )
    parser.add_argument(
        "--batch_strategy",
        type=str,
        default="top-k",
        choices=["top-k", "eigen-decomposition"],
        help="Specify the batch strategy (top-k or eigen-decomposition)",
    )

    args = parser.parse_args()

    # Print the configuration
    print("Configuration:")
    print(f"Estimation Approach: {args.estimation_approach}")
    print(f"Oracle Type: {args.oracle_type}")
    print(f"Oracle Noise Level: {args.oracle_noise_level}")
    print(f"Dataset: {args.dataset}")
    print(f"Initialisation Type: {args.init}")
    print(f"Strategy: {args.strategy}")
    print(f"Batch Strategy: {args.batch_strategy}")
    print(f"Output Directory: {args.output_dir}")

    # Set up wandb
    wandb.init(entity="ada", project="aiau", config=args.__dict__)

    output_path = create_output_directory_single(
        args.output_dir,
        args.estimation_approach,
        args.oracle_type,
        args.oracle_noise_level,
        args.dataset,
        args.init,
    )
    csv_path = os.path.join(output_path, f"{args.strategy}-{args.batch_strategy}.csv")
    queries_path = os.path.join(
        output_path, f"{args.strategy}-{args.batch_strategy}-queries.pkl"
    )

    if os.path.exists(csv_path):
        # Skip if the file already exists
        print(f"Results already exist at {csv_path}")
        return

    results_df = run_experiment(
        estimation_approach=args.estimation_approach,
        oracle_type=args.oracle_type,
        oracle_noise_level=args.oracle_noise_level,
        dataset=args.dataset,
        initialisation_type=args.init,
        strategy_name=args.strategy,
        batch_strategy=args.batch_strategy,
    )

    # Save the results with the queries to a pkl file
    results_df.to_pickle(queries_path)

    no_queries_df = results_df.drop(columns=["query"])
    no_queries_df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")


if __name__ == "__main__":
    main()
