import argparse
import ast
import os
from pathlib import Path

import torch
from yacs.config import CfgNode as CN

from .finetuning.dataset_config_finetuning import _C as dataset_cfg_finetuning
from .finetuning.model_config_finetuning import _C as model_cfg_finetuning
from .finetuning.train_config_finetuning import _C as train_cfg_finetuning


def get_args():
    parser = argparse.ArgumentParser()
    # Dataset arguments
    parser.add_argument(
        "-d",
        "--dataset_database",
        help="specify the database used",
        default="librispeech",
        type=str,
    )
    parser.add_argument("--dataset_num_workers", help="the number of workers", default=2, type=int)
    parser.add_argument("--dataset_train_folder", help="Path to the train dataset folder", type=str)
    parser.add_argument("--dataset_eval_folder", help="Stem of validation dataset", type=str)
    parser.add_argument("--dataset_test_folder", help="Stem of test dataset", type=str)
    parser.add_argument("--dataset_bucket_file", help="Path to the bucket files", type=str)
    # Training arguments
    parser.add_argument("-e", "--train_epoch", help="total training epoch", type=int)
    parser.add_argument("-b", "--train_batch_size", help="training batch size", type=int)
    parser.add_argument("--train_val_batch_size", help="training batch size", type=int)
    parser.add_argument("-l", "--train_lr", help="learning rate", type=float)
    parser.add_argument("-g", "--train_device_id", help="GPU ids", type=str)
    parser.add_argument("-s", "--train_seed", help="random seed", type=int)
    parser.add_argument(
        "-p",
        "--train_patience",
        help="the patience used in the early stopping",
        default=3,
        type=int,
    )
    parser.add_argument("--train_warmup_steps", help="number of steps lr is raising from min_lr to lr", type=int)
    parser.add_argument("--train_resume", help="path to the .ckpt file id resume training / for inference", type=str)
    parser.add_argument("--train_device", help="run on cuda or cpu", default="cuda", type=str)
    parser.add_argument("--train_wandb_mode", help="log to wandb online or offline", type=str)
    parser.add_argument("--train_wandb_project", help="wandb project name", type=str)
    parser.add_argument("--train_freeze_upstream", help="Freeze or tune WavLM", type=str)
    parser.add_argument("--train_weight_decay", help="Weight decay for Adam", type=float)
    parser.add_argument("--train_kl_gamma", help="KL weight loss", type=float)
    parser.add_argument("--train_log_dir", help="Name of directory for logs", type=str)
    parser.add_argument("--train_ds_size", help="Debug dataset size", type=int)
    parser.add_argument("--train_disable_kl_loss", help="If train with KL loss or without even with ELBO", type=bool)
    parser.add_argument("--train_validation_metric", help="wer or loss", type=str)
    parser.add_argument(
        "--train_prior_weight_ctc_loss", help="use posterior or prior distribution weight for ctc layer loss", type=str
    )
    parser.add_argument("--train_save_logs_nfs", help="Save ckpts to /app/data or to the /app/nfs", type=bool)
    # Model arguments
    parser.add_argument("-r", "--model_output_rep", help="weighted sum or last layer", type=str)
    parser.add_argument("--model_init_with_ckpt", help="if init model with ckpt", type=bool)
    parser.add_argument(
        "--model_path_to_wavlm",
        help="initialize model with pre-trained WavLM",
        type=str,
    )
    parser.add_argument(
        "-M",
        "--model_type",
        help="modify cfg.train.model.type",
        default="WavLM-Large",
        type=str,
    )
    parser.add_argument("--model_encoder_layers", help="WavLM size", type=int)
    parser.add_argument("--model_dropout", help="Dropout in WavLM", type=float)
    parser.add_argument("--model_projector_dim", help="Proj dim", type=int)
    parser.add_argument("--model_dropout_input", help="Dropout before WavLM", type=float)
    parser.add_argument("--model_log_dir", help="Name of directory for logs", type=str)
    parser.add_argument("--model_deep_model", help="Model type", type=str)
    parser.add_argument("--model_prior_distribution", help="Name of directory for logs", type=str)
    parser.add_argument("--model_p_for_geometric_pmf", help="P for geometric distribution", type=float)
    parser.add_argument("--model_layer_position_encoding", help="Use or not position encoding", type=str)
    parser.add_argument(
        "--model_distribution_prediction",
        help="How layer distribution is predicted",
        type=str,
    )
    # ASR
    parser.add_argument("--model_rnn_dim", help="RNN hidden size", type=str)
    parser.add_argument("--model_rnn_dropout", help="RNN dropout", type=str)
    parser.add_argument("--model_beam_threshold", help="Probability under which", type=int)
    parser.add_argument("--model_beam", help="Number of beam", type=int)
    parser.add_argument("--model_n_asr_models", help="24 ASR models to train or 1", type=str)
    parser.add_argument("--model_distribution_prediction_architecture", help="Linear layer or sequential", type=str)
    parser.add_argument("--model_mhfa_head_nb", help="Number of mha heads for mhfa", type=int)
    parser.add_argument("--model_mhfa_compression_dim", help="Dimension for compression dim", type=int)
    parser.add_argument("--train_accumulate_each_n_steps", help="Number of steps for Gradient Accumulation", type=int)
    parser.add_argument("--model_chi2_nc", help="Coefficient for chi2 prior distribution ELBO", type=int)

    args = parser.parse_args()
    return args


def create_workshop(cfg, mode, local_rank, fold=None):
    output_rep = cfg.model.output_rep
    beam = cfg.model.beam
    thresh = cfg.model.beam_threshold
    rnn_dim = cfg.model.rnn_dim[0]
    rnn_dropout = cfg.model.rnn_dropout[0]
    seed = cfg.train.seed
    kl_gamma = cfg.train.kl_gamma
    distribution_prediction = cfg.model.distribution_prediction
    kl_distribution = cfg.model.prior_distribution
    val_folder = cfg.dataset.eval_folder
    test_folder = cfg.dataset.test_folder
    proj_dim = cfg.model.projector_dim
    dropout = cfg.model.dropout
    dropout_input = cfg.model.dropout_input
    freeze_cnn = cfg.train.freeze_cnn
    clip_grad = cfg.train.clip_grad
    database = cfg.dataset.database
    batch = cfg.train.batch_size * cfg.train.accumulate_each_n_steps
    layers = cfg.model.encoder_layers
    lr = cfg.train.lr
    p_for_geometric_pmf = cfg.model.p_for_geometric_pmf
    epoch = cfg.train.epoch
    freeze_upstream = cfg.train.freeze_upstream
    logs_dir = cfg.train.log_dir
    weight_decay = cfg.train.weight_decay
    n_asr_models = cfg.model.n_asr_models
    layer_position_encoding = cfg.model.layer_position_encoding
    model_dropout_input = cfg.model.dropout_input
    warmup_steps = cfg.train.warmup_steps
    mhfa_compression_dim = cfg.model.mhfa_compression_dim
    mhfa_head_nb = cfg.model.mhfa_head_nb
    chi2_nc = cfg.model.chi2_nc

    world_size = torch.cuda.device_count()
    batch = batch * world_size
    if output_rep == "elbo":
        pre_suffix = (
            f"seed_{seed}_kl_gamma{kl_gamma}_kl_prior_{kl_distribution}_"
            f"beam_{beam}_thresh_{thresh}_bs_{batch}_"
            f"_distribution_prediction_{distribution_prediction}_output_rep_{output_rep}"
            f"_lr_{lr}"
        )
        if kl_distribution == "geometric":
            pre_suffix = pre_suffix + f"_p_geom_{p_for_geometric_pmf}"
        else:
            pre_suffix = pre_suffix + f"_chi2_nc_{chi2_nc}"
    elif output_rep == "mhfa":
        pre_suffix = (
            f"output_rep_{output_rep}_mhfa_compression_dim_{mhfa_compression_dim}_mhfa_head_nb_{mhfa_head_nb}"
            f"seed_{seed}_beam_{beam}_thresh_{thresh}_bs_{batch}_layer_position_encoding_{layer_position_encoding}"
            f"_model_dropout_input_{model_dropout_input}_lr_{lr}_warmup_steps_{warmup_steps}"
        )
    else:
        pre_suffix = (
            f"output_rep_{output_rep}_seed_{seed}_beam_{beam}_thresh_{thresh}_bs_{batch}"
            f"_model_dropout_input_{model_dropout_input}_lr_{lr}_warmup_steps_{warmup_steps}"
        )

    if isinstance(val_folder, str):
        if val_folder.startswith("["):
            val_folder = val_folder[2:-2]
    else:
        val_folder = val_folder[0]

    if isinstance(test_folder, str):
        if test_folder.startswith("["):
            test_folder = test_folder[2:-2]
    else:
        test_folder = test_folder[0]

    suffix = f"{pre_suffix}/val_session_{val_folder}_test_speaker_{test_folder}"
    config_name = f"/app/data/{logs_dir}/{suffix}"

    cfg.workshop = config_name
    cfg.ckpt_save_path = f"{config_name}/checkpoint"

    if local_rank == 0:
        if os.path.exists(cfg.workshop):
            if cfg.train.resume is None:
                message = f"workshop {cfg.workshop} already existed."
                raise ValueError(message)
        else:
            print(f"Created workshop on local rank: {local_rank}")
            os.makedirs(str(Path(cfg.workshop) / "distribution_predicted"))
            os.makedirs(cfg.ckpt_save_path)
    return cfg


class ParseAction(argparse.Action):
    def __call__(self, parser, namespace, values, option_string=None):
        print("%r %r %r" % (namespace, values, option_string))
        values = list(map(int, values.split()))
        setattr(namespace, self.dest, values)


def get_config(mode=""):
    args = get_args()
    cfg = CN(new_allowed=True)
    cfg.model = CN(new_allowed=True)
    cfg.dataset = CN(new_allowed=True)
    cfg.train = CN(new_allowed=True)
    if len(args.model_type.split("-")) > 1:
        args.model_type, version = (
            args.model_type.split("-")[0],
            args.model_type.split("-")[1],
        )
    cfg.model.update(model_cfg_finetuning[args.model_type])
    cfg.dataset.update(dataset_cfg_finetuning[args.dataset_database])
    cfg.train.update(train_cfg_finetuning[args.model_type + mode])
    args = vars(args)
    verbose = []
    for key, value in args.items():
        key_list = key.split("_", maxsplit=1)
        if len(key_list) > 1:
            if value is not None and key_list[-1] == "freeze_cnn":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "freeze_upstream":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "n_asr_models":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "layer_position_encoding":
                value = True if value == "True" else False
            elif value is not None and key_list[-1] == "prior_weight_ctc_loss":
                value = True if value == "True" else False
            if value is not None or not hasattr(cfg[key_list[0]], key_list[1]):
                cfg[key_list[0]][key_list[1]] = value
                verbose.append((key, value))
        else:
            if value is not None or not hasattr(cfg, key_list[0]):
                cfg[key_list[0]] = value
                verbose.append((key, value))
    if version == "Base":
        cfg.model.path_to_wavlm = cfg.model.path_to_wavlm  # [0]
        cfg.model.encoder_layers = 12
        cfg.model.encoder_embed_dim = 768
        cfg.model.ffn_embed_dim = 3072
        cfg.model.num_heads = 12
        cfg.model.extractor_mode = "default"
        cfg.model.normalize = False
        cfg.model.normalize_before = False
    elif version == "Large":
        cfg.model.path_to_wavlm = cfg.model.path_to_wavlm  # [1]
        cfg.model.encoder_layers = 24
    else:
        raise ValueError(f"Unknown WavLM version: {version}")

    # modify cfg.train.batch_size in the case of multi-GPUs training
    num_gpus = torch.cuda.device_count()  # num_gpus
    if num_gpus > 1:
        ddp_batch_size = round(cfg.train.batch_size / num_gpus)
        print(f"Modified batch size: {cfg.train.batch_size} -> {ddp_batch_size}.")
        cfg.train.batch_size = ddp_batch_size
        ddp_batch_size_val = round(cfg.train.val_batch_size / num_gpus)
        print(f"Modified val batch size: {cfg.train.val_batch_size} -> {ddp_batch_size_val}.")
        cfg.train.val_batch_size = ddp_batch_size_val
    if isinstance(cfg.model.rnn_dim, str):
        if cfg.model.rnn_dim.startswith("["):
            cfg.model.rnn_dim = cfg.model.rnn_dim[1:-1]
        cfg.model.rnn_dim = list(map(int, cfg.model.rnn_dim.split(" ")))
    if isinstance(cfg.model.rnn_dropout, str):
        if cfg.model.rnn_dropout.startswith("["):
            cfg.model.rnn_dropout = cfg.model.rnn_dropout[1:-1]
        cfg.model.rnn_dropout = list(map(float, cfg.model.rnn_dropout.split(" ")))
    if isinstance(cfg.dataset.train_folder, str):
        if cfg.dataset.train_folder.startswith("["):
            cfg.dataset.train_folder = cfg.dataset.train_folder[1:-1]
        cfg.dataset.train_folder = cfg.dataset.train_folder.split(" ")
    print(f"Batch size: {cfg.train.batch_size}.")
    print(f"Applying layer_position_encoding: {cfg.model.layer_position_encoding}.")
    return cfg


def dict_2_list(dict):
    lst = []
    for key, value in dict.items():
        if value is not None:
            lst.extend([key, value])
    return lst
