import os
import json
import copy
import logging
from pathlib import Path
from typing import Optional, Dict, Any

import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter

from ..optim.scheduler import Eden, Eden2, Eden3
#from ...optim.noam import Noam
from ..optim.optimizer import ScaledAdam
from ..optim.utils import get_parameter_groups_with_lrs

# Model avg and checkpointing
from ..utils.checkpoint import (
    load_checkpoint,
    save_checkpoint,  
    remove_checkpoints,
    update_averaged_model,
    load_model_params,
)

# LoRA injection (if used)
from ..peft.lora.utils import (
    inject_lora_to_model, 
    mark_only_lora_as_trainable,
    register_backward_hook_for_extra_tokens
)

# Metrics and tracking
from ..utils.metric_tracker import MetricsTracker

# Data loader types (if needed in signatures)
from torch.utils.data import DataLoader


class BaseTrainer:
    def __init__(self, cfg, model, rank=0, local_rank=0, world_size=1):
        self.cfg = cfg
        self.exp_dir = cfg.exp_dir
        self.model = model
        self.rank = rank
        self.local_rank = local_rank
        self.world_size = world_size
        self.device = torch.device("cuda", local_rank)
        self.model = model
        self.use_fp16 = cfg.trainer.use_fp16
        self.global_step = cfg.trainer.start_batch
        self.tb_writer = None

        # initialization
        self.initialize_model_from_checkpoint()

        # (Optional) add lora to model
        if hasattr(cfg, "lora") and cfg.lora.use_lora:
            self._init_peft()
            
        if self.cfg.trainer.get("freeze_modules"):
            self._freeze_modules()
            
        if self.cfg.trainer.use_averaged_model:
            self._init_model_avg()

        self.device = torch.device("cuda", self.local_rank)
        self.model.to(self.device)
        if self.world_size > 1:
            self.model = DDP(
                self.model,
                device_ids=[self.local_rank],
                find_unused_parameters=cfg.trainer.get("find_unused_parameters", True)
            )

        self.scaler = GradScaler(enabled=self.use_fp16)
        self.optimizer = self.build_optimizer()
        self.scheduler = self.build_scheduler()
        self.train_dl, self.valid_dl = self.build_dataloaders(cfg)

        if self.rank == 0 and cfg.trainer.tensorboard:
            self.tb_writer = SummaryWriter(log_dir=f"{self.exp_dir}/tensorboard")

        # resume training
        self.resume_training_from_checkpoint()
        
        num_param = sum([p.numel() for p in self.model.parameters()])
        num_trainable_param = sum([
            p.numel() for p in self.model.parameters() if p.requires_grad])
        logging.info(f"Number of model parameters: {num_param}")
        logging.info(f"Number of trainable model parameters: {num_trainable_param}")
        
    def _init_model_avg(self):
        if self.rank == 0:
            # Only rank 0 tracks the moving average
            model = self.model.module if hasattr(self.model, "module") else self.model
            self.model_avg = copy.deepcopy(model).to(torch.float64)
        else:
            self.model_avg = None
            
    def _init_peft(self):
        with open(self.cfg.lora.config) as f:
            lora_config = json.load(f)
        logging.info(f"LoRA configuration: {lora_config}")
        inject_lora_to_model(self.model, lora_config)
        mark_only_lora_as_trainable(self.model, 'none')
        
        # when using LoRA (or other scenarios with backbone frozen),
        # the extra LID/tokens in decoder.token_embedding should be active.
        for p in self.model.decoder.token_embedding.parameters():
            p.requires_grad = True
        new_tokens = [self.tokenizer.to_language_token(i)
                        for i in self.model.config.extra_languages.keys()]
        new_tokens += [getattr(self.tokenizer, i)
                        for i in self.model.config.extra_tokens]
        register_backward_hook_for_extra_tokens(self.model, new_tokens)
        
    def _freeze_modules(self):
        """
        Freeze parameters of specified submodules in a model.
        """
        for name in self.cfg.trainer.freeze_modules:
            submodule = getattr(self.model, name, None)
            if submodule is None:
                print(f"Warning: model has no attribute '{name}'")
                continue
            else:
                logging.info(f'freeze {name} for model in training')
            for param in submodule.parameters():
                param.requires_grad = False

    def build_optimizer(self):
        opt_name = self.cfg.trainer.optimizer
        lr = self.cfg.trainer.base_lr

        if opt_name == 'scaled_adam':
            return ScaledAdam(
                get_parameter_groups_with_lrs(self.model, lr=lr, include_names=True),
                lr=lr,
                clipping_scale=2.0,
            )
        elif opt_name == 'adamw':
            return torch.optim.AdamW(
                [p for p in self.model.parameters() if p.requires_grad], 
                lr=lr
            )
        elif opt_name == 'adam':
            return torch.optim.Adam(
                self.model.parameters(),
                lr=lr,
                betas=(0.9, 0.999),
                eps=1e-8,
                weight_decay=1e-6,
            )
        else:
            raise ValueError(f"Unsupported optimizer: {opt_name}")

    def build_scheduler(self):
        sch_name = self.cfg.trainer.scheduler
        warmup_batches = self.cfg.trainer.warmup_batches
        lr_batches = self.cfg.trainer.lr_batches
        base_lr = self.cfg.trainer.base_lr

        if sch_name == 'eden':
            if self.cfg.data.use_infinite_dataset:
                if self.cfg.trainer.lr_steps_per_epoch > 0:
                    return Eden3(self.optimizer, lr_batches, self.cfg.trainer.lr_steps_per_epoch, warmup_batches)
                else:
                    return Eden2(self.optimizer, lr_batches, warmup_batches)
            else:
                return Eden(self.optimizer, lr_batches, self.cfg.trainer.lr_epochs, warmup_batches)
        elif sch_name == 'noam':
            return Noam(self.optimizer, warmup_batches, base_lr)
        else:
            raise ValueError(f"Unsupported scheduler: {sch_name}")
        
    def initialize_model_from_checkpoint(self):
        if self.cfg.trainer.initialization.get("checkpoint", None) is not None:
            init_ckpt = self.cfg.trainer.initialization["checkpoint"]
            init_modules = self.cfg.trainer.initialization.get("init_modules", None)
            
            logging.info(f"Initializing {init_modules} from: {init_ckpt}")
            load_model_params(
                model=self.model,
                ckpt_path=init_ckpt,
                init_modules=init_modules,
            )
            if self.cfg.trainer.get("use_averaged_model", False) and self.rank == 0:
                model = self.model.module if hasattr(self.model, 'module') else self.model
                self.model_avg = copy.deepcopy(model).to(torch.float64)


    def resume_training_from_checkpoint(self):
        resume_ckpt = None
        if self.cfg.trainer.start_batch > 0:
            resume_ckpt = Path(self.exp_dir) / f"checkpoint-{self.cfg.trainer.start_batch}.pt"
        elif self.cfg.trainer.start_epoch > 1:
            resume_ckpt = Path(self.exp_dir) / f"epoch-{self.cfg.trainer.start_epoch - 1}.pt"

        if resume_ckpt and resume_ckpt.is_file():
            logging.info(f"Resuming training from: {resume_ckpt}")
            checkpoints = load_checkpoint(
                resume_ckpt,
                model=self.model,
                model_avg=self.model_avg,
                optimizer=self.optimizer,
                scheduler=self.scheduler,
                scaler=self.scaler,
            )
            self.global_step = checkpoints.get("batch_idx_train", self.global_step)
        elif resume_ckpt:
            raise FileNotFoundError(f"Checkpoint file not found: {resume_ckpt}")

    def build_dataloaders(self, cfg):
        raise NotImplementedError

    def train_one_epoch(self, epoch: int):
        self.model.train()
        metrics_tracker = MetricsTracker()

        for batch_idx, batch in enumerate(self.train_dl):
            if batch_idx % 10 == 0:
                self._maybe_update_batch_count()
            if self.cfg.data.use_infinite_dataset:
                batch_idx = self.global_step

            self.global_step += 1
            batch_size = batch["inputs"].size(0)

            loss, batch_metrics = self._forward_backward_optimize(batch)

            # Track metrics
            metrics_tracker.update(batch_metrics, self.cfg.trainer.reset_interval)

            # Periodically update model avg
            self._maybe_update_model_average()
            
            # Periodically evaluate and save
            self._maybe_validate_and_save(epoch)

            # Monitor and adjust AMP grad scale
            self._maybe_rescale_grad_amp(batch_idx)

            # Periodic logging
            self._maybe_log_training_status(epoch, batch_idx, batch_size, batch_metrics, metrics_tracker)

        torch.cuda.empty_cache()

    def validate(self, epoch: int):
        """
        Runs validation on one or more validation sets.

        Uses self.valid_dl (can be a list of loaders). Only logs on rank 0.
        """
        self.model.eval()
        with torch.no_grad():
            for i, valid_loader in enumerate(self.valid_dl):
                total_metrics = MetricsTracker()

                for batch in valid_loader:
                    loss, batch_metrics = self._forward_one_batch(batch=batch, is_training=False)
                    assert not loss.requires_grad
                    total_metrics.update(batch_metrics)

                # DDP reduce
                if self.world_size > 1:
                    total_metrics.reduce(device=loss.device)

                # Logging
                if self.rank == 0:
                    logging.info(f"Epoch {epoch}, global step {self.global_step}, validation set {i}: {total_metrics}")

                    if self.tb_writer is not None:
                        total_metrics.write_summary(
                            self.tb_writer,
                            tag=f"train/valid_{i}",
                            step=self.global_step
                        )

        self.model.train()

                  
    def run(self):
        """
        Entry point to run the full training process across multiple epochs.
        """
        num_epochs = self.cfg.trainer.num_epochs
        start_epoch = self.cfg.trainer.start_epoch

        for epoch in range(start_epoch, num_epochs + 1):
            # Advance epoch-based scheduler (optional)
            self.scheduler.step_epoch(epoch - 1)

            # Ensure deterministic behavior across workers (e.g., for shuffling)
            if hasattr(self.train_dl, "sampler") and hasattr(self.train_dl.sampler, "set_epoch"):
                self.train_dl.sampler.set_epoch(epoch - 1)

            # Log epoch marker
            if self.tb_writer and self.rank == 0:
                self.tb_writer.add_scalar("train/epoch", epoch, self.global_step)

            # Train for one epoch
            self.train_one_epoch(epoch)

            # Full validation pass (optional — if not handled inside train loop)
            self.validate(epoch)

            # Save full checkpoint (epoch-based)
            self._maybe_save_epoch_checkpoint(epoch)

                
    def _maybe_update_batch_count(self):
        """
        Optionally update the model's batch_count, if the model supports it.
        This is typically used by models with schedule-float–based behaviors.
        """
        model = self.model.module if isinstance(self.model, torch.nn.parallel.DistributedDataParallel) else self.model

        if not hasattr(model, "set_batch_count"):
            return

        # Calculate batch_count only if model supports it
        if hasattr(self.cfg.trainer, "ref_duration"):
            ref_duration = self.cfg.trainer.ref_duration
        else:
            logging.warning("trainer config has no ref_duration; skipping batch count update.")
            return

        batch_count = (
            self.global_step
            * (self.cfg.data.max_duration * self.world_size)
            / ref_duration
        )
        adjusted_batch_count = batch_count + 10000 if self.cfg.trainer.initialization.checkpoint else batch_count

        model.set_batch_count(adjusted_batch_count)
        
    def _forward_backward_optimize(self, batch):
        """
        Performs a forward pass, backward pass, and optimizer step with mixed precision.

        Args:
            batch (dict): A batch of training data.

        Returns:
            loss (Tensor): The scalar loss value before scaling.
            batch_metrics (MetricsTracker): Dictionary or object containing logging info.
        """
        with torch.cuda.amp.autocast(enabled=self.use_fp16):
            loss, batch_metrics = self._forward_one_batch(batch=batch, is_training=True)

        # Backprop and optimization step
        self.scaler.scale(loss).backward()
        self.scheduler.step_batch(self.global_step)
        self.scaler.step(self.optimizer)
        self.scaler.update()
        self.optimizer.zero_grad()

        return loss, batch_metrics
    
    def _maybe_update_model_average(self):
        """
        Updates the averaged model on rank 0 using either uniform average or EMA.

        Requires:
            - self.model_avg (on CPU)
            - self.model (on GPU, optionally wrapped in DDP)
            - self.cfg.trainer.average_period
            - self.global_step

        Only runs on rank 0.
        """
        if (
            self.rank != 0 or
            not self.cfg.trainer.get("use_averaged_model", False) or
            self.global_step == 0 or
            self.global_step % self.cfg.trainer.average_period != 0
        ):
            return
        update_averaged_model(
            average_period=self.cfg.trainer.average_period,
            batch_idx_train=self.global_step,
            model_cur=self.model,
            model_avg=self.model_avg,
        )

    def _maybe_validate_and_save(self, epoch: int):
        """
        Periodically run validation and save a step-based checkpoint.
        This should be called during training at intervals.
        """
        if self.global_step == 0:
            return

        if self.global_step % self.cfg.trainer.valid_interval != 0:
            return

        self.validate(epoch)

        save_every = self.cfg.trainer.save_every_n * self.cfg.trainer.valid_interval
        if self.global_step % save_every != 0:
            return

        if self.rank == 0:
            ckpt_path = Path(self.exp_dir) / f"checkpoint-{self.global_step}.pt"
            save_checkpoint(
                filename=ckpt_path,
                model=self.model,
                model_avg=self.model_avg,
                batch_idx_train=self.global_step,
                optimizer=self.optimizer,
                scheduler=self.scheduler,
                scaler=self.scaler,
                sampler=self.train_dl.sampler,
                rank=self.rank,
            )
            logging.info(f"Saving model checkpoint to: {ckpt_path}")
            remove_checkpoints(
                out_dir=self.exp_dir,
                topk=self.cfg.trainer.keep_last_k,
                rank=self.rank,
            )

            
    def _maybe_rescale_grad_amp(self, batch_idx: int):
        """
        Monitors and adjusts the AMP gradient scaler if it's too small.
        Helps recover from instability early in training.

        This only runs if use_fp16 is True.
        """
        if not self.use_fp16 or batch_idx % 100 != 0:
            return

        cur_scale = self.scaler.get_scale()

        # Proactively increase the scale if it's growing too slowly
        if cur_scale < 8.0 or (cur_scale < 32.0 and batch_idx % 400 == 0):
            self.scaler.update(cur_scale * 2.0)

        # Warn or crash if scale is extremely small
        if cur_scale < 0.01:
            logging.warning(f"Grad scale is small: {cur_scale}")
        if cur_scale < 1.0e-5:
            raise RuntimeError(f"grad_scale is too small, exiting: {cur_scale}")
        
    def _maybe_log_training_status(
        self,
        epoch: int,
        batch_idx: int,
        batch_size: int,
        batch_metrics,
        total_metrics,
    ):
        """
        Logs training status and writes TensorBoard scalars at configured intervals.

        Only executes on rank 0.
        """
        if self.rank != 0 or batch_idx % self.cfg.trainer.log_interval != 0:
            return

        cur_lr = max(self.scheduler.get_last_lr())
        cur_grad_scale = self.scaler.get_scale() if self.use_fp16 else 1.0

        logging.info(
            f"Epoch {epoch}, "
            f"batch {batch_idx}, info[{batch_metrics}], "
            f"tot_info[{total_metrics}], batch size: {batch_size}, "
            f"lr: {cur_lr:.2e}, "
            + (f"grad_scale: {cur_grad_scale}" if self.use_fp16 else "")
        )

        if self.tb_writer is not None:
            self.tb_writer.add_scalar("train/learning_rate", cur_lr, self.global_step)
            if self.use_fp16:
                self.tb_writer.add_scalar("train/grad_scale", cur_grad_scale, self.global_step)

            batch_metrics.write_summary(self.tb_writer, "train/current_", self.global_step)
            total_metrics.write_summary(self.tb_writer, "train/tot_", self.global_step)

    def _maybe_save_epoch_checkpoint(self, epoch: int):
        """
        Saves an epoch-level checkpoint (outside step-based logic).
        """
        if self.rank != 0:
            return

        filename = Path(self.exp_dir) / f"epoch-{epoch}.pt"
        logging.info(f"Saving epoch checkpoint to: {filename}")

        save_checkpoint(
            filename=filename,
            batch_idx_train=self.global_step,
            model=self.model,
            model_avg=self.model_avg,
            optimizer=self.optimizer,
            scheduler=self.scheduler,
            sampler=self.train_dl.sampler,
            scaler=self.scaler,
            rank=self.rank,
        )



