import copy
import warnings
import logging
import torch
import k2
from torch.cuda.amp import GradScaler
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP

from ..utils.icefall_utils import (
    MetricsTracker,
    get_parameter_groups_with_lrs,)
from ..models.zipformer.optim import ScaledAdam, Eden
from ..models.zipformer_adam.optim import Noam
from ..tokenizer.tokenizer_module import build_tokenizer
from ..peft.lora.utils import (
    inject_lora_to_model,
    mark_only_lora_as_trainable,
    register_backward_hook_for_extra_tokens
)
from .icefall_trainer import IcefallAsrTrainer
from ..utils.icefall_utils import MetricsTracker
from auden.data.whisper_data_module import WhisperAsrDatamodule


class WhisperAsrTrainer(IcefallAsrTrainer):
    def __init__(self, cfg, model, rank=0, local_rank=0, world_size=1):
        """
        Args:
            cfg (DictConfig or similar): Your configuration (e.g., Hydra config).
            model (nn.Module): A PyTorch model already instantiated outside of Trainer.
            rank (int): Global rank (0 <= rank < world_size).
            local_rank (int): Local rank on the current node (0 <= local_rank < num_gpus_per_node).
            world_size (int): Total number of processes across all nodes.
        """
        self.cfg = cfg
        self.model = model
        self.rank = rank
        self.local_rank = local_rank
        self.world_size = world_size
        self.use_fp16 = cfg.trainer.use_fp16
        self.global_step = cfg.trainer.start_batch
        self.model_avg = None
        # 1) Tokenizer initialization
        self.tokenizer = build_tokenizer(cfg.tokenizer)
        assert self.tokenizer.blank_id == self.model.config.blank_id
        assert self.tokenizer.vocab_size == self.model.config.vocab_size

        # (Optional) add lora to model
        if hasattr(cfg, "lora") and cfg.lora.use_lora:
            import json
            with open(cfg.lora.config) as f:
                lora_config = json.load(f)
            logging.info(f"LoRA configuration: {lora_config}")
            inject_lora_to_model(self.model, lora_config)
            mark_only_lora_as_trainable(self.model, 'none')

            # when using LoRA (or other scenarios with backbone frozen),
            # the extra LID/tokens in decoder.token_embedding should be active.
            for p in self.model.decoder.token_embedding.parameters():
                p.requires_grad = True
            new_tokens = [self.tokenizer.to_language_token(i)
                            for i in self.model.config.extra_languages.keys()]
            new_tokens += [getattr(self.tokenizer, i)
                            for i in self.model.config.extra_tokens]
            register_backward_hook_for_extra_tokens(self.model, new_tokens)

        num_param = sum([p.numel() for p in self.model.parameters()]) / 1024 / 1024
        num_trainable_param = sum([
            p.numel() for p in self.model.parameters() if p.requires_grad]) / 1024 / 1024
        logging.info(f"Number of model parameters: {num_param}M")
        logging.info(f"Number of trainable model parameters: {num_trainable_param}M")

        if rank == 0:
            # model_avg is only used with rank 0 and is on CPU
            self.model_avg = copy.deepcopy(self.model).to(torch.float64)

        # 2) Device setup
        self.device = torch.device("cuda", self.local_rank)
        self.model.to(self.device)

        # 3) Wrap model in DistributedDataParallel if multi-GPU
        find_unused_parameters = getattr(cfg.trainer, 'find_unused_parameters', True)
        if self.world_size > 1:
            self.model = DDP(
                self.model,
                device_ids=[self.local_rank],
                find_unused_parameters=find_unused_parameters,
            )
        # 4) Dataset and dataloader initialization
        self.data_module = WhisperAsrDatamodule(cfg)
        self.train_dl = self.data_module.train_dl
        self.valid_dl = self.data_module.valid_dl

        # 5) Optimizer and Scheduler initialization
        if cfg.trainer.optimizer == 'scaled_adam':
            self.optimizer = ScaledAdam(
                get_parameter_groups_with_lrs(self.model, lr=cfg.trainer.base_lr, include_names=True),
                lr=cfg.trainer.base_lr,  # should have no effect
                clipping_scale=2.0,
            )
        elif cfg.trainer.optimizer == 'adamw':
            self.optimizer = torch.optim.AdamW(
                [p for p in model.parameters() if p.requires_grad],
                lr=cfg.trainer.base_lr
            )
        elif cfg.trainer.optimizer == 'adam':
            self.optimizer = torch.optim.Adam(
                self.model.parameters(),
                lr=cfg.trainer.base_lr,
                betas=(0.9, 0.98),
                eps=1e-9,
                weight_decay=1e-6
            )
        else:
                raise ValueError(f"Optimizer={cfg.trainer.optimizer} is not supported")

        if cfg.trainer.scheduler == 'eden':
            self.scheduler = Eden(
                self.optimizer, cfg.trainer.lr_batches, cfg.trainer.lr_epochs,
                warmup_batches=cfg.trainer.warmup_batches
            )
        elif cfg.trainer.scheduler == 'noam':
            self.scheduler = Noam(self.optimizer, cfg.trainer.warmup_batches, base_lr=cfg.trainer.base_lr)

        # 6) Automatic Mixed Precision Scaler
        self.scaler = GradScaler(
            enabled=self.use_fp16,
            init_scale=1.0
        )

        # 7) Load checkpoint if available
        self._load_checkpoint_if_available()

        if cfg.trainer.tensorboard and self.rank == 0:
            self.tb_writer = SummaryWriter(log_dir=f"{cfg.exp_dir}/tensorboard")
        else:
            self.tb_writer = None

    def _compute_loss(
        self, batch: dict, is_training: bool,
    ):
        """
        Compute loss given the model and its inputs.

        Args:
        batch:
            A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
            for the content in it.
        is_training:
            True for training. False for validation. When it is True, this
            function enables autograd during computation; when it is False, it
            disables autograd.
        """
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        feature = feature.to(device)
        feature = feature.permute(0, 2, 1)
        MAXT = 3000
        if feature.size(2) > MAXT:
            logging.warning(f"Get feature input with shape={feature.shape}")
            feature = feature[:, :, :MAXT]

        supervisions = batch["supervisions"]
        feature_lens = supervisions["num_frames"].to(device)

        batch_idx_train = self.global_step

        texts = batch["supervisions"]["text"]
        langs = [c.supervisions[0].language for c in batch["supervisions"]["cut"]]

        text_tokens_list = [
            [self.tokenizer.sot]
            + ([self.tokenizer.to_language_token(lang)] if lang is not None else [])
            + [self.tokenizer.transcribe]
            + [self.tokenizer.no_timestamps]
            + self.tokenizer.encode(text)
            + [self.tokenizer.eot]
            for text, lang in zip(texts, langs)
        ]
        target = k2.RaggedTensor([i[1:] for i in text_tokens_list]).to(device)
        y = k2.RaggedTensor([i[:-1] for i in text_tokens_list]).to(device)

        with torch.set_grad_enabled(is_training):
            loss = self.model(
                x=feature,
                x_lens=feature_lens,
                y=y,
                target=target
            )

        assert loss.requires_grad == is_training

        info = MetricsTracker()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            info["frames"] = (feature_lens // 2).sum().item()

        # Note: We use reduction=sum while computing the loss.
        info["loss"] = loss.detach().cpu().item()

        return loss, info
