"""Training orchestration runner for token classification training."""

import os
import torch
import wandb
from datasets import concatenate_datasets
from transformers import TrainingArguments

from calib.config.args import TrainArgs
from calib.utils.gpu import init_gpu
from calib.utils.io import load_completions_dict
from calib.utils.names import (
    build_run_name, get_label_mode, get_padding_info
)
from calib.data.dataset import build_dataset
from calib.data.collate import HiddenStateDataCollator
from calib.models.builder import build_models_and_tokenizer
from calib.training.trainer import BinaryTokenRewardTrainer
from calib.metrics.compute_metrics import create_compute_metrics
from calib.training.trainer import create_optimizer
from calib.utils.hdf5_loader import load_precomputed_data, get_embedding_count


class Runner:
    """Orchestrates the entire training workflow."""
    
    def __init__(self, args: TrainArgs):
        self.args = args
        self.device = None
        self.base_device = None
        self.base_model = None
        self.model = None
        self.tokenizer = None
        self.dataset = None
        self.paragraph_times = None
        self.save_dict = None
        self.eval_save_dict = None
        self.precomputed_train_data = None
        self.precomputed_eval_data = None
        
    def run(self) -> None:
        """Execute the full training pipeline."""
        print("Starting token classification training...")
        
        # Initialize GPU
        self._init_gpu()

        # Load precomputed hidden states if available
        self._load_precomputed_data()
        
        # Load data
        self._load_data()

        
        # Build models and tokenizer
        self._build_models_and_tokenizer()
        
        # Print training model architecture
        self._print_training_model()
        
        # Build dataset
        self._build_dataset()
        
        # Setup wandb and run name
        self._setup_wandb_and_naming()
        
        # Setup training
        self._setup_training()
        
        # Run training or evaluation
        self._execute_training()
        
        print("Training completed successfully!")
    
    def _init_gpu(self):
        """Initialize GPU and seeding."""
        print(f"Initializing GPU...")
        self.device, _ = init_gpu(seed=self.args.seed)
        
        # Set base device
        self.base_device = self.args.base_device if self.args.base_device else self.device
        
        print(f"Training device: {self.device}")
        print(f"Base device: {self.base_device}")
        print(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            print(f"Training GPU: {torch.cuda.get_device_name(self.device)}")
            if self.base_device != self.device:
                print(f"Base GPU: {torch.cuda.get_device_name(self.base_device)}")
    
    def _load_data(self):
        """Load completions data."""
        # Only load completion data if we don't have precomputed data
        if self.args.precomputed_path is None:
            self.save_dict = load_completions_dict(self.args.completion_path)
        else:
            self.save_dict = load_completions_dict(self.precomputed_train_data["args"].completion_path)

            specified_num_prompts = self.args.num_prompts
            if specified_num_prompts is not None:
                num_completions_per_prompt = self.save_dict["args"].num_completions_per_prompt
                specified_num_completions = specified_num_prompts * num_completions_per_prompt
                self.precomputed_train_data["eval_dataset"] = self.precomputed_train_data["eval_dataset"].select(range(min(specified_num_completions, len(self.precomputed_train_data["eval_dataset"]))))
                self.precomputed_train_data["inputs_embeds"] = self.precomputed_train_data["inputs_embeds"][:specified_num_completions]
                print(f"Restricted training dataset to {specified_num_prompts} examples")

        if self.args.eval_precomputed_path is None:
            self.eval_save_dict = load_completions_dict(self.args.eval_completion_path)
        else:
            self.eval_save_dict = load_completions_dict(self.precomputed_eval_data["args"].eval_completion_path)

            specified_eval_num_prompts = self.args.eval_num_prompts
            if specified_eval_num_prompts is not None:
                num_completions_per_prompt = self.eval_save_dict["args"].num_completions_per_prompt
                specified_eval_num_completions = specified_eval_num_prompts * num_completions_per_prompt
                self.precomputed_eval_data["eval_dataset"] = self.precomputed_eval_data["eval_dataset"].select(range(min(specified_eval_num_completions, len(self.precomputed_eval_data["eval_dataset"]))))
                self.precomputed_eval_data["inputs_embeds"] = self.precomputed_eval_data["inputs_embeds"][:specified_eval_num_completions]
                print(f"Restricted evaluation dataset to {specified_eval_num_prompts} examples")

    def _load_precomputed_data(self):
        """Load precomputed hidden states if available."""
        # Load training precomputed data
        if self.args.precomputed_path is not None:
            print(f"Loading precomputed training hidden states from: {self.args.precomputed_path}")
            self.precomputed_train_data = load_precomputed_data(self.args.precomputed_path)
            print(f"Loaded precomputed training data with {get_embedding_count(self.precomputed_train_data)} examples")

            # Verify that the precomputed data contains required fields
            if 'eval_dataset' not in self.precomputed_train_data:
                raise ValueError("Precomputed training data must contain 'eval_dataset' field")
            if 'inputs_embeds' not in self.precomputed_train_data:
                raise ValueError("Precomputed training data must contain 'hidden_states' field")

        else:
            print("No precomputed training hidden states specified, will compute on-the-fly")

        # Load evaluation precomputed data
        if self.args.eval_precomputed_path is not None:
            print(f"Loading precomputed evaluation hidden states from: {self.args.eval_precomputed_path}")
            self.precomputed_eval_data = load_precomputed_data(self.args.eval_precomputed_path)
            print(f"Loaded precomputed evaluation data with {get_embedding_count(self.precomputed_eval_data)} examples")

            # Verify that the precomputed data contains required fields
            if 'eval_dataset' not in self.precomputed_eval_data:
                raise ValueError("Precomputed evaluation data must contain 'eval_dataset' field")
            if 'inputs_embeds' not in self.precomputed_eval_data:
                raise ValueError("Precomputed evaluation data must contain 'inputs_embeds' field")

        else:
            print("No precomputed evaluation hidden states specified, will compute on-the-fly")

    def _build_models_and_tokenizer(self):
        """Build base model, training model, and tokenizer."""
        print(f"Building models and tokenizer: {self.args.model_name}")

        if self.precomputed_eval_data is not None or self.precomputed_train_data is not None:
            # When using precomputed data, we can skip base model loading for evaluation
            # But we still need it for training, so load everything normally
            pass

        self.base_model, self.lm_head, self.model, self.tokenizer = build_models_and_tokenizer(
            args=self.args,
            base_device=self.base_device,
            training_device=self.device,
            save_dict=self.save_dict,
        )
        
        print(f"Models built successfully:")
        print(f"  Base: {self.args.model_name} on {self.base_device}")
        print(f"  Training: {self.device}")
        print(f"  Tokenizer max_length: {self.tokenizer.model_max_length}")
    
    def _print_training_model(self):
        """Print training model architecture and configuration."""
        print("\n" + "="*60)
        print("TRAINING MODEL ARCHITECTURE")
        print("="*60)
        print(self.model)
        print("="*60)
        
        # Print configuration details
        if hasattr(self.model, 'config'):
            config = self.model.config
            print("TRAINING MODEL CONFIGURATION:")
            print("-" * 40)
            print(f"Model type: {getattr(config, 'model_type', 'Unknown')}")
            print(f"Architecture: {getattr(config, 'architecture', 'Unknown')}")
            print(f"Hidden size: {getattr(config, 'hidden_size', 'Unknown')}")
            print(f"Number of layers: {getattr(config, 'num_hidden_layers', 'Unknown')}")
            print(f"Number of attention heads: {getattr(config, 'num_attention_heads', 'Unknown')}")
            print(f"Group size: {getattr(config, 'group_size', 'Unknown')}")
            print(f"MLP hidden size: {getattr(config, 'mlp_hidden_size', 'Unknown')}")
            print(f"Max context length: {getattr(config, 'max_context_len', 'Unknown')}")
            print(f"Num summarization vectors: {getattr(config, 'num_summarization_vectors', 'Unknown')}")
            print(f"Number of labels: {getattr(config, 'num_labels', 'Unknown')}")
            print(f"Classifier dropout: {getattr(config, 'classifier_dropout', 'Unknown')}")
            print(f"Torch dtype: {getattr(config, 'torch_dtype', 'Unknown')}")
            print("-" * 40)
        
        # Count parameters
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Non-trainable parameters: {total_params - trainable_params:,}")
        print("="*60 + "\n")
    
    def _build_dataset(self):
        """Build dataset with proper mode configuration."""
        print("Building dataset...")

        # Build training dataset - use precomputed if available
        if self.precomputed_train_data is not None:
            print("Using precomputed training dataset")
            self.dataset = self.precomputed_train_data['eval_dataset']
            self.dataset = self.dataset.map(lambda x, idx: {'train': True, "precomputed_idx": idx}, with_indices=True)
        else:
            if self.save_dict is None:
                raise ValueError("No training data available: either provide --completion_path or --precomputed_path")
            print("Building training dataset from scratch")
            self.dataset = build_dataset(
                save_dict=self.save_dict,
                tokenizer=self.tokenizer,
                restrict_num_prompts=self.args.num_prompts,
                args=self.args,
                balance_difficulty=self.args.balance_difficulty,
                select_difficulty=self.args.select_difficulty,
            )

        # Build evaluation dataset - use precomputed if available
        if self.precomputed_eval_data is not None:
            print("Using precomputed evaluation dataset")
            self.eval_dataset = self.precomputed_eval_data['eval_dataset']
            self.eval_dataset = self.eval_dataset.map(lambda x, idx: {'train': False, "precomputed_idx": idx}, with_indices=True)
        else:
            if self.eval_save_dict is None:
                raise ValueError("No evaluation data available: either provide --eval_completion_path or --eval_precomputed_path")
            print("Building evaluation dataset from scratch")
            self.eval_dataset = build_dataset(
                save_dict=self.eval_save_dict,
                tokenizer=self.tokenizer,
                restrict_num_prompts=self.args.eval_num_prompts,
                args=self.args,
                balance_difficulty=self.args.eval_balance_difficulty,
                select_difficulty=self.args.eval_select_difficulty,
            )
        
        # Print dataset info
        num_train = len(self.dataset)
        num_valid = len(self.eval_dataset)
        
        print(f"Dataset built: {num_train} train, {num_valid} validation examples")
        
        # Log mode information
        input_description = "prefix with injected rollouts (truncated to first unbalanced })"
        print(f"Training & Evaluation: Using token classification with labels at segment boundaries (input = {input_description})")
    
    def _setup_wandb_and_naming(self):
        """Setup wandb and generate run name."""
        print("Setting up wandb and generating run name...")
        
        # Setup components for run name (inj_roll_trunc mode hardcoded)
        label_mode = get_label_mode()
        train_mode = "full"
        padding_info = get_padding_info()
        
        # Set save directory
        if self.args.load_model_from is not None:
            # When loading from existing model, use that path directly
            self.save_dir = self.args.load_model_from
        else:
            # Build run name for new training
            run_name = build_run_name(
                args=self.args,
                label_mode=label_mode,
                train_mode=train_mode,
                padding_info=padding_info,
            )
            self.save_dir = os.path.join(self.args.output_dir, run_name)
        os.makedirs(self.save_dir, exist_ok=True)
        print(f"Will save to: {self.save_dir}")
        
        # Save config
        torch.save(vars(self.args), os.path.join(self.save_dir, "config.pt"))
        
        # Initialize wandb
        config_dict = vars(self.args).copy()
        config_dict.update({
            "tokenizer_model_max_length": self.tokenizer.model_max_length,
            "task": "calibrator",
        })
        
        if not self.args.eval_only:
            wandb.init(project=self.args.wandb_project, config=config_dict)
            print(f"Wandb initialized with project: token_classification_prefix")
        
        print(
            f"Effective batch size: {self.args.effective_batch_size}"
            f" (batch_size={self.args.batch_size} * gradient_accumulation_steps={self.args.gradient_accumulation_steps})"
        )
    
    def _setup_training(self):
        """Setup training components."""
        print("Setting up training components...")
        
        # Create training arguments
        training_args = self._create_training_arguments()
        
        # Create data collator
        data_collator = self._create_data_collator()
        
        # Get completion metadata
        num_completions_per_prompt = getattr(self.save_dict["args"], "num_completions_per_prompt", None)
        eval_num_completions_per_prompt = getattr(self.eval_save_dict["args"], "num_completions_per_prompt", None)
        
        # Create compute_metrics function
        compute_metrics = create_compute_metrics(
            num_completions_per_prompt=eval_num_completions_per_prompt,
            eval_temps=self.args.eval_temps,
            compute_tradeoff=self.args.eval_tradeoff,
            eval_ds=self.eval_dataset,
            save_dir=self.save_dir,
            group_softmax=self.model.config.group_softmax,
            sum_group_softmax=self.model.config.sum_group_softmax,
            group_size=self.model.config.group_size,
            wm=self.model.wm_group_size is not None,
            save_pred=self.args.save_pred,
            eval_completion_path=self.args.eval_completion_path if self.args.eval_precomputed_path is None else None,
            eval_balance_difficulty=self.args.eval_balance_difficulty,
            eval_num_prompts=self.args.eval_num_prompts,
            last_rollout_only=self.args.last_rollout_only,
            eval_precomputed_path=self.args.eval_precomputed_path,
            eval_select_difficulty=self.args.eval_select_difficulty,
            dataset_seed=self.args.dataset_seed,
            num_difficulty_bins=self.args.num_difficulty_bins,
            model_name=self.args.model_name,
            tokenizer_name=self.args.tokenizer_name,
        )
        
        # Create trainer
        self.trainer = BinaryTokenRewardTrainer(
            model=self.model,
            args=training_args,
            train_dataset=self.dataset,
            eval_dataset=self.eval_dataset,
            processing_class=self.tokenizer,
            data_collator=data_collator,
            compute_metrics=compute_metrics,
            config_args=self.args,
            num_completions_per_prompt=num_completions_per_prompt,
            sampler_group_size=self.args.sampler_group_size if self.args.sampler_group_size is not None else num_completions_per_prompt,
            optimizers=(create_optimizer(self.model, self.args), None),
            group_softmax=self.model.config.group_softmax,
            sum_group_softmax=self.model.config.sum_group_softmax,
            group_size=self.model.config.group_size,
        )
        
        print("Training setup completed")
    
    
    def _create_training_arguments(self):
        """Create transformers TrainingArguments."""
        
        return TrainingArguments(
            output_dir=self.save_dir,
            eval_strategy="steps" if self.args.eval_steps is not None else "no",
            eval_steps=self.args.eval_steps,
            eval_on_start=self.args.eval_on_start or (self.args.load_model_from is not None),
            save_strategy="steps" if self.args.save_steps is not None else "no",
            save_steps=self.args.save_steps,
            save_total_limit=10 if not self.args.eval_only else None,
            save_only_model=True,
            logging_steps=1,
            num_train_epochs=self.args.epochs if not self.args.eval_only else 0,
            per_device_train_batch_size=self.args.batch_size,
            per_device_eval_batch_size=self.args.batch_size,
            gradient_accumulation_steps=self.args.gradient_accumulation_steps,
            gradient_checkpointing=not self.args.eval_only,
            # learning_rate=self.args.lr,
            lr_scheduler_type=self.args.lr_scheduler_type,
            max_grad_norm=self.args.max_grad_norm,
            warmup_ratio=0.05,
            seed=self.args.seed,
            bf16=False,
            fp16=False,
            report_to="wandb" if not self.args.eval_only else "none",
            dataloader_drop_last=True,
            dataloader_pin_memory=False,  # Disable pin_memory since collator handles GPU placement
            remove_unused_columns=False,  # keep labels aligned with input_ids
        )
    
    def _create_data_collator(self):
        """Create data collator."""
        if self.precomputed_train_data is not None:
            print("Will use precomputed hidden states for training")
        if self.precomputed_eval_data is not None:
            print("Will use precomputed hidden states for evaluation")

        print("Using hidden state data collator (handles both precomputed and base model processing)")
        return HiddenStateDataCollator(
            base_model=self.base_model,
            lm_head=self.lm_head,
            tokenizer=self.tokenizer,
            args=self.args,
            training_device=self.device,
            precomputed_train_data=self.precomputed_train_data,
            precomputed_eval_data=self.precomputed_eval_data,
        )
    
    def _execute_training(self):
        """Execute training or evaluation."""
        if self.args.eval_only:
            print("Running evaluation only (no training)...")
            if self.trainer.eval_dataset is not None:
                eval_results = self.trainer.evaluate()
                print(f"Evaluation results: {eval_results}")
            else:
                print("Warning: No validation dataset available for evaluation.")
        else:
            print("Starting training...")
            self.trainer.train(resume_from_checkpoint=self.args.resume_from_checkpoint)
            self.trainer.save_model(os.path.join(self.save_dir, "final"))
            print("Training completed and model saved.")
            
            # Run final evaluation after training
            if self.args.eval_steps is None:
                print("Running final evaluation...")
                final_eval_results = self.trainer.evaluate()
                print(f"Final evaluation results: {final_eval_results}")
