import os
import copy
import warnings
import logging
import torch
import k2
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter
from torch.nn.parallel import DistributedDataParallel as DDP

from ..utils.icefall_checkpoint import (
    remove_checkpoints,
    save_checkpoint_with_global_batch_idx,
    save_checkpoint,
    load_model_params,
    load_checkpoint,
    update_averaged_model,
    remove_ds_checkpoints,
)
from ..utils.icefall_utils import (
    MetricsTracker,
    get_parameter_groups_with_lrs,)
from ..data.icefall_data_module import IcefallAsrDatamodule
from ..tokenizer.tokenizer_module import build_tokenizer
from ..models.zipformer_adam.optim import Noam

def log_memory(rank, tag=""):
    """Logs memory usage for the current rank."""
    device = torch.device(f"cuda:{rank}")  # or use next(model.parameters()).device
    allocated = torch.cuda.memory_allocated(device) / 1024**2
    reserved = torch.cuda.memory_reserved(device) / 1024**2
    max_allocated = torch.cuda.max_memory_allocated(device) / 1024**2
    max_reserved = torch.cuda.max_memory_reserved(device) / 1024**2
    
    print(f"[Rank {rank}][{tag}] "
          f"Allocated: {allocated:.2f} MB, "
          f"Reserved: {reserved:.2f} MB, "
          f"Max Allocated: {max_allocated:.2f} MB, "
          f"Max Reserved: {max_reserved:.2f} MB")


class AsrTrainer:
    def __init__(self, cfg, model, optimizer, scheduler, world_size=1):
        """
        Args:
            cfg (DictConfig): Your Hydra configuration.
            model (DeepSpeedEngine): The DeepSpeed-wrapped model engine.
            optimizer: deepspeed optimizer.
            scheduler: deepspeed scheduler.
            world_size (int): Total number of processes.
        """
        self.cfg = cfg
        self.rank = model.global_rank
        self.local_rank = model.local_rank
        self.world_size = world_size
        self.global_step = cfg.trainer.start_batch
        self.device = torch.device("cuda", self.local_rank)
        self.language_to_id = {'Chinese': 0, 'English': 1}

        # 1) Use the DeepSpeed engine directly.
        self.model = model
        self.optimizer = optimizer

        # 2) (Optional) Keep a CPU copy for model averaging.
        # If your DeepSpeed engine wraps the original model in `module`, use that.
        if self.rank == 0:
            # # Note: Ensure that the model copy is made from the underlying model.
            underlying_model = self.model.module if hasattr(self.model, "module") else self.model
            # Extract the state dictionary and move each tensor to CPU and convert to float64
            state = {k: v.cpu().to(torch.float64) for k, v in underlying_model.state_dict().items()}

            # Create a new instance of the model.
            from auden.models.model_module import build_model
            model_copy = build_model(cfg.model)  # adjust arguments as needed

            # Load the state dict into the new model
            model_copy.load_state_dict(state)

            self.model_avg = model_copy
        else:
            self.model_avg = None

        # 3) Tokenizer initialization.
        self.tokenizer = build_tokenizer(cfg.tokenizer)
        # Retrieve the model configuration. If DeepSpeed wraps the model, access the underlying module.
        model_config = self.model.module.config if hasattr(self.model, "module") else self.model.config
        assert self.tokenizer.blank_id == model_config.blank_id, "Mismatch in blank_id"
        assert self.tokenizer.vocab_size == model_config.vocab_size, "Mismatch in vocab_size"

        # 4) Dataset and DataLoader initialization.
        self.data_module = IcefallAsrDatamodule(cfg)
        self.train_dl = self.data_module.train_dl
        self.valid_dl = self.data_module.valid_dl
        
        # 5) build a custom scheduler if its not warmup scheduler
        if cfg.trainer.scheduler == 'noam':
            self.scheduler = Noam(self.optimizer, cfg.trainer.warmup_batches, base_lr=cfg.trainer.base_lr)
        else:
            self.scheduler = scheduler

        # 5) (Optional) Load checkpoint if available.
        self._load_checkpoint_if_available()

        # 6) TensorBoard setup.
        if cfg.trainer.tensorboard and self.rank == 0:
            self.tb_writer = SummaryWriter(log_dir=f"{cfg.exp_dir}/tensorboard")
        else:
            self.tb_writer = None

            
    def _load_checkpoint_if_available(self):
        # resuming training if start_batch > 0 or start_epoch > 1
        if self.cfg.trainer.start_batch > 0:
            self.model.load_checkpoint(self.cfg.exp_dir, tag=f'ds-checkpoint-{self.global_step}')
        elif self.cfg.trainer.start_epoch > 1:
            self.model.load_checkpoint(self.cfg.exp_dir, tag=f'ds-epoch-{self.cfg.trainer.start_epoch-1}')
            
        # load other checkpoints outside deepspeed
        if self.cfg.trainer.start_batch > 0:
            filename = Path(self.cfg.exp_dir) / f"checkpoint-{self.cfg.trainer.start_batch}.pt"
        elif self.cfg.trainer.start_epoch > 1:
            filename = Path(self.cfg.exp_dir) / f"epoch-{self.cfg.trainer.start_epoch-1}.pt"
        else:
            filename = None

        if filename is not None:
            assert filename.is_file(), f"{filename} does not exist!"

            checkpoints = load_checkpoint(
                filename,
                model_avg=self.model_avg,
            )
            self.global_step = checkpoints["batch_idx_train"]
            if self.cfg.trainer.scheduler == 'noam':
                self.scheduler = Noam(self.model.optimizer, self.cfg.trainer.warmup_batches, base_lr=self.cfg.trainer.base_lr)
        
        
    
    def _set_batch_count(self, batch_count):
        underlying_model = self.model.module if hasattr(self.model, "module") else self.model
        for name, module in underlying_model.named_modules():
            if hasattr(module, "batch_count"):
                module.batch_count = batch_count
            if hasattr(module, "name"):
                module.name = name
                
    def _get_adjusted_batch_count(self):
        # returns the number of batches we would have used so far if we had used the reference
        # duration.  This is for purposes of _set_batch_count().
        return (
            self.global_step
            * (self.cfg.data.max_duration * self.world_size)
            / self.cfg.trainer.ref_duration
        ) + self.cfg.trainer.initial_batch_count
        
    def _compute_loss(self, batch: dict, is_training: bool,
    ):
        """
        Compute loss given the model and its inputs.

        Args:
        batch:
            A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
            for the content in it.
        is_training:
            True for training. False for validation. When it is True, this
            function enables autograd during computation; when it is False, it
            disables autograd.
        """
        device = self.device
        feature = batch["inputs"]
        # at entry, feature is (N, T, C)
        assert feature.ndim == 3
        if self.model.fp16_enabled():
            feature = feature.to(torch.half)
        elif self.model.bfloat16_enabled():
            feature = feature.to(torch.bfloat16)
        feature = feature.to(device)

        supervisions = batch["supervisions"]
        feature_lens = supervisions["num_frames"].to(device)

        batch_idx_train = self.global_step
        warm_step = self.cfg.trainer.rnnt_warm_step

        texts = batch["supervisions"]["text"]
        langs = [c.supervisions[0].language for c in batch["supervisions"]["cut"]]
        lang_ids = torch.tensor([self.language_to_id[lang] for lang in langs]).to(self.device)
        y = self.tokenizer.encode(texts)
        y = k2.RaggedTensor(y).to(device)

        with torch.set_grad_enabled(is_training):
            simple_loss, pruned_loss, ctc_loss, _, lid_output, lid_loss = self.model(
                x=feature,
                x_lens=feature_lens,
                y=y,
                prune_range=self.cfg.trainer.prune_range,
                am_scale=self.cfg.trainer.am_scale,
                lm_scale=self.cfg.trainer.lm_scale,
                language=lang_ids,
            )
            
            loss = 0.0

            if simple_loss:
                s = self.cfg.trainer.simple_loss_scale
                # take down the scale on the simple loss from 1.0 at the start
                # to params.simple_loss scale by warm_step.
                simple_loss_scale = (
                    s
                    if batch_idx_train >= warm_step
                    else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
                )
                pruned_loss_scale = (
                    1.0
                    if batch_idx_train >= warm_step
                    else 0.1 + 0.9 * (batch_idx_train / warm_step)
                )

                loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
            
            if ctc_loss:
                loss += self.cfg.trainer.ctc_loss_scale * ctc_loss
                
            if lid_loss:
                loss += self.cfg.trainer.lid_loss_scale * lid_loss
                
            # if self.cfg.trainer.balance_loss_scale > 0:
            #     loss += self.cfg.trainer.balance_loss_scale * balance_loss
                
            # if self.cfg.trainer.specialization_loss_scale > 0:
            #     loss += self.cfg.trainer.specialization_loss_scale * specialization_loss

        assert loss.requires_grad == is_training

        info = MetricsTracker()
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            info["frames"] = (feature_lens // 4).sum().item()

        # Note: We use reduction=sum while computing the loss.
        info["loss"] = loss.detach().cpu().item()
        if simple_loss:
            info["simple_loss"] = simple_loss.detach().cpu().item()
            info["pruned_loss"] = pruned_loss.detach().cpu().item()
        if ctc_loss:
            info["ctc_loss"] = ctc_loss.detach().cpu().item()
        if lid_loss:
            info["utterances"] = feature.shape[0]
            info["lid_loss"] = lid_loss.detach().cpu().item() * info["frames"] # it will be normalized by num_frames for tb and log
            if isinstance(model, torch.nn.parallel.DistributedDataParallel):
                accuracy = self.model.module.calculate_accuracy(lid_output, langs)
            else:
                accuracy = self.model.calculate_accuracy(lid_output, langs)
            info["lid_acc"] = accuracy * info["frames"] # it will be normalized by num_frames for tb and log
        # if self.cfg.trainer.balance_loss_scale > 0:
        #     info["balance_loss"] = balance_loss.detach().cpu().item() * info["frames"]
        # if self.cfg.trainer.specialization_loss_scale > 0:
        #     info["specialization_loss"] = specialization_loss.detach().cpu().item() * info["frames"]

        return loss, info
        
            
    def train_one_epoch(self, epoch):
        self.model.train()
        tot_loss = MetricsTracker()

        for batch_idx, batch in enumerate(self.train_dl):
            # this is to make sure the training setting is consistent for different batch_size/max_duration
            if batch_idx % 10 == 0:
                self._set_batch_count(self._get_adjusted_batch_count())
            self.global_step += 1
            batch_size = len(batch["supervisions"]["text"])
            loss, loss_info = self._compute_loss(
                batch=batch,
                is_training=True,
            )
            # summary stats
            tot_loss = (tot_loss * (1 - 1 / self.cfg.trainer.reset_interval)) + loss_info
            # NOTE: We use reduction==sum and loss is computed over utterances
            # in the batch and there is no normalization to it so far.
            
            # 1) Backward pass using DeepSpeed.
            self.model.backward(loss)

            # 2) DeepSpeed optimizer step (which internally calls optimizer.step() and zero_grad).
            self.model.step()
            
            if self.cfg.trainer.scheduler == 'noam':
               self.scheduler.step_batch(self.global_step)

            # keep track of the averaged_model that is a moving average of the model
            # it could save time when we want to take the averaged checkpoint during inference
            if (
                self.rank == 0 and 
                self.global_step > 0 and 
                self.global_step % self.cfg.trainer.average_period == 0):
                update_averaged_model(
                    average_period=self.cfg.trainer.average_period,
                    batch_idx_train=self.global_step,
                    model_cur=self.model.module,
                    model_avg=self.model_avg,
                )

            # validate and save the model
            if (
                self.global_step > 0
                and self.global_step % self.cfg.trainer.valid_interval == 0
            ):
                self.validate_model(epoch)
                if self.global_step % (self.cfg.trainer.save_every_n * self.cfg.trainer.valid_interval) == 0:
                    # save deepspeed model engine, they will be loaded for continual training
                    self.model.save_checkpoint(self.cfg.exp_dir, tag=f'ds-checkpoint-{self.global_step}')
                    # these are for general non-deepspeed things like model state dict that would be loaded for inference
                    save_checkpoint_with_global_batch_idx(
                        out_dir=self.cfg.exp_dir,
                        global_batch_idx=self.global_step,
                        model_avg=self.model_avg,
                        params=self.cfg.model,
                        sampler=self.train_dl.sampler,
                        rank=self.rank,
                    )
                    
                    # only restore the latest k checkpoints to save disk space
                    remove_ds_checkpoints(self.cfg.exp_dir, self.cfg.trainer.keep_last_k, self.rank)
                    remove_checkpoints(
                        out_dir=self.cfg.exp_dir,
                        topk=self.cfg.trainer.keep_last_k,
                        rank=self.rank,
                    )

            if batch_idx % self.cfg.trainer.log_interval == 0 and self.rank == 0:
                cur_lr = max(self.scheduler.get_last_lr())

                logging.info(
                    f"Epoch {epoch}, "
                    f"batch {batch_idx}, loss[{loss_info}], "
                    f"tot_loss[{tot_loss}], batch size: {batch_size}, "
                    f"lr: {cur_lr:.2e}, "
                )

                if self.tb_writer is not None:
                    self.tb_writer.add_scalar(
                        "train/learning_rate", cur_lr, self.global_step
                    )

                    loss_info.write_summary(
                        self.tb_writer, "train/current_", self.global_step
                    )
                    tot_loss.write_summary(self.tb_writer, "train/tot_", self.global_step)
    
                        
    def validate_model(self, epoch):
        """Run the validation process."""
        self.model.eval()
        with torch.no_grad():
            for i, valid_dl_i in enumerate(self.valid_dl):
                tot_loss = MetricsTracker()
                for batch_idx, batch in enumerate(valid_dl_i):
                    loss, loss_info = self._compute_loss(
                        batch=batch,
                        is_training=False,
                    )
                    assert loss.requires_grad is False
                    tot_loss = tot_loss + loss_info
                    
                if self.world_size > 1:
                    tot_loss.reduce(loss.device)
                
                if self.rank == 0:
                    logging.info(f"Epoch {epoch}, global batch {self.global_step}, validation: {tot_loss}")
                    if self.tb_writer is not None:
                        tot_loss.write_summary(
                            self.tb_writer, f"train/valid_{i}", self.global_step
                        )
                
        self.model.train()
    
                    
    def run(self):
        """
        Entry point to run the entire training process.
        """
        num_epochs = self.cfg.trainer.num_epochs
        start_epoch = self.cfg.trainer.start_epoch
        for epoch in range(start_epoch, num_epochs+1): # start from 1 instead of 0 in icefall
            if self.cfg.trainer.scheduler == 'noam':
               self.scheduler.step_epoch(epoch - 1)
            # fix_random_seed(params.seed + epoch - 1)
            self.train_dl.sampler.set_epoch(epoch - 1)
            if self.tb_writer is not None:
                self.tb_writer.add_scalar("train/epoch", epoch, self.global_step)

            self.train_one_epoch(epoch)
            self.model.save_checkpoint(self.cfg.exp_dir, tag=f'ds-epoch-{epoch}') # save deepspeed model engine
            if self.rank == 0:
                filename = Path(self.cfg.exp_dir) / f"epoch-{epoch}.pt"
                save_checkpoint(
                    filename=filename,
                    params=self.cfg.model,
                    batch_idx_train=self.global_step,
                    model_avg=self.model_avg,
                    sampler=self.train_dl.sampler,
                    rank=self.rank,
                )