from yaml import load, Loader
import argparse
from datetime import datetime
from datetime import timedelta
import os
import os.path as osp
from copy import deepcopy
import json
import yaml
from pathlib import Path
import math

import torch
from torch.utils.tensorboard import SummaryWriter
from torch import distributed as dist

from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import (
    DeepSpeedPlugin,
    DistributedDataParallelKwargs,
    InitProcessGroupKwargs,
    ProjectConfiguration,
    set_seed,
    gather_object,
)

import transformers
import diffusers
from diffusers.training_utils import cast_training_params
from diffusers.optimization import get_scheduler

from lib.utils.logging import print_and_save_logging
from lib.utils.optimizer_utils import get_optimizer, gradient_norm
from lib.utils.misc import Dict


logger = get_logger('acwm_cosmos2_trainer')
logger.setLevel('INFO')


class State:
    """Container for storing training state variables (configurations, model info, etc.)."""
    # Training state
    seed: int = None
    model_name: str = None
    accelerator: Accelerator = None
    weight_dtype: torch.dtype = None
    train_epochs: int = None
    train_steps: int = None
    overwrote_max_train_steps: bool = False
    num_trainable_parameters: int = 0
    learning_rate: float = None
    train_batch_size: int = None
    generator: torch.Generator = None

    # Hub state (for model uploading, if used)
    repo_id: str = None
    # Artifacts state (output directories)
    output_dir: str = None


class BaseTrainer(object):
    """Base trainer class for initializing distributed training, logging, and optimization components."""
    def __init__(self, config_file, val_only=False):
        # Load configuration from YAML file
        cd = load(open(config_file, 'r'), Loader=Loader)

        # Convert config dict to argparse Namespace for easy attribute access
        args = argparse.Namespace(**cd)
        # Ensure numerical parameters are cast to float (prevents string-type issues)
        args.lr = float(args.lr)
        args.epsilon = float(args.epsilon)
        args.weight_decay = float(args.weight_decay)

        self.args = args
        self.state = State()  # Initialize training state container
        self.val_only = val_only

        # Initialize distributed training (Accelerator, mixed precision, etc.)
        self._init_distributed()
        # Initialize logging and output directories (skipped if only validation is run)
        if not val_only:
            self._init_logging()
            self._init_directories_and_repositories()

        # Set model name from config
        self.state.model_name = self.args.model_name

    def _init_logging(self):
        """Initialize logging verbosity and TensorBoard writer; sync output paths across processes."""
        # Set logging verbosity: main process shows warnings/info, others only show errors
        if self.state.accelerator.is_local_main_process:
            transformers.utils.logging.set_verbosity_warning()
            diffusers.utils.logging.set_verbosity_info()
        else:
            transformers.utils.logging.set_verbosity_error()
            diffusers.utils.logging.set_verbosity_error()

        # Generate timestamp for unique output folder (only on main process)
        if self.state.accelerator.is_main_process:
            current_time = datetime.now()
            start_time = current_time.strftime("%Y_%m_%d_%H_%M_%S")
            
            # Create save folder: use subfolder if specified, else use timestamp
            if getattr(self.args, "sub_folder", False):
                self.save_folder = osp.join(self.args.output_dir, self.args.sub_folder)
            else:
                self.save_folder = osp.join(self.args.output_dir, start_time)
            
            # Create directory (ignore if it already exists)
            os.makedirs(self.save_folder, exist_ok=True)
     
            # Save configuration to YAML file (for reproducibility)
            if isinstance(self.args, Dict):
                args_dict = self.args.to_dict()
            else:
                args_dict = vars(deepcopy(self.args))
                # Convert non-serializable values to strings
                for k, v in args_dict.items():
                    args_dict[k] = str(v)
            
            # Save config as YAML (more human-readable than JSON for hierarchical configs)
            with open(osp.join(self.save_folder, 'config.yaml'), "w") as file:
                yaml.dump(args_dict, file, sort_keys=False, indent=2)

            # Initialize TensorBoard writer
            self.writer = SummaryWriter(log_dir=self.save_folder)

            # Broadcast save folder path to all processes (for consistent logging)
            save_folder_bytes = self.save_folder.encode()
            folder_len_tensor = torch.tensor([len(save_folder_bytes)], device=self.state.accelerator.device)
            if dist.is_initialized():
                dist.broadcast(folder_len_tensor, src=0)  # Broadcast path length first
                folder_tensor = torch.ByteTensor(list(save_folder_bytes)).to(self.state.accelerator.device)
                dist.broadcast(folder_tensor, src=0)       # Broadcast path bytes
        else:
            # Receive save folder path from main process (if distributed)
            if dist.is_initialized():
                folder_len_tensor = torch.tensor([0], device=self.state.accelerator.device)
                dist.broadcast(folder_len_tensor, src=0)
                folder_tensor = torch.empty(
                    folder_len_tensor.item(), dtype=torch.uint8, device=self.state.accelerator.device
                )
                dist.broadcast(folder_tensor, src=0)
                self.save_folder = bytes(folder_tensor.tolist()).decode()

        # Initialize cross-process logging (saves logs to the shared save folder)
        print_and_save_logging(self.save_folder, rank=self.state.accelerator.process_index)

    def _init_distributed(self):
        """Initialize distributed training components via Hugging Face Accelerator."""
        # Configure project and logging directories
        logging_dir = Path(self.args.output_dir, self.args.logging_dir)
        project_config = ProjectConfiguration(
            project_dir=self.args.output_dir, 
            logging_dir=logging_dir
        )

        # DDP kwargs: allow unused parameters (common in transformer fine-tuning)
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        # NCCL timeout configuration (prevents hanging in distributed setups)
        init_process_group_kwargs = InitProcessGroupKwargs(
            backend="nccl", 
            timeout=timedelta(seconds=self.args.nccl_timeout)
        )

        # Disable mixed precision for MPS (Apple Silicon) due to limited support
        mixed_precision = "no" if torch.backends.mps.is_available() else self.args.mixed_precision
        # Configure external logging tools (e.g., Weights & Biases)
        report_to = None if self.args.report_to.lower() == "none" else self.args.report_to

        # Initialize DeepSpeed plugin if enabled (for large-scale training)
        if getattr(self.args, "use_deepspeed", False):
            per_device_batch_size = self.args.batch_size
            world_size = int(os.environ.get("WORLD_SIZE", 1))  # Number of GPUs across all nodes
            grad_accum_steps = self.args.gradient_accumulation_steps

            # Calculate total effective batch size (for DeepSpeed)
            total_train_batch_size = per_device_batch_size * world_size * grad_accum_steps
            self.args.deepspeed["train_batch_size"] = total_train_batch_size

            # Initialize DeepSpeed plugin with config
            ds_plugin = DeepSpeedPlugin(
                hf_ds_config=self.args.deepspeed,
                gradient_accumulation_steps=grad_accum_steps
            )
        else:
            ds_plugin = None

        # Initialize Accelerator (handles DDP, mixed precision, gradient accumulation)
        accelerator = Accelerator(
            project_config=project_config,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
            mixed_precision=mixed_precision,
            log_with=report_to,
            kwargs_handlers=[ddp_kwargs, init_process_group_kwargs],
            deepspeed_plugin=ds_plugin,
        )

        # Disable native AMP for MPS (Apple Silicon)
        if torch.backends.mps.is_available():
            accelerator.native_amp = False

        # Store accelerator in state for global access
        self.state.accelerator = accelerator

        # Set random seed for reproducibility (if specified)
        if self.args.seed is not None:
            self.state.seed = self.args.seed
            set_seed(self.args.seed)

        # Determine weight dtype based on mixed precision setting
        weight_dtype = torch.float32
        if self.state.accelerator.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif self.state.accelerator.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16
            
        self.state.weight_dtype = weight_dtype

    def _init_directories_and_repositories(self):
        """Create output directories (only run by main process)."""
        if self.state.accelerator.is_main_process:
            self.args.output_dir = Path(self.args.output_dir)
            # Create parent directories if they don't exist
            self.args.output_dir.mkdir(parents=True, exist_ok=True)
            # Store output directory in state
            self.state.output_dir = self.args.output_dir


    def prepare_trainable_parameters(self):
        """Configure which model parameters are trainable (e.g., full fine-tuning vs. LoRA)."""
        logger.info("Initializing trainable parameters")
        
        # Determine which components to freeze (based on training type)
        if self.args.train_type == "lora":
            # Freeze the base transformer if using LoRA
            components_to_disable_grads = [self.transformer]
        else:
            # No components to freeze (full fine-tuning)
            components_to_disable_grads = []
            
        # Disable gradients for frozen components
        for component in components_to_disable_grads:
            if component is not None:
                component.requires_grad_(False)

        # MPS (Apple Silicon) does not support bfloat16: raise error if attempted
        if torch.backends.mps.is_available() and self.state.weight_dtype == torch.bfloat16:
            raise ValueError(
                "Mixed precision training with bfloat16 is not supported on MPS. "
                "Please use fp16 (recommended) or fp32 instead."
            )

        # Enable gradient checkpointing to reduce memory usage (if specified)
        if self.args.gradient_checkpointing:
            self.transformer.enable_gradient_checkpointing()

        # Enable TF32 for faster matrix multiplications on Ampere GPUs
        if self.args.allow_tf32 and torch.cuda.is_available():
            torch.backends.cuda.matmul.allow_tf32 = True
                
        # Checkpoint loading is not implemented in this base class
        if self.args.prev_checkpoint is not None:
            raise NotImplementedError("Checkpoint loading is not implemented in BaseTrainer.")

    def prepare_optimizer(self):
        """Initialize optimizer and learning rate scheduler."""
        logger.info("Initializing optimizer and learning rate scheduler")

        # Get training mode (determines which parameters are optimized)
        train_mode = self.args.train_mode

        # Store training epochs/steps in state
        self.state.train_epochs = self.args.train_epochs
        self.state.train_steps = self.args.train_steps

        # Upcast trainable parameters to float32 for stability in mixed precision training
        if self.args.mixed_precision == "fp16":
            cast_training_params([self.transformer], dtype=torch.float32)

        # Calculate effective learning rate (if scaling by batch size is enabled)
        self.state.learning_rate = self.args.lr
        if self.args.scale_lr:
            self.state.learning_rate = (
                self.state.learning_rate
                * self.args.gradient_accumulation_steps
                * self.args.batch_size
                * self.state.accelerator.num_processes
            )

        # Collect trainable parameters based on training mode
        transformer_trainable_params = []
        if train_mode == 'action_only':
            # Only train parameters related to action modeling
            for name, param in self.transformer.named_parameters():
                if 'action_' in name:
                    param.requires_grad = True
                    transformer_trainable_params.append(param)
                else:
                    param.requires_grad = False

        elif train_mode == "video_only":
            # Train all parameters except action- and language-related ones
            for name, param in self.transformer.named_parameters():
                if 'action_' not in name and "lang_" not in name:
                    param.requires_grad = True
                    transformer_trainable_params.append(param)
                else:
                    param.requires_grad = False

        elif train_mode in ["all", "action_full"]:
            # Train all parameters (full fine-tuning)
            for name, param in self.transformer.named_parameters():
                param.requires_grad = True
                transformer_trainable_params.append(param)

        else:
            raise NotImplementedError(f"Training mode '{train_mode}' is not supported.")

        # Log number of trainable parameters
        num_trainable_params = sum(p.numel() for p in transformer_trainable_params)
        logger.info(f'Total trainable parameters: {num_trainable_params}')

        # Wrap parameters for optimizer (with specific learning rate)
        transformer_params_with_lr = {
            "params": transformer_trainable_params,
            "lr": self.state.learning_rate,
        }
        params_to_optimize = [transformer_params_with_lr]
        # Store number of trainable parameters in state
        self.state.num_trainable_parameters = num_trainable_params

        # Initialize optimizer (uses custom `get_optimizer` utility)
        optimizer = get_optimizer(
            params_to_optimize=params_to_optimize,
            optimizer_name=self.args.optimizer,
            learning_rate=self.args.lr,
            beta1=self.args.beta1,
            beta2=self.args.beta2,
            beta3=self.args.beta3,
            epsilon=self.args.epsilon,
            weight_decay=self.args.weight_decay,
            use_8bit=self.args.optimizer_8bit,
            use_torchao=self.args.optimizer_torchao,
        )

        # Calculate number of update steps per epoch (for scheduler)
        num_update_steps_per_epoch = math.ceil(
            len(self.train_dataloader) / self.args.gradient_accumulation_steps
        )

        # Override train steps if not specified (use epochs * steps per epoch)
        if self.state.train_steps is None:
            self.state.train_steps = self.state.train_epochs * num_update_steps_per_epoch
            self.state.overwrote_max_train_steps = True

        # Initialize learning rate scheduler
        lr_scheduler = get_scheduler(
            name=self.args.lr_scheduler,
            optimizer=optimizer,
            num_warmup_steps=self.args.lr_warmup_steps * self.state.accelerator.num_processes,
            num_training_steps=self.state.train_steps * self.state.accelerator.num_processes,
            num_cycles=self.args.lr_num_cycles,
            power=self.args.lr_power,
        )

        # Store optimizer and scheduler for training loop
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler


    def prepare_for_training(self):
        """Prepare model, optimizer, dataloader, and scheduler for distributed training."""
        self.transformer, self.optimizer, self.train_dataloader, self.lr_scheduler = self.state.accelerator.prepare(
            self.transformer, self.optimizer, self.train_dataloader, self.lr_scheduler
        )

    def prepare_trackers(self):
        """Initialize external trackers (e.g., Weights & Biases) for training metrics."""
        logger.info("Initializing trackers")

        # Use custom tracker name if specified, else use default
        tracker_name = self.args.tracker_name or "acwm_cosmos2"
        # Initialize trackers with current config
        self.state.accelerator.init_trackers(tracker_name, config=self.args.__dict__)
