from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Tuple, List, Union


@dataclass
class TrainConfig:
    # Preprocess
    n_threads: int = field(default=200) # n_threads to parallel process utterance
    min_speaker_id: int = field(default=11) # (min_speaker_idx – 1) will be extracted from each id to make min speaker idx == 1
    include_empty_intervals: bool = field(default=True) # if True silence will be loaded from .TextGrid

    mel_fmin: int = field(default=0)
    mel_fmax: int = field(default=8000)
    hop_length: int = field(default=192)
    stft_length: int = field(default=768)
    sample_rate: int = field(default=16000)
    window_length: int = field(default=768)
    n_mel_channels: int = field(default=80)

    raw_data_path: Path = field(default=Path("/app/data/ssw_esd"))
    val_ids_path: Path = field(default=Path("/app/data/val_ids.txt"))
    test_ids_path: Path = field(default=Path("/app/data/test_ids.txt"))
    preprocessed_data_path: Path = field(default=Path("/app/data/debug_preprocessed"))

    egemap_feature_names: Tuple[str] = field(default=("F0semitoneFrom27.5Hz_sma3nz_percentile50.0",
                                                      "F0semitoneFrom27.5Hz_sma3nz_percentile80.0",
                                                      "F0semitoneFrom27.5Hz_sma3nz_pctlrange0-2",
                                                      "spectralFlux_sma3_amean", "HNRdBACF_sma3nz_amean",
                                                      "mfcc1V_sma3nz_amean", "equivalentSoundLevel_dBp"))

    # Vocoder
    vocoder_checkpoint_path: str = field(default="/app/data/g_01800000")
    istft_resblock_kernel_sizes: Tuple[int] = field(default=(3, 7, 11))
    istft_upsample_rates: Tuple[int] = field(default=(6, 8))
    istft_upsample_initial_channel: int = field(default=512)
    istft_upsample_kernel_sizes: Tuple[int] = field(default=(16, 16))
    istft_resblock_dilation_sizes: Tuple[Tuple[int]] = field(default=((1, 3, 5), (1, 3, 5), (1, 3, 5)))
    gen_istft_n_fft: int = field(default=16)
    gen_istft_hop_size: int = field(default=4)

    # Transformer Encoder
    padding_index: int = field(default=0)
    max_seq_len: int = field(default=2000)
    phones_mapping_path: Path = field(default=Path("/app/data/debug_preprocessed/phones.json"))
    transformer_encoder_hidden: int = field(default=512)
    transformer_encoder_layer: int = field(default=6)
    transformer_encoder_head: int = field(default=2)
    transformer_conv_filter_size: int = field(default=512)
    transformer_conv_kernel_size: tuple = field(default=(9, 1))
    transformer_encoder_dropout: float = field(default=0.2)

    # Transformer Decoder
    transformer_decoder_hidden: int = field(default=512)
    transformer_decoder_layer: int = field(default=6)
    transformer_decoder_head: int = field(default=2)
    transformer_decoder_dropout: float = field(default=0.2)

    # Emotion Conditioning
    emotion_emb_hidden_size: int = field(default=256)
    n_egemap_features: int = field(default=2)
    conditional_cross_attention: bool = field(default=True)
    conditional_layer_norm: bool = field(default=True)
    stack_speaker_with_emotion_embedding: bool = field(default=True)

    # Variance Predictor
    variance_embedding_n_bins: int = field(default=256)
    variance_predictor_kernel_size: int = field(default=3)
    variance_predictor_filter_size: int = field(default=256)
    variance_predictor_dropout: float = field(default=0.5)

    # Dataset
    multi_speaker: bool = field(default=True)
    multi_emotion: bool = field(default=True)
    n_emotions: int = field(default=5)
    n_speakers: int = field(default=10)
    train_batch_size: int = field(default=64)
    val_batch_size: int = field(default=64)
    device: str = field(default="cuda")

    # FastSpeech
    postnet_norm: str = field(default="BN") # "BN" or "IN"
    speaker_emb_hidden_size: int = field(default=256)

    # Discriminator
    optimizer_lrate_d: float = field(default=1e-4)
    optimizer_betas_d: tuple[float, float] = field(default=(0.5, 0.9))
    kernels_d: tuple[float, ...] = field(default=(3, 5, 5, 5, 3))
    strides_d: tuple[float, ...] = field(default=(1, 2, 2, 1, 1))
    compute_adversarial_loss: bool = field(default=True)
    compute_fm_loss: bool = field(default=True)

    # Train
    seed: int = field(default=55)
    precision: str = field(default=32)
    matmul_precision: str = field(default="medium")
    lightning_checkpoint_path: str = field(default="/app/data/model1_checkpoint")
    train_from_checkpoint: Optional[str] = field(default=None)
    num_workers: int = field(default=0)
    test_wav_files_directory: str = field(default="/app/data/model1_fix_wav")
    test_mos_files_directory: str = field(default="/app/data/model1_fix_mos")
    total_training_steps: int = field(default=50000)
    val_each_epoch: int = field(default=10)
    val_audio_log_each_step: int = field(default=1000) # each 1000 step will log audio to wanbd

    # Optimizer
    optimizer_grad_clip_val: float = field(default=1.)
    optimizer_warm_up_step: float = field(default=4000)
    optimizer_anneal_steps: tuple[float, ...] = field(default=(300000, 400000, 500000))
    optimizer_anneal_rate: float = field(default=0.3)
    fastspeech_optimizer_betas: tuple[float, float] = field(default=(0.9, 0.98))
    fastspeech_optimizer_eps: float = field(default=1e-9)
    fastspeech_optimizer_weight_decay: float = field(default=0.)

    # Wandb
    wandb_project: str = field(default="Model1")
    wandb_run_id: str = field(default=None)
    resume_wandb_run: bool = field(default=False) # if true will log data to the last wandb run in the specified project
    strategy: str = field(default="ddp_find_unused_parameters_true")
    wandb_offline: bool = field(default=False)
    wandb_progress_bar_refresh_rate: int = field(default=1)
    wandb_log_every_n_steps: int = field(default=1)
    devices: Union[tuple, int] = field(default=(0, 1, 2, 3))
    limit_val_batches: Optional[int] = field(default=4)
    limit_test_batches: Optional[int] = field(default=4)
    num_sanity_val_steps: int = field(default=4)
    save_top_k_model_weights: int = field(default=3)
    metric_monitor_mode: str = field(default="max") # 'min' or 'max'

    def __post_init__(self):
        self.hop_in_ms = self.hop_length / self.sample_rate
        if self.stack_speaker_with_emotion_embedding:
            self.emb_size_dis = self.speaker_emb_hidden_size + self.emotion_emb_hidden_size
        else:
            self.emb_size_dis = self.emotion_emb_hidden_size