import argparse
import ast
import os
from pathlib import Path

import torch
from yacs.config import CfgNode as CN

from .finetuning.dataset_config_finetuning import _C as dataset_cfg_finetuning
from .finetuning.model_config_finetuning import _C as model_cfg_finetuning
from .finetuning.train_config_finetuning import _C as train_cfg_finetuning


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-M",
        "--model_type",
        help="modify cfg.train.model.type",
        default="WavLM-Large",  # WavLM
        type=str,
    )  # required=True
    parser.add_argument(
        "-d",
        "--dataset_database",
        help="specify the database used",
        default="voxceleb1",
        type=str,
    )
    parser.add_argument("-f", "--dataset_feature", help="specify the feature used", type=str)
    parser.add_argument("-e", "--train_EPOCH", help="total training epoch", type=int)
    parser.add_argument("-b", "--train_batch_size", help="training batch size", type=int)
    parser.add_argument("-l", "--train_lr", help="learning rate", type=float)
    parser.add_argument("-g", "--train_device_id", help="GPU ids", type=str)
    parser.add_argument("-s", "--train_seed", help="random seed", type=int)
    parser.add_argument("--train_objective", help="ce or aamsoftmax", type=str)
    parser.add_argument("--train_ckpt_save_path", help="ckpt save path", type=str)
    parser.add_argument(
        "-S",
        "--train_save_best",
        help="save model with the best performance",
        action="store_true",
    )
    parser.add_argument(
        "-p",
        "--train_patience",
        help="the patience used in the early stopping",
        default=15,
        type=int,
    )
    parser.add_argument("-r", "--model_output_rep", help="weighted sum or last layer", type=str)
    parser.add_argument("-m", "--mark", help="mark the current run", type=str)
    parser.add_argument("--dataset_num_workers", help="the number of workers", default=2, type=int)
    parser.add_argument("--train_warmup_epoch", help="set the warmup epoch", default=0.05, type=float)
    parser.add_argument("--train_resume", help="resume an experiment", type=str)
    parser.add_argument("--train_load_model", help="load a model", type=str)
    parser.add_argument("--train_device", help="run on cuda or cpu", default="cuda", type=str)
    parser.add_argument(
        "--model_path_to_vesper",
        help="initialize model with WavLM's checkpoint",
        type=str,
    )
    parser.add_argument(
        "--model_path_to_wavlm",
        help="initialize model with pre-trained WavLM",
        type=str,
    )
    # WANDB SWEEP PRETRAIN
    parser.add_argument("--train_freeze_cnn", help="freeze or tune cnn", type=str)
    parser.add_argument("--model_init_style", help="strategy to initialize WavLM", type=str)
    parser.add_argument("--model_encoder_layers", help="WavLM size", type=int)
    parser.add_argument("--model_predictor_initialization", help="What predictor to use", type=str)
    # WANDB SWEEP FINETUNE
    parser.add_argument("--model_num_classes", help="Number of classes", type=int)
    parser.add_argument("--train_freeze_upstream", help="Freeze or tune wavlm", type=str)
    parser.add_argument("--dataset_meta_csv_file", help="Path to manifest", type=str)
    parser.add_argument("--train_weight_decay", help="Weight decay for Adam", type=float)
    parser.add_argument("--train_ce_weights", help="Freeze or tune wavlm", type=str)
    parser.add_argument("--model_dropout", help="Dropout in WavLM", type=float)
    parser.add_argument("--model_projector_dim", help="Proj dim", type=int)
    parser.add_argument("--model_dropout_input", help="Dropout before wavlm", type=float)
    parser.add_argument("--model_use_std", help="Use mean and std for classification if True", type=str)
    parser.add_argument("--model_attention_pooling", help="Use attention instead of mean pooling is True", type=str)
    parser.add_argument("--model_log_dir", help="Name of directory for logs", type=str)
    parser.add_argument("--model_chi2_nc", help="Gamma parameter for chi2 distribution", type=int)
    # SWEEP ELBO
    parser.add_argument("--train_kl_gamma", help="KL weight loss", type=float)
    parser.add_argument("--model_deep_model", help="Model type", type=str)
    parser.add_argument("--dataset_test_gender", help="session for iemocap", type=str)
    parser.add_argument("--dataset_eval_session", help="gender for test", type=str)
    parser.add_argument("--train_log_dir", help="Name of directory for logs", type=str)
    parser.add_argument("--model_prior_distribution", help="Name of directory for logs", type=str)
    parser.add_argument("--model_distribution_prediction", help="How layer distribution is predicted", type=str)
    parser.add_argument("--model_layer_used_for_inference", help="Layer used for inference", type=int)
    parser.add_argument("--model_dist_mlp", help="If True place mlp before attention", type=bool)
    parser.add_argument(
        "--model_elbo_share_downstream_weights", help="If True using the same classifier weights", type=bool
    )
    parser.add_argument("--train_save_best_model", help="If True best model is saved for test", type=str)
    parser.add_argument(
        "--dataset_have_test_set", help="If True running test set, train_save_best_model should also be True", type=str
    )
    parser.add_argument("--train_wandb_mode", help="offline if not logging to wandb now", type=str)
    args = parser.parse_args()
    return args


def create_workshop(cfg, mode, local_rank, fold=None):
    seed = cfg.train.seed
    output_rep = cfg.model.output_rep
    kl_gamma = cfg.train.kl_gamma
    distribution_prediction = cfg.model.distribution_prediction
    kl_distribution = cfg.model.prior_distribution
    val_session = cfg.dataset.eval_session
    test_speaker = cfg.dataset.test_gender
    proj_dim = cfg.model.projector_dim
    modeltype = cfg.model.type
    dropout = cfg.model.dropout
    dropout_input = cfg.model.dropout_input
    modelinit = cfg.model.init_style
    freeze_cnn = cfg.train.freeze_cnn
    clip_grad = cfg.train.clip_grad
    database = cfg.dataset.database
    batch = cfg.train.batch_size
    layers = cfg.model.encoder_layers
    lr = cfg.train.lr
    epoch = cfg.train.EPOCH
    num_classes = cfg.dataset.num_classes
    freeze_upstream = cfg.train.freeze_upstream
    logs_dir = cfg.train.log_dir
    weight_decay = cfg.train.weight_decay
    chi2_nc = cfg.model.chi2_nc
    dist_mlp = cfg.model.dist_mlp

    world_size = torch.cuda.device_count()
    batch = batch * world_size

    try:
        upstream_name = cfg.model.path_to_vesper.split("/")[-2]
    except AttributeError:
        upstream_name = "WavLM-Large"

    # if cfg.train.ce_weights == [1.5, 1.7, 0.51, 2.38, 5.83, 4.46, 2.26, 0.28]:
    #     ce_weights = "baseline"
    # elif cfg.train.ce_weights == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]:
    #     ce_weights = "ones"
    # elif cfg.train.ce_weights == [1.02, 0.82, 1.7, 0.88, 0.93, 0.99, 1.08, 0.94]:
    #     ce_weights = "reverb"
    # else:
    #     assert False, f"Looking for ce weights parameter"

    # if cfg.dataset.meta_csv_file == "/app/data/preprocessed/odyssey_full.csv":
    #     csv = "simple"
    # elif cfg.dataset.meta_csv_file == "/app/data/preprocessed/msp_aug.csv":
    #     csv = "aug"
    # elif cfg.dataset.meta_csv_file == "/app/data/preprocessed/msp_aug_rev.csv":
    #     csv = "aug_rev"
    # elif cfg.dataset.meta_csv_file == "/app/data/preprocessed/dev_odyssey.csv":
    #     csv = "dev_simple"
    # else:
    #     assert False, f"Looking for csv df parameter"

    suffix = f"seed_{seed}_bs_{batch}_proj_dim_{proj_dim}_lr_{lr}_output_rep_{output_rep}"
    if output_rep == "elbo":
        suffix += (
            f"_kl_gamma{kl_gamma}_kl_prior_{kl_distribution}"
            f"_distribution_prediction_{distribution_prediction}_chi2_nc_{chi2_nc}"
        )
    config_name = os.path.join(cfg.train.logs_save_path, f"{logs_dir}/{suffix}")

    if cfg.mark is not None:
        config_name = config_name + "_mark_{}".format(cfg.mark)

    cfg.workshop = config_name  # os.path.join(config_name, f"fold_{fold}")
    cfg.ckpt_save_path = os.path.join(cfg.workshop, "checkpoints")

    if local_rank == 0:
        if os.path.exists(cfg.workshop):
            if cfg.train.resume is None:
                raise ValueError(f"workshop {cfg.workshop} already existed.")
        else:
            os.makedirs(cfg.workshop)

        if os.path.exists(cfg.ckpt_save_path):
            if cfg.train.resume is None:
                raise ValueError(f"cfg.ckpt_save_path {cfg.ckpt_save_path} already existed.")
        else:
            os.makedirs(cfg.ckpt_save_path)

        if os.path.exists(str(Path(cfg.workshop) / "distribution_predicted")):
            if cfg.train.resume is None:
                raise ValueError(
                    f"Path(cfg.workshop) / distribution_predicted "
                    f"{Path(cfg.workshop) / 'distribution_predicted'} already existed."
                )
        else:
            os.makedirs(str(Path(cfg.workshop) / "distribution_predicted"))
    return cfg


def get_config(mode="_finetune"):
    args = get_args()
    print(f"DATASET: {args.dataset_database}")
    cfg = CN(new_allowed=True)
    cfg.model = CN(new_allowed=True)
    cfg.dataset = CN(new_allowed=True)
    cfg.train = CN(new_allowed=True)
    if len(args.model_type.split("-")) > 1:
        args.model_type, version = (
            args.model_type.split("-")[0],
            args.model_type.split("-")[1],
        )
    if args.model_type == "WavLM":
        is_wavlm = True
        args.model_type = "WavLM"
    else:
        is_wavlm = False
    cfg.model.update(model_cfg_finetuning[args.model_type])
    cfg.dataset.update(dataset_cfg_finetuning[args.dataset_database])
    cfg.train.update(train_cfg_finetuning[args.model_type + mode])
    args = vars(args)
    verbose = []
    for key, value in args.items():
        key_list = key.split("_", maxsplit=1)
        if len(key_list) > 1:
            if value is not None and key_list[-1] == "freeze_cnn":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "freeze_upstream":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "use_std":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "attention_pooling":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "dist_mlp":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "save_best_model":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "have_test_set":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "ce_weights" and value is not None:
                value = list(map(float, value.split(",")))
            if value is not None or not hasattr(cfg[key_list[0]], key_list[1]):
                cfg[key_list[0]][key_list[1]] = value
                verbose.append((key, value))
        else:
            if value is not None or not hasattr(cfg, key_list[0]):
                cfg[key_list[0]] = value
                verbose.append((key, value))
    if is_wavlm:
        cfg.model.init_with_wavlm = True
        cfg.model.init_with_ckpt = not cfg.model.init_with_wavlm
        if version == "Base":
            cfg.model.path_to_wavlm = cfg.model.path_to_wavlm  # [0]
            cfg.model.encoder_layers = 12
            cfg.model.encoder_embed_dim = 768
            cfg.model.ffn_embed_dim = 3072
            cfg.model.num_heads = 12
            cfg.model.extractor_mode = "default"
            cfg.model.normalize = False
            cfg.model.normalize_before = False
        elif version == "Large":
            cfg.model.path_to_wavlm = cfg.model.path_to_wavlm  # [1]
            cfg.model.encoder_layers = 24
        else:
            raise ValueError(f"Unknown WavLM version: {version}")
    else:
        cfg.model.init_with_wavlm = True if "pretrain" in mode else False
        cfg.model.init_with_ckpt = not cfg.model.init_with_wavlm
        cfg.model.encoder_layers = cfg.model.encoder_layers  # eval(version)

    cfg.dataset.num_classes = cfg.model.num_classes

    # modify cfg.train.batch_size in the case of multi-GPUs training
    num_gpus = torch.cuda.device_count()  # num_gpus
    if num_gpus > 1:
        ddp_batch_size = round(cfg.train.batch_size / num_gpus)
        print(f"Modified batch size: {cfg.train.batch_size} -> {ddp_batch_size}.")
        cfg.train.batch_size = ddp_batch_size
    return cfg


def dict_2_list(dict):
    lst = []
    for key, value in dict.items():
        if value is not None:
            lst.extend([key, value])
    return lst
