import argparse
import ast
import os
from pathlib import Path

import torch
from yacs.config import CfgNode as CN

from .finetuning.dataset_config_finetuning import _C as dataset_cfg_finetuning
from .finetuning.model_config_finetuning import _C as model_cfg_finetuning
from .finetuning.train_config_finetuning import _C as train_cfg_finetuning


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-M",
        "--model_type",
        help="modify cfg.train.model.type",
        default="WavLM-Large",
        type=str,
    )
    parser.add_argument(
        "-d",
        "--dataset_database",
        help="specify the database used",
        default="iemocap",
        type=str,
    )
    parser.add_argument("-f", "--dataset_feature", help="specify the feature used", type=str)
    parser.add_argument("-e", "--train_epoch", help="total training epoch", type=int)
    parser.add_argument("-b", "--train_batch_size", help="training batch size", type=int)
    parser.add_argument("-l", "--train_lr", help="learning rate", type=float)
    parser.add_argument("-g", "--train_device_id", help="GPU ids", type=str)
    parser.add_argument("-s", "--train_seed", help="random seed", type=int)
    parser.add_argument(
        "-S",
        "--train_save_best",
        help="save model with the best performance",
        action="store_true",
    )
    parser.add_argument(
        "-p",
        "--train_patience",
        help="the patience used in the early stopping",
        default=15,
        type=int,
    )
    parser.add_argument("-r", "--model_output_rep", help="weighted sum or last layer", type=str)
    parser.add_argument("-m", "--mark", help="mark the current run", type=str)
    parser.add_argument("--dataset_num_workers", help="the number of workers", default=1, type=int)
    parser.add_argument("--train_warmup_epoch", help="set the warmup epoch", default=0.05, type=float)
    parser.add_argument("--train_resume", help="resume an experiment", type=str)
    parser.add_argument("--train_load_model", help="load a model", type=str)
    parser.add_argument("--train_device", help="run on cuda or cpu", default="cuda", type=str)
    parser.add_argument("--train_ds_size", help="Debug dataset size", type=int)
    parser.add_argument("--train_save_logs_nfs", help="Save ckpts to /app/data or to the /app/nfs", type=str)
    parser.add_argument(
        "--model_path_to_wavlm",
        help="initialize model with pre-trained WavLM",
        type=str,
    )
    # WANDB SWEEP PRETRAIN
    parser.add_argument("--train_freeze_cnn", help="freeze or tune cnn", type=str)
    parser.add_argument("--model_encoder_layers", help="WavLM size", type=int)
    parser.add_argument("--model_predictor_initialization", help="What predictor to use", type=str)
    # WANDB SWEEP FINETUNE
    parser.add_argument("--model_num_classes", help="Number of classes", type=int)
    parser.add_argument("--model_layer_used_for_inference", help="Layer number that is used for inference", type=int)
    parser.add_argument("--train_freeze_upstream", help="Freeze or tune wavlm", type=str)
    parser.add_argument("--dataset_meta_csv_file", help="Path to manifest", type=str)
    parser.add_argument("--train_weight_decay", help="Weight decay for Adam", type=float)
    parser.add_argument("--train_ce_weights", help="Freeze or tune wavlm", type=str)
    parser.add_argument("--model_dropout", help="Dropout in WavLM", type=float)
    parser.add_argument("--model_projector_dim", help="Proj dim", type=int)
    parser.add_argument("--model_dropout_input", help="Dropout before wavlm", type=float)
    parser.add_argument("--model_use_std", help="Use mean and std for classification if True", type=str)
    parser.add_argument("--model_attention_pooling", help="Use attention instead of mean pooling is True", type=str)
    parser.add_argument("--model_log_dir", help="Name of directory for logs", type=str)
    # SWEEP ELBO
    parser.add_argument("--train_kl_gamma", help="KL weight loss", type=float)
    parser.add_argument("--model_deep_model", help="Model type", type=str)
    parser.add_argument("--dataset_test_gender", help="session for iemocap", type=str)
    parser.add_argument("--dataset_eval_session", help="gender for test", type=str)
    parser.add_argument("--train_log_dir", help="Name of directory for logs", type=str)
    parser.add_argument("--model_prior_distribution", help="Name of directory for logs", type=str)
    parser.add_argument("--model_distribution_prediction", help="How layer distribution is predicted", type=str)
    parser.add_argument("--model_chi2_nc", help="Gamma for chi2 prior distribution", type=int)
    parser.add_argument("--model_backbone", help="data2vec or wavlm", type=str)
    parser.add_argument("--model_freeze_backbone", help="True if backbone is frozen", type=bool)
    parser.add_argument("--train_accumulate_each_n_steps", help="Make zero_grad each n step", type=int)
    parser.add_argument("--model_dist_mlp", help="If True place mlp before attention", type=str)
    parser.add_argument("--model_elbo_share_downstream_weights", help="If share downstream weights", type=bool)
    parser.add_argument("--train_val_batch_size", help="Val batch size", type=int)
    parser.add_argument("--model_p_for_geometric_pmf", help="Param for geometric prior distribution", type=float)
    parser.add_argument("--train_clip_grad_value", type=float)
    parser.add_argument("--train_save_model_val", type=str)
    parser.add_argument("--dataset_have_test_set", type=str)

    args = parser.parse_args()
    return args


def create_workshop(cfg, mode, local_rank, fold=None):
    seed = cfg.train.seed
    kl_gamma = cfg.train.kl_gamma
    distribution_prediction = cfg.model.distribution_prediction
    kl_distribution = cfg.model.prior_distribution
    val_session = cfg.dataset.eval_session
    test_speaker = cfg.dataset.test_gender
    proj_dim = cfg.model.projector_dim
    modeltype = cfg.model.type
    dropout = cfg.model.dropout
    dropout_input = cfg.model.dropout_input
    freeze_cnn = cfg.train.freeze_cnn
    clip_grad = cfg.train.clip_grad
    database = cfg.dataset.database
    batch = cfg.train.batch_size
    layers = cfg.model.encoder_layers
    lr = cfg.train.lr
    epoch = cfg.train.epoch
    num_classes = cfg.dataset.num_classes
    freeze_upstream = cfg.train.freeze_upstream
    logs_dir = cfg.train.log_dir
    weight_decay = cfg.train.weight_decay
    chi2_nc = cfg.model.chi2_nc
    accumulate_each_n_steps = cfg.train.accumulate_each_n_steps
    output_rep = cfg.model.output_rep
    dist_mlp = cfg.model.dist_mlp
    model_p_for_geometric_pmf = cfg.model.p_for_geometric_pmf

    world_size = torch.cuda.device_count()
    batch = batch * world_size

    # if cfg.train.ce_weights == [1.5, 1.7, 0.51, 2.38, 5.83, 4.46, 2.26, 0.28]:
    #     ce_weights = "baseline"
    # elif cfg.train.ce_weights == [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]:
    #     ce_weights = "ones"
    # elif cfg.train.ce_weights == [1.02, 0.82, 1.7, 0.88, 0.93, 0.99, 1.08, 0.94]:
    #     ce_weights = "reverb"
    # else:
    #     assert False, f"Looking for ce weights parameter"

    # if cfg.dataset.meta_csv_file == "/app/data/preprocessed/odyssey_full.csv":
    #     csv = "simple"
    # elif cfg.dataset.meta_csv_file == "/app/data/preprocessed/msp_aug.csv":
    #     csv = "aug"
    # elif cfg.dataset.meta_csv_file == "/app/data/preprocessed/msp_aug_rev.csv":
    #     csv = "aug_rev"
    # elif cfg.dataset.meta_csv_file == "/app/data/preprocessed/dev_odyssey.csv":
    #     csv = "dev_simple"
    # else:
    #     assert False, f"Looking for csv df parameter"

    pre_suffix = (
        f"output_rep_{output_rep}_proj_dim_{proj_dim}"
        f"_bs_{batch * accumulate_each_n_steps}_lr_{lr}_weight_decay_{weight_decay}"
    )
    if output_rep == "elbo":
        pre_suffix += (
            f"_kl_gamma{kl_gamma}_kl_prior_{kl_distribution}_distribution_prediction_{distribution_prediction}"
        )

    if kl_distribution == "chi2":
        pre_suffix += f"_chi2_nc_{chi2_nc}"
    if kl_distribution == "geometric":
        pre_suffix += f"_model_p_for_geometric_pmf_{model_p_for_geometric_pmf}"

    suffix = f"{pre_suffix}/val_session_{val_session}_test_speaker_{test_speaker}"
    config_name = f"/app/nfs_small/{logs_dir}/{suffix}"

    cfg.workshop = config_name
    # if cfg.train.save_logs_nfs:
    #     cfg.ckpt_save_path = f"/app/nfs/{logs_dir}/{suffix}/checkpoint"
    # else:
    cfg.ckpt_save_path = f"{config_name}/checkpoint"

    if local_rank == 0:
        if os.path.exists(cfg.workshop):
            if cfg.train.resume is None:
                raise ValueError(f"workshop {cfg.workshop} already existed.")
        else:
            os.makedirs(cfg.workshop)
            os.makedirs(cfg.ckpt_save_path)
            os.makedirs(str(Path(cfg.workshop) / "distribution_predicted"))
    return cfg


def get_config(mode=""):
    args = get_args()
    cfg = CN(new_allowed=True)
    cfg.model = CN(new_allowed=True)
    cfg.dataset = CN(new_allowed=True)
    cfg.train = CN(new_allowed=True)
    if len(args.model_type.split("-")) > 1:
        args.model_type, version = (
            args.model_type.split("-")[0],
            args.model_type.split("-")[1],
        )

    cfg.model.update(model_cfg_finetuning[args.model_type])
    cfg.dataset.update(dataset_cfg_finetuning[args.dataset_database])
    cfg.train.update(train_cfg_finetuning[args.model_type + mode])
    args = vars(args)
    verbose = []
    for key, value in args.items():
        key_list = key.split("_", maxsplit=1)
        if len(key_list) > 1:
            if value is not None and key_list[-1] == "freeze_cnn":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "freeze_upstream":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "freeze_backbone":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "use_std":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "attention_pooling":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "save_logs_nfs":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "dist_mlp":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "save_model_val":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "have_test_set":
                value = True if value == "True" else False

            elif value is not None and key_list[-1] == "ce_weights" and value is not None:
                value = list(map(float, value.split(",")))
            if value is not None or not hasattr(cfg[key_list[0]], key_list[1]):
                cfg[key_list[0]][key_list[1]] = value
                verbose.append((key, value))
        else:
            if value is not None or not hasattr(cfg, key_list[0]):
                cfg[key_list[0]] = value
                verbose.append((key, value))

    cfg.model.init_with_wavlm = True
    cfg.model.init_with_ckpt = not cfg.model.init_with_wavlm
    if version == "Base":
        cfg.model.path_to_wavlm = cfg.model.path_to_wavlm  # [0]
        cfg.model.encoder_layers = 12
        cfg.model.encoder_embed_dim = 768
        cfg.model.ffn_embed_dim = 3072
        cfg.model.num_heads = 12
        cfg.model.extractor_mode = "default"
        cfg.model.normalize = False
        cfg.model.normalize_before = False
    elif version == "Large":
        cfg.model.path_to_wavlm = cfg.model.path_to_wavlm  # [1]
        cfg.model.encoder_layers = 24
    else:
        raise ValueError(f"Unknown WavLM version: {version}")
    cfg.dataset.num_classes = cfg.model.num_classes

    # modify cfg.train.batch_size in the case of multi-GPUs training
    num_gpus = torch.cuda.device_count()  # num_gpus
    if num_gpus > 1:
        ddp_batch_size = round(cfg.train.batch_size / num_gpus)
        ddp_batch_size_val = round(cfg.train.val_batch_size / num_gpus)
        print(f"Modified batch size: {cfg.train.batch_size} -> {ddp_batch_size}.")
        print(f"Modified val batch size: {cfg.train.val_batch_size} -> {ddp_batch_size_val}.")
        cfg.train.batch_size = ddp_batch_size
        cfg.train.val_batch_size = ddp_batch_size_val
    print(f"Batch size: {cfg.train.batch_size}.")
    print(f"Batch size val: {cfg.train.val_batch_size}.")
    return cfg


def dict_2_list(dict):
    lst = []
    for key, value in dict.items():
        if value is not None:
            lst.extend([key, value])
    return lst
