import argparse

import numpy as np
import torch
import torch.nn as nn
import torchvision
from transformers import AutoTokenizer, AutoModel

import configs
from data.load import get_ada_cl
from eval.evaluate import test
from methods import (
    agem,
    er,
    expvae,
    finetune,
    learn,
    oracle,
    online_ewc,
    mir,
    gdumb,
    separate,
)
from models.llm import llm_base
from models.mlp import MLP, MLP_LLM, reduced_MLP_LLM
from models.resnet import (
    ResNet18,
    # reduced_ResNet18,
    # reduce_ResNet18_mini,
    ResNet18_mini,
)
from models.vae import VAE, VAE_miniimagenet, VAE_MLP, VAE_tinyimagenet
from utils import get_task_indices, get_adapt_weights, calc_adapt, save_results


def run(args, verbose: bool = True):
    # use cuda?
    device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
    method_list = {
        "oracle": oracle.Oracle,
        "er": er.ExperienceReplay,
        "finetune": finetune.Finetune,
        "learn": learn.LEARN,
        "online_ewc": online_ewc.OnlineEWC,
        "mir": mir.MIR,
        "gdumb": gdumb.Gdumb,
        "agem": agem.AGEM,
        "expvae": expvae.ExpVAE,
        # 'separate': separate.Separate
    }
    # print whether cuda is used
    if verbose:
        print(device)

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if device.type == "cuda":
        torch.cuda.manual_seed(args.seed)

    if args.cl_type == "no_repeat":
        args.n_repeats = 1

    # load data
    ada_cl, test_cl_list = get_ada_cl(**vars(args))

    # for Method in [learn.LEARN, oracle.Oracle, er.ExperienceReplay, finetune.Finetune]:
    # for Method in [oracle.Oracle, er.ExperienceReplay, learn.LEARN]:
    if args.method == "all":
        method = method_list.keys()

    elif args.method == "main":
        method = [
            "finetune",
            "oracle",
            "er",
            "learn",
        ]
    elif args.method == "else":
        method = [
            "online_ewc",
            "mir",
            "gdumb",
            "agem",
            "expvae",
        ]
    elif args.method == "else2":
        method = [
            # "agem",
            "expvae",
            "gdumb",
        ]
    else:
        method = [args.method]
    for name in method:
        Method = method_list[name]
        if args.dataset_name == "miniimagenet":
            # model_class = (
            #     reduce_ResNet18_mini
            #     if Method
            #     in [oracle.Oracle, expvae.ExpVAE, learn.LEARN, separate.Separate]
            #     else ResNet18_mini
            # )
            model_class = ResNet18_mini
            gen_class = VAE_miniimagenet
        elif args.dataset_name == "tinyimagenet":
            # model_class = (
            #     reduce_ResNet18_mini
            #     if Method
            #     in [oracle.Oracle, expvae.ExpVAE, learn.LEARN, separate.Separate]
            #     else ResNet18_mini
            # )
            model_class = ResNet18_mini

            gen_class = VAE_tinyimagenet
        elif args.dataset_name == "cub":
            # model_class = (
            #     reduced_MLP
            #     if Method
            #     in [oracle.Oracle, expvae.ExpVAE, learn.LEARN, separate.Separate]
            #     else MLP
            # )
            model_class = MLP
            gen_class = VAE_MLP
        elif args.dataset_name == "clinc150":
            model_class = (
                reduced_MLP_LLM
                if Method
                in [oracle.Oracle, expvae.ExpVAE, learn.LEARN, separate.Separate]
                else MLP_LLM
            )
            gen_class = VAE_MLP
        else:
            # model_class = (
            #     reduced_ResNet18
            #     if Method
            #     in [oracle.Oracle, expvae.ExpVAE, learn.LEARN, separate.Separate]
            #     else ResNet18
            # )
            model_class = ResNet18
            gen_class = VAE
        # train
        task_indices = get_task_indices(
            ada_cl, n_class_per_task=args.n_classes // args.n_tasks
        )
        adapt_weights = get_adapt_weights(task_indices, args.gamma)
        adapt_acc, adapt_loss = [], []
        seed_list = np.arange(args.seed_len) + args.seed
        # model_size = torch.zeros(len(seed_list))
        for i, seed in enumerate(seed_list):
            np.random.seed(seed)
            torch.manual_seed(seed)
            model = Method(
                **vars(args),
                model_class=model_class,
                gen_class=gen_class,
                device=device,
            )
            if args.dataset_name == "cub":
                model.base = torchvision.models.resnet18(pretrained=True).to(device)
                model.base.eval()
                model.base.fc = nn.Identity()

            elif args.dataset_name == "clinc150":
                model_name = "distilroberta-base"
                tokenizer = AutoTokenizer.from_pretrained(model_name)
                tokenizer.pad_token_id = tokenizer.eos_token_id
                transformer = AutoModel.from_pretrained(model_name)
                transformer.config.pad_token_id = tokenizer.pad_token_id
                transformer.eval()
                model.base = llm_base(transformer, device)

            else:
                model.base = nn.Identity()
            method_name = model.method_name
            if i == 0:
                print("-" * 30, method_name, "-" * 30)
            print(f"Training... seed: {seed}")
            model.train(ada_cl, verbose)
            avg_acc_, avg_loss_, acc_list_, loss_list_ = model.get_results()
            avg_acc = avg_acc_ if i == 0 else torch.cat((avg_acc, avg_acc_), 1)
            avg_loss = avg_loss_ if i == 0 else torch.cat((avg_loss, avg_loss_), 1)

            # eval
            test_acc_list_ = test(
                cl_type=args.cl_type,
                model=model.get_models(),
                dataloader_list=test_cl_list,
                n_tasks=args.n_tasks,
                n_classes=args.n_classes,
                verbose=verbose,
                base=model.base,
            ).view(-1, 1)
            test_acc_list = (
                test_acc_list_
                if i == 0
                else torch.cat((test_acc_list, test_acc_list_), 1)
            )

            # adapt
            adapt_acc.append(calc_adapt(acc_list_, adapt_weights))

            adapt_loss.append(calc_adapt(loss_list_, adapt_weights))
            # model_size[i] = len(model.get_models())

            # few-shot adapt
            # model.fewshot_test(fewshot_cl, verbose)
            # avg_acc_fewshot, acc_list_fewshot = model.get_fewshot_results()
            # fewshot_adapt_acc.append(calc_adapt(acc_list_fewshot, adapt_weights_fewshot))

            print(avg_acc.view(-1)[-1])
            print(test_acc_list.mean(1))
        # Saving results
        save_results(
            f"./store/results/{args.cl_type}/{args.dataset_name}/{args.seed}",
            avg_acc,
            avg_loss,
            test_acc_list,
            adapt_acc,
            adapt_loss,
            # model_size,
            name_prefix=method_name,
        )


def parse_args():
    parser = argparse.ArgumentParser(description="Run experiments.")
    parser.add_argument(
        "--config",
        type=str,
        default="cifar10",
        help="The name of the configuration to use.",
    )
    parser.add_argument("--method_name", type=str, help="Name of the experiment.")
    parser.add_argument("--cl_type", type=str, help="CL type.")
    parser.add_argument("--n_tasks", type=int, help="Number of tasks.")
    parser.add_argument("--n_classes", type=int, help="Number of classes.")
    parser.add_argument("--n_repeats", type=int, help="Number of repeats.")
    parser.add_argument("--prune_min", type=float, help="Prune min.")
    parser.add_argument("--patience_threshold", type=float, help="Patience threshold.")
    parser.add_argument("--batch_size", type=int, help="Batch size.")
    parser.add_argument("--gamma", type=float, help="Gamma.")
    parser.add_argument("--mixing", type=float, help="Mixing.")
    parser.add_argument("--lr", type=float, help="Learning rate.")
    parser.add_argument("--buffer_size", type=int, help="Buffer size.")
    parser.add_argument("--eta", type=int, help="Eta.")
    parser.add_argument("--agem_batch_size", type=int, help="AGEM batch size.")
    parser.add_argument("--cuda", type=bool, help="CUDA.")
    parser.add_argument("--seed", type=int, help="Seed.")
    parser.add_argument("--seed_len", type=int, help="Seed length.")
    parser.add_argument("--method", type=str, help="Method name.", default="gdumb")
    return parser.parse_args()


def main():
    # Get command-line arguments
    cmd_args = parse_args()

    # Get the selected configuration
    if hasattr(configs, cmd_args.config):
        config = getattr(configs, cmd_args.config)
    else:
        raise ValueError(f"No such configuration: {cmd_args.config}")

    # Update config with any hyperparameters provided on the command line
    for arg in vars(cmd_args):
        if arg != "config" and getattr(cmd_args, arg) is not None:
            config[arg] = getattr(cmd_args, arg)

    # Convert the config dictionary to an argparse.Namespace object
    args = argparse.Namespace(**config)

    run(args, verbose=True)


if __name__ == "__main__":
    main()
