import argparse
from copy import deepcopy

import numpy as np
import torch
from torch import nn

import configs
from data.load import get_ada_cl
from eval.evaluate import test
from eval.learn_ablation import LEARN_ablation
from models.mlp import MLP
from models.resnet import (
    ResNet18,
    # reduced_ResNet18,
    # reduce_ResNet18_mini,
    ResNet18_mini,
)
from models.vae import VAE, VAE_miniimagenet, VAE_MLP, VAE_tinyimagenet
from utils import get_task_indices, get_adapt_weights, calc_adapt, save_results


def run_ablations(
    args, patience_region, mixing_region, verbose: bool = True, seed_len: int = 1
):
    # use cuda?
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

    # print whether cuda is used
    if verbose:
        print(device)
        print("-" * 20, f"Ablation for {args.dataset_name}", "-" * 20)
        print(f"patience_region: {patience_region}, mixing_region: {mixing_region}")

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if device.type == "cuda":
        torch.cuda.manual_seed(args.seed)

    # load data
    ada_cl, test_cl_list = get_ada_cl(**vars(args))
    # args list
    args_list = []
    for patience in patience_region:
        args_ = deepcopy(args)
        args_.patience_threshold = patience
        args_list.append(args_)

    for mixing in mixing_region:
        args_ = deepcopy(args)
        args_.mixing = mixing
        args_list.append(args_)
    # run hyperparameter ablation
    for args_ in args_list:
        results = run_LEARN(
            args_,
            ada_cl=ada_cl,
            test_cl_list=test_cl_list,
            device=device,
            verbose=verbose,
        )
        save_results(
            f"./store/results/ablation/{args.dataset_name}/patience_{args_.patience_threshold}_mixing_{args_.mixing}",
            **results,
        )

    # run ablation for LEARN architecture
    experiment_types = ["disable_exploring", "disable_recalling", "disable_refining"]
    flags = {"is_exploring": True, "is_recalling": True, "is_refining": True}

    for exp_type in experiment_types:
        # Toggle off the specific flag for this experiment type
        flag_to_disable = "is_" + exp_type.split("_")[1]
        flags[flag_to_disable] = False
        print(f"Running experiment: {exp_type}")
        results = run_LEARN(
            args, **flags, ada_cl=ada_cl, test_cl_list=test_cl_list, device=device
        )
        save_results(
            f"./store/results/ablation/{args.dataset_name}/{exp_type}", **results
        )

        # Reset the flags for the next experiment
        flags = {"is_exploring": True, "is_recalling": True, "is_refining": True}


def run_LEARN(
    args,
    ada_cl,
    test_cl_list,
    device,
    is_exploring=True,
    is_recalling=True,
    is_refining=True,
    verbose: bool = True,
):
    if args.dataset_name == "miniimagenet":
        # model_class = (
        #     reduce_ResNet18_mini
        #     if Method
        #     in [oracle.Oracle, expvae.ExpVAE, learn.LEARN, separate.Separate]
        #     else ResNet18_mini
        # )
        model_class = ResNet18_mini
        gen_class = VAE_miniimagenet
    elif args.dataset_name == "tinyimagenet":
        # model_class = (
        #     reduce_ResNet18_mini
        #     if Method
        #     in [oracle.Oracle, expvae.ExpVAE, learn.LEARN, separate.Separate]
        #     else ResNet18_mini
        # )
        model_class = ResNet18_mini

        gen_class = VAE_tinyimagenet
    elif args.dataset_name == "cub":
        model_class = MLP
        gen_class = VAE_MLP
    else:
        model_class = ResNet18
        gen_class = VAE
    # train
    task_indices = get_task_indices(
        ada_cl, n_class_per_task=args.n_classes // args.n_tasks
    )
    adapt_weights = get_adapt_weights(task_indices, args.gamma)
    adapt_acc, adapt_loss = [], []
    seed_list = np.arange(args.seed_len)
    for i, seed in enumerate(seed_list):
        np.random.seed(seed)
        torch.manual_seed(seed)
        model = LEARN_ablation(
            **vars(args),
            is_exploring=is_exploring,
            is_recalling=is_recalling,
            is_refining=is_refining,
            model_class=model_class,
            gen_class=gen_class,
            device=device,
        )
        model.base = nn.Identity()
        if i == 0:
            print(
                f"patience_threshold: {args.patience_threshold}, mixing: {args.mixing}"
            )
        print(f"Training... seed: {seed}")
        model.train(ada_cl, verbose)
        avg_acc_, avg_loss_, acc_list_, loss_list_ = model.get_results()
        avg_acc = avg_acc_ if i == 0 else torch.cat((avg_acc, avg_acc_), 1)
        avg_loss = avg_loss_ if i == 0 else torch.cat((avg_loss, avg_loss_), 1)

        # eval
        test_acc_list_ = test(
            cl_type=args.cl_type,
            model=model.get_models(),
            dataloader_list=test_cl_list,
            n_tasks=args.n_tasks,
            n_classes=args.n_classes,
            verbose=verbose,
            base=model.base,
        ).view(-1, 1)
        test_acc_list = (
            test_acc_list_ if i == 0 else torch.cat((test_acc_list, test_acc_list_), 1)
        )

        # adapt
        adapt_acc.append(calc_adapt(acc_list_, adapt_weights))
        adapt_loss.append(calc_adapt(loss_list_, adapt_weights))

    return {
        "avg_acc": avg_acc,
        "avg_loss": avg_loss,
        "test_acc_list": test_acc_list,
        "adapt_acc": adapt_acc,
        "adapt_loss": adapt_loss,
    }


def main(config, seed_len):
    if config == "all":
        dataset_list = ["cifar10", "cifar100", "miniimagenet", "tinyimagenet"]
    elif config == "group1":
        dataset_list = ["cifar10", "tinyimagenet"]
    elif config == "group2":
        dataset_list = ["cifar100", "miniimagenet"]
    else:
        dataset_list = [config]

    for config_ in dataset_list:
        config = getattr(configs, config_)
        args = argparse.Namespace(**config)
        run_ablations(
            args,
            patience_region=[1, 5, 10],
            mixing_region=[0.1, 0.3, 0.4, 0.5],
            seed_len=seed_len,
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run experiments.")
    parser.add_argument(
        "--config",
        type=str,
        default="group_1",
        help="The name of the configuration to use.",
    )
    parser.add_argument(
        "--seed_len",
        type=int,
        default=1,
    )
    args = parser.parse_args()

    print(args.config, args.seed_len)
    main(args.config, args.seed_len)
