from typing import Tuple
import os
import json
import dataclasses
from dataclasses import dataclass


@dataclass
class CommonConfig:
    seed: int = 6666
    embedding_size: int = 100
    dataset: str = 'mimic3'
    task: str = 'mimic3-50'


@dataclass
class IOConfig:
    task: int = CommonConfig.task
    data_path: str = os.path.join('..', 'data', CommonConfig.dataset)
    saved_path: str = os.path.join('..', 'saved', CommonConfig.dataset)


@dataclass
class BaseConfig:
    def __str__(self):
        config_dict = dataclasses.asdict(self)
        string = json.dumps(config_dict, indent=4)
        return string


@dataclass
class Train:
    accelerator: str = 'gpu'
    devices: int = 1
    epochs: int = 50
    weight_decay: float = 0.01
    learning_rate: float = 5e-4
    adam_epsilon: float = 1e-8
    scheduler: str = 'linear'
    warmup_ratio: float = 0.0


@dataclass
class Embedding:
    num_words: int = 150697
    embedding_size: int = CommonConfig.embedding_size
    padding_idx: int = 150695
    dropout_prob: float = 0.2
    use_init_embeddings: bool = True


@dataclass
class Combiner:
    embedding_size: int = CommonConfig.embedding_size
    rnn_units: int = 512
    num_layers: int = 1
    dropout_prob: float = 0.0


@dataclass
class TextEncoder:
    embedding: Embedding = Embedding()
    combiner: Combiner = Combiner()


@dataclass
class Decoder:
    input_dim: int = Combiner.rnn_units
    attention_dim: int = 512
    attention_head: int = 8
    dropout_prob: float = 0.2
    pooling: str = 'max'
    activation: str = 'tanh'


@dataclass
class LabelEncoder:
    pooling: str = 'max'


@dataclass
class Loss:
    kl_loss_weight: float = 5.0
    code_loss_weight: float = 1.0


@dataclass
class ICDModel:
    text_encoder: TextEncoder = TextEncoder()
    decoder: Decoder = Decoder()
    label_encoder: LabelEncoder = LabelEncoder()
    loss: Loss = Loss()


@dataclass
class Dataset:
    version: str = CommonConfig.task
    truncate_length: int = 4000
    label_truncate_length: int = 30
    term_count: int = 8


@dataclass
class Metrics:
    ks: Tuple = (5, 8, 15)


@dataclass
class TrainLoader:
    batch_size: int = 16
    shuffle: int = True
    num_workers: int = 8


@dataclass
class DevLoader:
    batch_size: int = 64
    shuffle: int = False
    num_workers: int = 8


@dataclass
class Config(BaseConfig):
    task: str = CommonConfig.task
    seed: int = CommonConfig.seed
    io: IOConfig = IOConfig()
    train: Train = Train()
    icd_model: ICDModel = ICDModel()
    dataset: Dataset = Dataset()
    metrics: Metrics = Metrics()
    train_loader: TrainLoader = TrainLoader()
    dev_loader: DevLoader = DevLoader()
