"""
Use SFT trainer to train rosetta model
"""

import gc
import torch
import torch.nn as nn
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.optimization import get_scheduler
from torch.optim import AdamW
from tqdm import tqdm
import os
import sys
import json
import argparse
import shutil
import wandb
import torch.distributed as dist  # Added for Distributed Data Parallel support
from torch.nn.parallel import DistributedDataParallel  # For type checking
from datetime import datetime
from typing import List, Dict, Any, Tuple, Optional
import math
import contextlib

# Add the project root to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))

from rosetta.model.wrapper import RosettaModel
from rosetta.model.projector import create_projector, save_projector
from rosetta.model.aggregator import save_aggregator, get_aggregator_class
from rosetta.train.dataset_adapters import ChatDataset, RosettaDataCollator, create_dataset, MMLUFilteredChatDataset, generate_kv_cache_index
from rosetta.train.model_utils import k_nearest_sources, last_aligned_sources

torch.autograd.set_detect_anomaly(True)

class MixedDataset(Dataset):
    """Mixed dataset combining rosetta_only and slm_only datasets with labels"""
    
    def __init__(self, rosetta_only_path: str, slm_only_path: str, split: str = "train", 
                 num_samples: Optional[int] = None):
        """
        Initialize mixed dataset
        
        Args:
            rosetta_only_path: Path to rosetta_only dataset
            slm_only_path: Path to slm_only dataset  
            split: Dataset split
            num_samples: Number of samples to use from each dataset
        """
        # Load both datasets
        self.rosetta_dataset = MMLUFilteredChatDataset(
            split=split, 
            # num_samples=2143,
            data_path=rosetta_only_path
        )
        self.slm_dataset = MMLUFilteredChatDataset(
            split=split,
            # num_samples=2143, 
            data_path=slm_only_path
        )
        
        # Create combined index mapping
        self.total_len = len(self.rosetta_dataset) + len(self.slm_dataset)
        
    def __len__(self):
        return self.total_len
        
    def __getitem__(self, idx):
        if idx < len(self.rosetta_dataset):
            # From rosetta dataset - should use rosetta model
            sample = self.rosetta_dataset[idx]
            dataset_label = 1  # rosetta_only
        else:
            # From slm dataset - should use slm only
            sample = self.slm_dataset[idx - len(self.rosetta_dataset)]
            dataset_label = 0  # slm_only
            
        # Add dataset label to the sample
        return {
            'conversation': sample,
            'dataset_label': dataset_label
        }

class MixedChatDataset(Dataset):
    """ChatDataset wrapper for mixed dataset that preserves dataset labels"""
    
    def __init__(self, mixed_dataset: MixedDataset, tokenizer: AutoTokenizer, max_length: int = 2048):
        self.mixed_dataset = mixed_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.mixed_dataset)
        
    def __getitem__(self, idx):
        sample = self.mixed_dataset[idx]
        messages = sample['conversation'] 
        dataset_label = sample['dataset_label']
        
        # Process the conversation like ChatDataset
        instruction = self.tokenizer.apply_chat_template(
            messages[:1],
            tokenize=False,
            add_generation_prompt=True,
            enable_thinking=False,
        )

        full_text = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
            enable_thinking=False,
        )

        instruction_tokens = self.tokenizer(instruction, add_special_tokens=False)["input_ids"]
        full_tokens = self.tokenizer(full_text, add_special_tokens=False)["input_ids"]
        
        if len(full_tokens) > self.max_length:
            full_tokens = full_tokens[:self.max_length]
        
        labels = [-100] * len(instruction_tokens) + full_tokens[len(instruction_tokens):]
        if len(labels) > self.max_length:
            labels = labels[:self.max_length]
        
        kv_cache_index = generate_kv_cache_index(len(instruction_tokens), len(full_tokens))

        return {
            "input_ids": full_tokens,
            "labels": labels,
            "kv_cache_index": kv_cache_index,
            "dataset_label": dataset_label  # Add dataset label
        }

class MixedRosettaDataCollator(RosettaDataCollator):
    """Extended data collator that handles dataset labels"""
    
    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        # Extract dataset labels before calling parent
        dataset_labels = [feat.pop('dataset_label') for feat in features]
        
        # Call parent collator
        batch = super().__call__(features)
        
        # Add dataset labels back to batch
        batch['dataset_label'] = torch.tensor(dataset_labels, dtype=torch.long)
        
        return batch

def set_seed(seed: int = 42):
    """Set all random seeds for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # For distributed training
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        torch.distributed.barrier()

def enable_full_determinism():
    """Enable stricter determinism settings for reproducibility."""
    # Must be set before CUDA context creation for cuBLAS determinism
    os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":4096:8")
    # PyTorch deterministic algorithms (may raise if non-deterministic ops are used)
    try:
        torch.use_deterministic_algorithms(True)
    except Exception:
        pass
    # Disable TF32 to reduce numeric variability
    try:
        torch.backends.cuda.matmul.allow_tf32 = False
        torch.backends.cudnn.allow_tf32 = False
    except Exception:
        pass

def broadcast_decision_from_rank0(decision: bool, distributed: bool, device: str, rank: int) -> bool:
    """Broadcast a boolean decision from rank 0 to all ranks so control flow matches."""
    if not distributed:
        return decision
    if rank == 0:
        tensor_flag = torch.tensor([1 if decision else 0], device=device, dtype=torch.int)
    else:
        tensor_flag = torch.empty(1, device=device, dtype=torch.int)
    dist.broadcast(tensor_flag, src=0)
    return bool(tensor_flag.item())

def freeze_model(model: nn.Module):
    """Freeze all parameters in a model"""
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_model(model: nn.Module):
    """Unfreeze all parameters in a model"""
    for param in model.parameters():
        param.requires_grad = True


def unfreeze_projectors(rosetta_model: RosettaModel):
    """Unfreeze only the projector parameters"""
    for projector in rosetta_model.projector_list:
        for param in projector.parameters():
            param.requires_grad = True

def freeze_all_except_selectors(rosetta_model: RosettaModel):
    """Freeze all parameters except selector-related parameters in projectors"""
    # First freeze everything
    for param in rosetta_model.parameters():
        param.requires_grad = False
    
    # Then unfreeze selector-related parameters
    for i, projector in enumerate(rosetta_model.projector_list):
        if hasattr(projector, 'selector_depends_on_input') and projector.selector_depends_on_input:
            # For input-dependent selectors, unfreeze the selector_generator
            if hasattr(projector, 'selector_generator'):
                for param in projector.selector_generator.parameters():
                    param.requires_grad = True
                print(f"Unfroze selector_generator parameters in projector {i}")
            else:
                print(f"Warning: Projector {i} has selector_depends_on_input=True but no selector_generator")
        else:
            # For parameter-based selectors, unfreeze selector_logit
            if hasattr(projector, 'selector_logit'):
                projector.selector_logit.requires_grad = True
                print(f"Unfroze selector_logit with shape: {projector.selector_logit.shape} in projector {i}")
            else:
                print(f"Warning: Projector {i} does not have selector_logit attribute")

def build_layer_mapping(n_target=28, n_source=36):

    source_positions = [i / (n_source - 1) for i in range(n_source)]
    target_positions = [j / (n_target - 1) for j in range(n_target)]

    mapping = []
    for i, sp in enumerate(target_positions):
        closest_j = min(range(n_source), key=lambda j: abs(source_positions[j] - sp))
        mapping.append((i, closest_j))

    return mapping

def build_shared_mlp(source_dim: int, hidden_dim: int, target_dim: int, num_layers: int, 
                use_layer_norm: bool, dropout: float, dtype: torch.dtype) -> nn.Sequential:
    """Build a single MLP projection module"""
    layers = []
        
    # Input projection
    layers.append(nn.Linear(source_dim, hidden_dim, dtype=dtype))
    if use_layer_norm:
        layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
    layers.append(nn.GELU())
    layers.append(nn.Dropout(dropout))
        
    # Hidden layers
    for _ in range(num_layers - 2):
        layers.append(nn.Linear(hidden_dim, hidden_dim, dtype=dtype))
        if use_layer_norm:
            layers.append(nn.LayerNorm(hidden_dim, dtype=dtype))
        layers.append(nn.GELU())
        layers.append(nn.Dropout(dropout))
        
    # Output projection
    if num_layers > 1:
        layers.append(nn.Linear(hidden_dim, target_dim, dtype=dtype))
    else:
        # Single layer case
        layers = [nn.Linear(source_dim, target_dim, dtype=dtype)]
        
    return nn.Sequential(*layers)
    
def setup_models(model_config: Dict[str, Any], device: str = "cuda", dtype: torch.dtype = torch.bfloat16, 
                pretrained_path: str = None):
    """Setup base and teacher models with projectors, load pretrained weights if provided"""
    
    # Load tokenizer (use base model tokenizer)
    tokenizer = AutoTokenizer.from_pretrained(model_config["base_model"])

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    # Load base model
    base_model = AutoModelForCausalLM.from_pretrained(
        model_config["base_model"],
        torch_dtype=dtype
    )
    
    # Load teacher model  
    teacher_model = AutoModelForCausalLM.from_pretrained(
        model_config["teacher_model"],
        torch_dtype=dtype
    )
    
    # Get model dimensions and layer counts
    # base_dim = base_model.config.head_dim
    # teacher_dim = teacher_model.config.head_dim
    base_dim = int(base_model.model.layers[0].self_attn.k_proj.out_features / base_model.config.num_key_value_heads)
    teacher_dim = int(teacher_model.model.layers[0].self_attn.k_proj.out_features / teacher_model.config.num_key_value_heads)
    base_num_heads = base_model.config.num_key_value_heads
    teacher_num_heads = teacher_model.config.num_key_value_heads
    slm_num_layers = base_model.config.num_hidden_layers
    llm_num_layers = teacher_model.config.num_hidden_layers
    
    # Create projector from config
    projector_config = model_config["projector"]
    projector_params = projector_config["params"].copy()
    projector_params["dtype"] = dtype
    projector_list = []
    # Only M projectors (share projector across sources): one per target layer
    num_projectors = slm_num_layers

    # shared_key_projection=build_shared_mlp(
    #     source_dim=teacher_dim,
    #     hidden_dim=projector_params["hidden_dim"],
    #     target_dim=base_dim,
    #     num_layers=projector_params["num_layers"],
    #     use_layer_norm=projector_params["use_layer_norm"],
    #     dropout=projector_params["dropout"],
    #     dtype=dtype
    # )
    # shared_value_projection=build_shared_mlp(
    #     source_dim=teacher_dim,
    #     hidden_dim=projector_params["hidden_dim"],
    #     target_dim=base_dim,
    #     num_layers=projector_params["num_layers"],
    #     use_layer_norm=projector_params["use_layer_norm"],
    #     dropout=projector_params["dropout"],
    #     dtype=dtype
    # )
    for _ in range(num_projectors):
        projector = create_projector(
            projector_config["type"],
            source_dim=teacher_dim,
            target_dim=base_dim,
            source_num_heads=teacher_num_heads,
            target_num_heads=base_num_heads,
            # shared_key_projection=shared_key_projection,
            # shared_value_projection=shared_value_projection,
            **projector_params
        )
        projector_list.append(projector.to(device))
    
    # Init RosettaModel
    # Build aggregators from config (optional)
    aggregator_config = model_config.get("aggregator")
    if aggregator_config:
        aggregator_type = aggregator_config["type"]
        aggregator_params = aggregator_config.get("params", {}).copy()
        aggregator_cls = get_aggregator_class(aggregator_type)
        K = int(aggregator_params.get("num_options", 1))
        aggregator_list = [aggregator_cls(**aggregator_params) for _ in range(slm_num_layers)]
    else:
        # No aggregator configured
        K = 1
        aggregator_list = []

    rosetta_model = RosettaModel(
        model_list=[base_model, teacher_model],
        base_model_idx=0,
        projector_list=projector_list,
        aggregator_list=aggregator_list,
        include_response=model_config.get("include_response", False)
    ).to(device).eval()
    
    
    # mapping stretegy
    if model_config["mapping"] == "last_aligned":
        source_target_mapping = last_aligned_sources(slm_num_layers, llm_num_layers, K)
    elif model_config["mapping"] == "k_nearest":
        source_target_mapping = k_nearest_sources(slm_num_layers, llm_num_layers, K)
    else:
        raise ValueError(f"Invalid mapping strategy: {model_config['mapping']}")
    print(f"Using {model_config['mapping']} mapping strategy (target: [sources])")

    # set projector and aggregator
    for target_layer_idx, src_list in source_target_mapping.items():
        # Only set aggregator index when aggregators exist
        if len(aggregator_list) > 0:
            rosetta_model.set_aggregator_idx(
                source_model_idx=1,
                target_model_idx=0,
                target_model_layer_idx=target_layer_idx,
                aggregator_idx=target_layer_idx,
            )
        for source_layer_idx in src_list:
            rosetta_model.set_projector_config(
                source_model_idx=1,  # Teacher model
                source_model_layer_idx=source_layer_idx,
                target_model_idx=0,  # Base model
                target_model_layer_idx=target_layer_idx,
                projector_idx=target_layer_idx,  # share projector per target layer
            )

    # Load pretrained weights if provided
    if pretrained_path:
        print(f"Loading pretrained weights from {pretrained_path}")
        for i, projector in enumerate(rosetta_model.projector_list):
            projector_weight_path = os.path.join(pretrained_path, f"projector_{i}.pt")
            if os.path.exists(projector_weight_path):
                # Load state dict with strict=False to allow missing keys (e.g., selector_generator)
                pretrained_state_dict = torch.load(projector_weight_path, map_location=device)
                missing_keys, unexpected_keys = projector.load_state_dict(pretrained_state_dict, strict=False)
                
                print(f"Loaded projector {i} weights")
                if missing_keys:
                    print(f"  Missing keys (will be randomly initialized): {missing_keys}")
                if unexpected_keys:
                    print(f"  Unexpected keys (ignored): {unexpected_keys}")
            else:
                print(f"Warning: Projector {i} weights not found at {projector_weight_path}")

    return rosetta_model, tokenizer


def train_step(model: RosettaModel, batch: List[Tuple[str]], tokenizer: AutoTokenizer, max_length: int, device: str):
    """Single training step with additional selector loss"""
    
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    position_ids = attention_mask.long().cumsum(-1) - 1
    position_ids = position_ids.masked_fill(attention_mask == 0, 0)
    labels = batch["labels"].to(device)
    dataset_labels = batch["dataset_label"].to(device)  # Shape: (batch_size,)

    kv_cache_index = [x.to(device) for x in batch["kv_cache_index"]]
    
    # Clear accumulated selectors before forward pass
    model_to_use = model.module if hasattr(model, "module") else model
    for proj in model_to_use.projector_list:
        if hasattr(proj, 'clear_accumulated_selectors'):
            proj.clear_accumulated_selectors()
    
    # Forward pass
    outputs = model.forward(
        kv_cache_index=kv_cache_index,
        input_ids=input_ids,
        attention_mask=attention_mask,
        position_ids=position_ids,
        labels=labels,
        use_cache=True
    )
    
    loss = 0

    # Add selector loss using accumulated selector values
    selector_loss = 0.0
    selector_count = 0
    
    for i, proj in enumerate(model_to_use.projector_list):
        # Check if we have accumulated selector values from the forward pass
        if hasattr(proj, 'accumulated_selectors') and len(proj.accumulated_selectors) > 0:
            # Use the already computed selectors (already passed through sigmoid)
            for selector_sigmoid in proj.accumulated_selectors:
                
                # Get batch size from the stored selector
                batch_size = selector_sigmoid.shape[0]
                
                # Create target based on dataset labels - ensure we have the right batch size
                target_labels = dataset_labels[:batch_size] if dataset_labels.shape[0] > batch_size else dataset_labels
                
                # Create target selectors based on granularity
                if proj.selector_granularity == "scalar":
                    # selector_sigmoid shape could be (B, 1, 1, 1) or similar
                    target_selectors = target_labels.float().view(-1, 1, 1, 1).expand_as(selector_sigmoid)
                elif proj.selector_granularity == "token":
                    # selector_sigmoid shape could be (B, H, N, D) 
                    target_selectors = target_labels.float().view(-1, 1, 1, 1).expand_as(selector_sigmoid)
                elif proj.selector_granularity == "head":
                    # selector_sigmoid shape could be (B, H, N, D)
                    target_selectors = target_labels.float().view(-1, 1, 1, 1).expand_as(selector_sigmoid)
                else:  # "value"
                    # selector_sigmoid shape could be (B, H, N, D)
                    target_selectors = target_labels.float().view(-1, 1, 1, 1).expand_as(selector_sigmoid)
                
                target_selectors = target_selectors.to(torch.bfloat16)
                selector_sigmoid = selector_sigmoid.to(torch.bfloat16)
                # Calculate binary cross entropy loss between selector and target
                current_loss = torch.nn.functional.binary_cross_entropy(selector_sigmoid, target_selectors)
                selector_loss += current_loss
                selector_count += 1
            
        elif hasattr(proj, 'selector_logit') and not proj.selector_depends_on_input:
            # Fallback to parameter-based selector (for non-input-dependent selectors)
            selector_sigmoid = torch.sigmoid(proj.selector_logit / proj.selector_temperature)
            
            # Create target based on dataset labels
            batch_size = dataset_labels.shape[0]
            if proj.selector_granularity == "scalar":
                target_selectors = dataset_labels.float().view(batch_size, 1, 1, 1)
                expanded_selector = selector_sigmoid.expand(batch_size, -1, -1, -1)
            elif proj.selector_granularity == "token":
                seq_len = input_ids.shape[1]
                target_selectors = dataset_labels.float().view(batch_size, 1, 1, 1).expand(batch_size, 1, seq_len, 1)
                expanded_selector = selector_sigmoid.expand(batch_size, 1, seq_len, -1)
            elif proj.selector_granularity == "head":
                num_heads = selector_sigmoid.shape[1] if len(selector_sigmoid.shape) > 1 else 1
                target_selectors = dataset_labels.float().view(batch_size, 1, 1, 1).expand(batch_size, num_heads, 1, 1)
                expanded_selector = selector_sigmoid.expand(batch_size, num_heads, -1, -1)
            else:  # "value"
                target_selectors = dataset_labels.float().view(batch_size, 1, 1, 1).expand_as(selector_sigmoid)
                expanded_selector = selector_sigmoid
            
            current_loss = torch.nn.functional.binary_cross_entropy(expanded_selector, target_selectors)
            selector_loss += current_loss
            selector_count += 1
        else:
            print(f"Warning: Projector {i} has no accessible selector values for loss computation")
    
    if selector_count > 0:
        selector_loss = selector_loss / selector_count  # Average across all selector calls
        loss += selector_loss  # Add weighted selector loss

    # Gate regularization (existing code)
    for proj in model_to_use.projector_list:
        if hasattr(proj, 'gate_logit') and hasattr(proj, 'gate_temperature'):
            gate_logit = torch.mean(proj.gate_logit)
            gate = torch.sigmoid(gate_logit / proj.gate_temperature)
            # loss += 0.0025 * gate
    return loss


def evaluate_model(model: RosettaModel, eval_loader: DataLoader, tokenizer: AutoTokenizer, max_length: int, device: str) -> float:
    """Evaluate the model and return average loss"""
    model.eval()
    eval_loss_total = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for eval_batch in eval_loader:
            eval_loss = train_step(model, eval_batch, tokenizer, max_length, device)
            eval_loss_total += eval_loss.item()
            num_batches += 1
    
    avg_eval_loss = eval_loss_total / num_batches if num_batches > 0 else 0.0
    model.train()  # Set back to train mode
    return avg_eval_loss


def main():
    """
    Train a RosettaModel using hyper-parameters defined in a JSON configuration
    file. The CLI is only used to specify the path to the config; all other
    settings live in the JSON. Training progress is tracked with Weights &
    Biases and the original config is copied alongside checkpoints for full
    reproducibility.
    """

    # ------------------------------------------------------------------
    # Configuration loading
    # ------------------------------------------------------------------
    parser = argparse.ArgumentParser(description="Train RosettaModel from a JSON config")
    parser.add_argument("--config", type=str, default="recipe/default_config.json", help="Path to JSON config file")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training")
    parser.add_argument("--output_dir", type=str, default="outputs", help="Directory to save outputs and checkpoints")
    args = parser.parse_args()

    with open(args.config, "r") as f:
        cfg: Dict[str, Any] = json.load(f)

    # Extract configuration sections
    model_config = cfg["model"]
    training_config = cfg["training"]
    output_config = cfg["output"]
    data_config = cfg["data"]

    # Set seed for reproducibility and enable stricter determinism
    set_seed(seed = training_config["seed"])
    enable_full_determinism()

    # Create datetime subfolder under output directory
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    # timestamped_output_dir = os.path.join(output_config["output_dir"], "3")
    timestamped_output_dir = output_config["output_dir"]
    # timestamped_output_dir = args.output_dir
    
    # Ensure output directory exists and copy config for reproducibility
    os.makedirs(timestamped_output_dir, exist_ok=True)
    shutil.copy(args.config, os.path.join(timestamped_output_dir, "config.json"))

    # ------------------------------------------------------------------
    # Distributed training setup
    # ------------------------------------------------------------------
    world_size = int(os.environ.get("WORLD_SIZE", 1))
    distributed = world_size > 1
    local_rank = args.local_rank

    if distributed:
        dist.init_process_group(backend="nccl")
        rank = dist.get_rank()
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        device = f"cuda:{local_rank}"
    else:
        rank = 0
        local_rank = 0
        device = training_config.get("device", "cuda")

    is_main_process = rank == 0

    # ------------------------------------------------------------------
    # Weights & Biases initialisation
    # ------------------------------------------------------------------
    run_name = f"{output_config['wandb_config']['run_name']}_{timestamp}"
    if is_main_process:
        wandb.init(
            project=output_config["wandb_config"]["project"],
            name=run_name,
            config=cfg,
            mode=output_config["wandb_config"]["mode"]
        )
    
    print(f"Outputs will be saved to: {timestamped_output_dir}")

    # ------------------------------------------------------------------
    # Model setup
    # ------------------------------------------------------------------
    if is_main_process:
        print("Setting up models…")
    
    # Add pretrained path
    pretrained_path = os.path.join("proj_ablation", "llm_only", "final")
    rosetta_model, tokenizer = setup_models(model_config, device, torch.bfloat16, pretrained_path)

    # Apply special freezing for selector training - freeze all except selectors
    if is_main_process:
        print("Freezing all weights except selector_logit parameters")
    freeze_all_except_selectors(rosetta_model) 

    # Wrap with DDP if needed
    if distributed:
        rosetta_model = torch.nn.parallel.DistributedDataParallel(
            rosetta_model,
            device_ids=[local_rank],
            output_device=local_rank,
            find_unused_parameters=True
        )

    total_params = sum(p.numel() for p in rosetta_model.parameters())
    trainable_params = sum(p.numel() for p in rosetta_model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Percentage of trainable parameters: {trainable_params / total_params * 100:.4f}%")

    # ------------------------------------------------------------------
    # Dataset & dataloaders
    # ------------------------------------------------------------------
    print("Loading mixed dataset…")
    
    # Create mixed dataset with both rosetta_only and slm_only
    mixed_dataset = MixedDataset(
        rosetta_only_path="./teacher_datasets/Rosetta_llm_only",
        slm_only_path="./teacher_datasets/slm_only_new", 
        split="test",
        num_samples=data_config.get("num_samples", None)
    )
    
    # Wrap with MixedChatDataset to handle tokenization
    full_dataset = MixedChatDataset(mixed_dataset, tokenizer, max_length=2048)

    train_size = int(data_config["train_ratio"] * len(full_dataset))
    eval_size = len(full_dataset) - train_size
    train_dataset, eval_dataset = torch.utils.data.random_split(full_dataset, [train_size, eval_size])

    per_device_batch_size = training_config["per_device_train_batch_size"]
    grad_accum_steps = training_config.get("gradient_accumulation_steps", 1)

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset, shuffle=True, seed=training_config["seed"]
        )
        eval_sampler = torch.utils.data.distributed.DistributedSampler(
            eval_dataset, shuffle=False, seed=training_config["seed"]
        )
    else:
        train_sampler = None
        eval_sampler = None

    collator = MixedRosettaDataCollator(
        tokenizer,
        pad_to_multiple_of=training_config.get("pad_to_multiple_of", None),
        max_length=2048
    )

    # Ensure per-worker seeding if num_workers > 0
    def _worker_init_fn(worker_id):
        worker_seed = training_config["seed"] + worker_id
        np.random.seed(worker_seed)
        random.seed(worker_seed)

    train_loader = DataLoader(
        train_dataset,
        batch_size=per_device_batch_size,
        shuffle=(train_sampler is None),
        sampler=train_sampler,
        collate_fn=collator,
        worker_init_fn=_worker_init_fn,
    )
    eval_loader = DataLoader(
        eval_dataset,
        batch_size=per_device_batch_size,
        shuffle=False,
        sampler=eval_sampler,
        collate_fn=collator,
        worker_init_fn=_worker_init_fn,
    )

    updates_per_epoch = math.ceil(len(train_loader) / grad_accum_steps)
    total_steps = updates_per_epoch * training_config["num_epochs"]

    # ------------------------------------------------------------------
    # Optimiser & scheduler
    # ------------------------------------------------------------------
    gate_params = []
    weight_params = []
    other_params = []

    for name, param in rosetta_model.named_parameters():
        if param.requires_grad:
            if "gate" in name:
                gate_params.append(param)
            elif "key_weight" in name or "value_weight" in name:
                weight_params.append(param)
            else:
                other_params.append(param)

    optimizer = AdamW([
        {"params": gate_params, "lr": 3e-4},
        {"params": weight_params, "lr": 3e-4},
        {"params": other_params, "lr": 3e-4}
        ], weight_decay=training_config["weight_decay"])

    scheduler = get_scheduler(
        training_config["scheduler_type"],
        optimizer=optimizer,
        num_warmup_steps=int(training_config["warmup_ratio"] * total_steps),
        num_training_steps=total_steps,
    )

    # ------------------------------------------------------------------
    # Training loop
    # ------------------------------------------------------------------
    print("Starting training…")
    global_step = 0
    optimizer.zero_grad()
    for epoch in range(training_config["num_epochs"]):
        if distributed and train_sampler is not None:
            # Ensure different shuffles across epochs in distributed setup
            train_sampler.set_epoch(epoch)
        rosetta_model.train()
        epoch_loss = 0.0
        progress_bar = tqdm(total=updates_per_epoch, desc=f"Epoch {epoch + 1}/{training_config['num_epochs']}", disable=not is_main_process)

        macro_step_in_epoch = 0
        accum_true_loss = 0.0
        micro_in_window = 0

        for batch_idx, batch in enumerate(train_loader):
            # Forward/backward with gradient accumulation and DDP no_sync for micro-steps
            is_accum_step = ((batch_idx + 1) % grad_accum_steps) != 0
            sync_ctx = rosetta_model.no_sync() if distributed and hasattr(rosetta_model, "no_sync") and is_accum_step else contextlib.nullcontext()

            with sync_ctx:
                loss = train_step(rosetta_model, batch, tokenizer, training_config["max_length"], device)
                true_loss_value = loss.detach().item()
                scaled_loss = loss / grad_accum_steps  # Gradient accumulation
                scaled_loss.backward()

            # accumulate true (unscaled) loss for averaging/printing
            epoch_loss += true_loss_value
            accum_true_loss += true_loss_value
            micro_in_window += 1

            # Optimizer step on boundaries or at last batch of the epoch
            did_step = (not is_accum_step) or (batch_idx + 1 == len(train_loader))
            grad_norm_value = None
            if did_step:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    [p for p in rosetta_model.parameters() if p.requires_grad],
                    max_norm=training_config["max_grad_norm"]
                )
                grad_norm_value = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else float(grad_norm)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                macro_step_in_epoch += 1

                # Update temperatures AFTER optimizer step, using actual training step count
                model_to_use = rosetta_model.module if hasattr(rosetta_model, "module") else rosetta_model
                for proj in model_to_use.projector_list:
                    if hasattr(proj, 'update_temperature') and callable(proj.update_temperature):
                        proj.update_temperature(global_step)

                # Anneal aggregator temperatures if supported
                for agg in model_to_use.aggregator_list:
                    if hasattr(agg, 'update_temperature') and callable(agg.update_temperature):
                        agg.update_temperature(global_step)

            # Progress bar and logging
            if is_main_process and did_step:
                # Calculate fractional epoch based on macro steps
                fractional_epoch = epoch + (macro_step_in_epoch / updates_per_epoch)

                avg_window_loss = accum_true_loss / max(1, micro_in_window)
                postfix = {
                    "loss": f"{avg_window_loss:.4f}",
                    "avg_loss": f"{epoch_loss / (batch_idx + 1):.4f}",
                    "lr": f"{scheduler.get_last_lr()[0]:.2e}",
                }
                progress_bar.set_postfix(postfix)
                progress_bar.update(1)

                wandb.log({
                    "train/loss": avg_window_loss,
                    "train/lr": scheduler.get_last_lr()[0],
                    "train/grad_norm": grad_norm_value,
                    "train/epoch": fractional_epoch,
                }, step=global_step)

                # reset window accumulators
                accum_true_loss = 0.0
                micro_in_window = 0

            # Evaluation and checkpointing only on real optimizer steps
            if did_step:
                # Calculate fractional epoch based on macro steps
                fractional_epoch = epoch + (macro_step_in_epoch / updates_per_epoch)
                # Evaluation at regular intervals under DDP using broadcasted decision
                want_eval = (global_step % output_config["eval_steps"] == 0)
                want_eval = broadcast_decision_from_rank0(want_eval, distributed, device, rank)
                if want_eval:
                    if distributed:
                        # All ranks evaluate their shard and average
                        local_eval_loss = evaluate_model(rosetta_model, eval_loader, tokenizer, training_config["max_length"], device)
                        loss_tensor = torch.tensor([local_eval_loss], device=device, dtype=torch.float32)
                        dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
                        avg_eval_loss = loss_tensor.item()
                        if is_main_process:
                            print(f"\nEvaluation (mid-epoch) at step {global_step}: {avg_eval_loss:.4f}")
                            wandb.log({
                                "eval/loss": avg_eval_loss,
                                "eval/step": global_step,
                                "eval/epoch": fractional_epoch
                            }, step=global_step)
                    else:
                        eval_loss = evaluate_model(rosetta_model, eval_loader, tokenizer, training_config["max_length"], device)
                        print(f"\nEvaluation loss at step {global_step}: {eval_loss:.4f}")
                        wandb.log({
                            "eval/loss": eval_loss,
                            "eval/step": global_step,
                            "eval/epoch": fractional_epoch
                        }, step=global_step)

                # Checkpointing under DDP using broadcasted decision
                want_save = (global_step % output_config["save_steps"] == 0)
                want_save = broadcast_decision_from_rank0(want_save, distributed, device, rank)
                if want_save:
                    if is_main_process:
                        checkpoint_dir = os.path.join(timestamped_output_dir, f"checkpoint-{global_step}")
                        os.makedirs(checkpoint_dir, exist_ok=True)

                        # Unwrap DDP to access underlying RosettaModel
                        base_model_ref = rosetta_model.module if isinstance(rosetta_model, DistributedDataParallel) else rosetta_model

                        for i, proj in enumerate(base_model_ref.projector_list):
                            # We save both the trainable weights and the constructor config
                            torch.save(proj.state_dict(), os.path.join(checkpoint_dir, f"projector_{i}.pt"))
                            save_projector(proj, os.path.join(checkpoint_dir, f"projector_{i}.json"))
                        for i, agg in enumerate(base_model_ref.aggregator_list):
                            torch.save(agg.state_dict(), os.path.join(checkpoint_dir, f"aggregator_{i}.pt"))
                            save_aggregator(agg, os.path.join(checkpoint_dir, f"aggregator_{i}.json"))
                        base_model_ref.save_projector_config(os.path.join(checkpoint_dir, "projector_config.json"))
                        base_model_ref.save_aggregator_config(os.path.join(checkpoint_dir, "aggregator_config.json"))

                        torch.save({
                            "step": global_step,
                            "epoch": epoch,
                            "optimizer_state_dict": optimizer.state_dict(),
                            "scheduler_state_dict": scheduler.state_dict(),
                            "loss": true_loss_value,  # true loss for this batch window
                        }, os.path.join(checkpoint_dir, "training_state.pt"))
                        print(f"\nCheckpoint saved at step {global_step}")

        avg_epoch_loss = epoch_loss / len(train_loader)

        # ------------------------------------------------------------------
        # Evaluation phase
        # ------------------------------------------------------------------
        if distributed:
            # Run eval on all ranks and average for deterministic sync
            local_eval_loss = evaluate_model(rosetta_model, eval_loader, tokenizer, training_config["max_length"], device)
            loss_tensor = torch.tensor([local_eval_loss], device=device, dtype=torch.float32)
            dist.all_reduce(loss_tensor, op=dist.ReduceOp.AVG)
            avg_eval_loss = loss_tensor.item()
            if is_main_process:
                print(f"Epoch {epoch + 1} completed. Train loss: {avg_epoch_loss:.4f} | Eval loss: {avg_eval_loss:.4f}")
                wandb.log({
                    "eval/epoch_loss": avg_eval_loss,
                    "epoch": epoch + 1,
                    "train/epoch_avg_loss": avg_epoch_loss
                }, step=global_step)
        else:
            print(f"Running end-of-epoch evaluation for epoch {epoch + 1}...")
            avg_eval_loss = evaluate_model(rosetta_model, eval_loader, tokenizer, training_config["max_length"], device)
            print(f"Epoch {epoch + 1} completed. Train loss: {avg_epoch_loss:.4f} | Eval loss: {avg_eval_loss:.4f}")
            wandb.log({
                "eval/epoch_loss": avg_eval_loss,
                "epoch": epoch + 1,
                "train/epoch_avg_loss": avg_epoch_loss
            }, step=global_step)

    # ------------------------------------------------------------------
    # Save final artefacts
    # ------------------------------------------------------------------
    if is_main_process:
        final_dir = os.path.join(timestamped_output_dir, "final")
        os.makedirs(final_dir, exist_ok=True)

        base_model_ref = rosetta_model.module if isinstance(rosetta_model, DistributedDataParallel) else rosetta_model

        for i, proj in enumerate(base_model_ref.projector_list):
            torch.save(proj.state_dict(), os.path.join(final_dir, f"projector_{i}.pt"))
            save_projector(proj, os.path.join(final_dir, f"projector_{i}.json"))
        for i, agg in enumerate(base_model_ref.aggregator_list):
            torch.save(agg.state_dict(), os.path.join(final_dir, f"aggregator_{i}.pt"))
            save_aggregator(agg, os.path.join(final_dir, f"aggregator_{i}.json"))
        base_model_ref.save_projector_config(os.path.join(final_dir, "projector_config.json"))
        base_model_ref.save_aggregator_config(os.path.join(final_dir, "aggregator_config.json"))

    if is_main_process:
        print("Training completed!")
        wandb.finish()

    # Clean up distributed training
    if distributed:
        dist.destroy_process_group()


if __name__ == "__main__":
    # debug mode
    # import debugpy
    # debugpy.listen(("0.0.0.0", 5678))
    # print("Waiting for debugger attach...")
    # debugpy.wait_for_client()
    # print("Debugger attached, running...")
    # torch.autograd.set_detect_anomaly(True)
    main()
