from trident.core.module import TridentModule
from trident.utils.enums import Split
import torch

from trident.utils.logging import get_logger

log = get_logger(__name__)

# Flores MMS mismatches
LANGUAGE_MAPPINGS = {
    "arb": "ara",
    "azj": "azj-script_latin",
    "fuv": "ful",
    "gaz": "orm",
    "khk": "mon",
    "lvs": "lav",
    "pbt": "pus",
    "pes": "fas",
    "srp": "srp-script_latin",
    "urd": "urd-script_arabic",
    "uzn": "uzb-script_latin",
    "zho_Hans": "cmn-script_simplified",
    "zho_Hant": "yue-script_traditional",
    "zsm": "zlm",
}

class MMSModule(TridentModule):
    """Load language adapters for fine-tuned variants of MMSModule on demand for train/val/testing."""

    def __init__(self, train_lang: str = "eng", *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.train_lang = train_lang
        self.has_language_adapter = self.model.config.adapter_attn_dim is not None
        if self.has_language_adapter:
            self.model.target_lang = train_lang
            self.model.load_adapter(train_lang)
            log.info(f"Loaded {train_lang} adapter.")
            self.freeze_adapters()

        # NOTE: -1 means training state
        self.current_val_idx: int = -1

    def freeze_adapters(self):
        for parameter in self.model.parameters():
            parameter.requires_grad = True
        for name, module in self.model.named_modules():
            if name.endswith("adapter_layer"):
                for parameter in module.parameters():
                    parameter.requires_grad = False
        log.info("Froze adapter layers")

    def on_eval_batch_start(self, batch, batch_idx: int, dataloader_idx: int, split):
        if self.has_language_adapter:
            if self.current_val_idx != dataloader_idx:
                dataspec_name, _ = self._get_datasetspec(split, dataloader_idx)
                # WARN: only works for SIB
                lang_code = dataspec_name.split("_")[-2]
                script = dataspec_name.split("_")[-1]
                # INFO: what about unsupported languages?
                try:
                    if lang_code in LANGUAGE_MAPPINGS:
                        lang_code = LANGUAGE_MAPPINGS[lang_code]
                    if f"{lang_code}_{script}" in LANGUAGE_MAPPINGS:
                        lang_code = LANGUAGE_MAPPINGS[f"{lang_code}_{script}"]
                    self.model.target_lang = lang_code
                    self.model.load_adapter(lang_code)
                    log.info(f"Loaded adapter for {lang_code}.")
                except:
                    log.warning(
                        f"Loading adapter for {lang_code} failed. Loading default 'eng' adapter."
                    )
                    self.model.target_lang = "eng"
                    self.model.load_adapter("eng")
                self.current_val_idx = dataloader_idx

    def on_validation_batch_start(  # type: ignore
        self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int
    ):
        self.on_eval_batch_start(batch, batch_idx, dataloader_idx, Split.VAL)

    def on_test_batch_start(  # type: ignore
        self, batch: dict[str, torch.Tensor], batch_idx: int, dataloader_idx: int
    ):
        self.on_eval_batch_start(batch, batch_idx, dataloader_idx, Split.TEST)

    def on_validation_end(self) -> None:
        super().on_validation_end()
        if self.has_language_adapter:
            self.model.load_adapter(self.train_lang)
            self.freeze_adapters()
            log.info(f"Validation finished. Loaded adapter for {self.train_lang=}")
        self.current_val_idx = -1

    def on_test_end(self) -> None:
        super().on_validation_end()
        if self.has_language_adapter:
            self.model.load_adapter(self.train_lang)
            self.freeze_adapters()
            log.info(f"Testing finished. Loaded adapter for {self.train_lang=}")
        self.current_val_idx = -1
