# train_unified.py
import torch
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor, RichProgressBar
import argparse
import yaml
import shutil # Import shutil for file copying
import os # Import os for makedirs
import wandb
from pytorch_lightning.callbacks import ModelCheckpoint
from pathlib import Path # For constructing paths

# --- Add this section for multiprocessing sharing strategy ---
try:
    torch.multiprocessing.set_sharing_strategy('file_system')
    print("Successfully set torch multiprocessing sharing strategy to 'file_system'.")
except RuntimeError:
    # This might happen if the strategy is already set or on platforms where it's not applicable/changeable
    print("Could not set torch multiprocessing sharing strategy (possibly already set or not supported).")
# --- End of section ---

from model import NeuralToPhonemeTransformer # Your main model
from model_attention_vis_v1 import NeuralToPhonemeTransformerAttentionRollout # Your main model
from data_modules import ECoGDataModule, CurriculumUpdateCallback # Your ECoG data module

def main_train_unified(args):
    # --- Set seed for reproducibility ---
    if args.seed is not None:
        print(f"Setting global seed to {args.seed}")
        pl.seed_everything(args.seed, workers=True)

    # --- Check for W&B sweep run and initialize ---
    is_sweep = os.environ.get("WANDB_SWEEP_ID") is not None
    if is_sweep:
        # In a sweep, wandb.init() is called by the agent.
        # We can initialize it here to get the config object.
        # The project/entity will be inherited from the sweep.
        wandb.init()
        # For sweeps, we'll get CLI args from wandb.config
        args.training_stage = wandb.config.training_stage
        # if 'initial_encoder_checkpoint_path' in wandb.config:
        #     args.initial_encoder_checkpoint_path = wandb.config.initial_encoder_checkpoint_path
        # if 'initial_decoder_checkpoint_path' in wandb.config:
        #     args.initial_decoder_checkpoint_path = wandb.config.initial_decoder_checkpoint_path
    
    with open(args.config_file, 'r') as f:
        config = yaml.safe_load(f)

    # --- Combine CLI args and config file ---
    # CLI args can override config file settings if needed, or be additive
    # For hparams, we'll primarily use the config file, but training_stage is crucial from CLI
    
    # Determine effective hparams for the model and datamodule
    # Model params, data params, and train params specific to the stage
    hparams_dict = {
        **config.get('model_params', {}),
        **config.get('data_params_ecog', {}), # ECoG specific data params
        **config.get('train_params_common', {}) # Common training params
    }
    # Add stage-specific training params
    if args.training_stage == 'encoder_only':
        hparams_dict.update(config.get('train_params_encoder_only', {}))
    elif args.training_stage == 'joint_teacher_forcing':
        hparams_dict.update(config.get('train_params_joint_tf', {}))
    elif args.training_stage == 'joint_sequential_generation':
        hparams_dict.update(config.get('train_params_joint_sg', {}))
    elif args.training_stage == 'secondary_only':
        hparams_dict.update(config.get('train_params_secondary_only', {}))
    # elif args.training_stage == 'bart_tf':
    #     hparams_dict.update(config.get('train_params_bart_tf', {}))
    
    # --- If in a sweep, override hparams with sweep config ---
    if is_sweep:
        print("RUNNING AS A W&B SWEEP: Overriding config with sweep parameters.")
        sweep_hparams = {k: v for k, v in wandb.config.items()}
        hparams_dict.update(sweep_hparams)
        print(f"Sweep Hparams: {sweep_hparams}")

    # Crucial: Set training_stage and checkpoint paths in hparams for the model
    hparams_dict['training_stage'] = args.training_stage
    hparams_dict['initial_encoder_checkpoint_path'] = args.initial_encoder_checkpoint_path
    hparams_dict['initial_decoder_checkpoint_path'] = args.initial_decoder_checkpoint_path
    # Add other relevant args to hparams if model/datamodule needs them
    hparams_dict['batch_size'] = hparams_dict.get('batch_size', 32) # Ensure batch_size is there
    hparams_dict['num_workers'] = hparams_dict.get('num_workers', 4)
    hparams_dict['train_data_fraction'] = args.train_data_fraction
    hparams_dict['seed'] = args.seed # Pass seed to hparams

    if 'PHONEME_MAP' in config:
        hparams_dict['PHONEME_MAP'] = config['PHONEME_MAP']
    else:
        print("Warning: PHONEME_MAP not found in config. Phoneme string conversion in validation will fail.")

    # --- Determine Monitor Metric (Single Source of Truth) ---
    # This logic runs BEFORE the model is initialized to ensure all components
    # (callbacks, scheduler) use the same, correct metric.
    monitor_metric = None
    monitor_mode = 'min' 

    is_bart_active = hparams_dict.get('train_bart_text_decoder', False)
    is_whisper_active = hparams_dict.get('train_whisper_text_decoder', False)

    if args.training_stage == 'secondary_only':
        if is_whisper_active:
            # Prioritize Whisper if both are active, or if only Whisper is.
            monitor_metric = 'val_wer_whisper_epoch'
        elif is_bart_active:
            monitor_metric = 'val_wer_bart_epoch'
        else:
            raise ValueError("For 'secondary_only' stage, at least one of 'train_bart_text_decoder' or 'train_whisper_text_decoder' must be true.")
    
    elif args.training_stage in ['joint_teacher_forcing', 'joint_sequential_generation']:
        # For joint stages, default to phoneme PER, but override if a text head is clearly the focus.
        monitor_metric = 'val_per_sg_epoch'
        if is_whisper_active:
             monitor_metric = 'val_wer_whisper_epoch' # Prioritize text WER if available
        elif is_bart_active:
             monitor_metric = 'val_wer_bart_epoch'
    
    elif args.training_stage == 'encoder_only':
        monitor_metric = 'val_total_aux_loss_epoch'

    if monitor_metric is None:
        # Fallback in case of an unhandled stage
        raise ValueError(f"Could not determine a monitor metric for training stage: {args.training_stage}")

    print(f"--- Determined monitor metric for callbacks and scheduler: '{monitor_metric}' ---")
    # CRITICAL FIX: Override the config's scheduler_monitor to ensure consistency
    # before the hparams object is created and passed to the model.
    hparams_dict['scheduler_monitor'] = monitor_metric

    hparams = argparse.Namespace(**hparams_dict)

    # --- Initialize DataModule ---
    ecog_datamodule = ECoGDataModule(hparams) # Pass all relevant hparams

    # --- Initialize Model ---
    if "gigant_attention" in args.wandb_project:
        print("Using ATTENTION VISUALISATION MODEL NeuralToPhonemeTransformerAttentionRollout")
        model = NeuralToPhonemeTransformerAttentionRollout(hparams) # Pass all relevant hparams
    else:
        model = NeuralToPhonemeTransformer(hparams) # Pass all relevant hparams

    # --- Initialize Logger ---
    run_name = None # Default to None, let W&B name the run
    if not is_sweep:
        # Only construct a run name if not in a sweep
        base_name = f"{args.training_stage}-{args.run_name_suffix}"
        if args.seed is not None:
            run_name = f"SEED_{args.seed}-{base_name}"
        else:
            run_name = base_name

        if args.train_data_fraction < 1.0:
            run_name += f"-{int(args.train_data_fraction*100)}pct"
        if args.initial_encoder_checkpoint_path:
            run_name += "-encFT"
        if args.initial_decoder_checkpoint_path:
            run_name += "-decFT"
        if args.resume_from_checkpoint:
            run_name += "-Resumed"

    # In a sweep, wandb_logger will automatically attach to the run initialized by wandb.init()
    wandb_logger = WandbLogger(
        project=args.wandb_project,
        # entity=args.wandb_entity,
        name=run_name,
        config=hparams_dict, # Log all hyperparameters
        log_model=False # Disable default model logging for sweeps to save space
    )

    run_name_for_path = run_name 
    project_name_for_path = args.wandb_project # Default to arg
    
    # Try to get the definitive project and run name from the WandbLogger instance
    # after it has been initialized, as WandB might assign a unique ID or adjust names.
    if hasattr(wandb_logger, 'experiment') and wandb_logger.experiment is not None:
        # Prefer names from the initialized logger experiment
        exp_project = getattr(wandb_logger.experiment, 'project', None)
        if exp_project: # wandb.experiment.project can be None
            project_name_for_path = exp_project
        
        exp_run_name = getattr(wandb_logger.experiment, 'name', None)
        if exp_run_name:
            run_name_for_path = exp_run_name
        elif hasattr(wandb_logger, 'id') and wandb_logger.id is not None:
            # Fallback to run_id if name is not available from experiment but id is
            run_name_for_path = f"run_{wandb_logger.id}"

    elif hasattr(wandb_logger, 'project_name') and wandb_logger.project_name is not None:
        # Fallback for some logger versions/states
        project_name_for_path = wandb_logger.project_name
        if hasattr(wandb_logger, 'name') and wandb_logger.name is not None:
            run_name_for_path = wandb_logger.name

    # --- Create Experiment Directories and Save Config ---
    # MODIFIED to include experiment_group_name for local organization
    base_save_path_with_group = Path("experiments") / str(project_name_for_path)
    if args.experiment_group_name:
        base_save_path_with_group = base_save_path_with_group / args.experiment_group_name

    base_run_dir = base_save_path_with_group / str(run_name_for_path)

    # --- Set the definitive run directory on the model object ---
    # This ensures the model saves validation CSVs to the correct location.
    model.base_run_dir = str(base_run_dir)

    config_save_dir = base_run_dir / "config"
    os.makedirs(config_save_dir, exist_ok=True) # Create the config directory if it doesn't exist

    try:
        original_config_path = Path(args.config_file)
        # destination_config_path = config_save_dir / original_config_path.name
        destination_config_path = config_save_dir / "config_copied.yaml" # New line
        shutil.copy(original_config_path, destination_config_path)
        print(f"Saved configuration file to: {destination_config_path}")
    except Exception as e:
        print(f"Warning: Could not save configuration file. Error: {e}")

    # --- Callbacks ---
    checkpoint_dir_for_callback = str(base_run_dir / "checkpoints")
    
    # The monitor_metric is now the single source of truth, determined before model init.
        
    # --- Create Callbacks ---
    early_stopping_callback = EarlyStopping(
        monitor=monitor_metric,
        patience=getattr(hparams, 'early_stopping_patience', 10), # Use getattr
        mode=monitor_mode,
        verbose=True
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')
    progress_bar_callback = RichProgressBar()
    callbacks = [early_stopping_callback, lr_monitor, progress_bar_callback]
    
    # --- Handle Model Checkpointing ---
    # NEW: Determine the value for save_top_k based on precedence: CLI > config > default.
    if args.save_top_k is not None:
        save_top_k_value = args.save_top_k
        print(f"Using save_top_k={save_top_k_value} from command-line argument.")
    else:
        save_top_k_value = getattr(hparams, 'save_top_k', 1)
        print(f"Using save_top_k={save_top_k_value} from config file or default.")

    # NEW: Store checkpoint callbacks in variables to access their paths later.
    per_checkpoint_callback = None
    wer_checkpoint_callback = None
    whisper_wer_checkpoint_callback = None
    single_checkpoint_callback = None

    if args.save_dual_checkpoints:
        print("Dual checkpointing enabled: saving best PER and best WER.")
        
        # 1. Add PER checkpoint (always, unless encoder-only)
        if args.training_stage != 'encoder_only':
            per_checkpoint_callback = ModelCheckpoint(
                dirpath=checkpoint_dir_for_callback,
                filename=f"{args.training_stage}-{{epoch:02d}}-per={{val_per_sg_epoch:.3f}}",
                monitor='val_per_sg_epoch',
                mode='min',
                save_top_k=save_top_k_value
            )
            callbacks.append(per_checkpoint_callback)
        
        # 2. Add WER checkpoint (only if BART is active in the model)
        if hasattr(model, 'bart_text_decoder_active') and model.bart_text_decoder_active:
            wer_checkpoint_callback = ModelCheckpoint(
                dirpath=checkpoint_dir_for_callback,
                filename=f"{args.training_stage}-{{epoch:02d}}-wer={{val_wer_bart_epoch:.3f}}",
                monitor='val_wer_bart_epoch',
                mode='min',
                save_top_k=save_top_k_value
            )
            callbacks.append(wer_checkpoint_callback)
        
        # 2.B Add Whisper WER checkpoint (if Whisper is active)
        if hasattr(model, 'whisper_text_decoder_active') and model.whisper_text_decoder_active:
            whisper_wer_checkpoint_callback = ModelCheckpoint(
                dirpath=checkpoint_dir_for_callback,
                filename=f"{args.training_stage}-{{epoch:02d}}-wer_whisper={{val_wer_whisper_epoch:.3f}}",
                monitor='val_wer_whisper_epoch',
                mode='min',
                save_top_k=save_top_k_value
            )
            callbacks.append(whisper_wer_checkpoint_callback)

        # 3. Safety Check: If no checkpointers were added, add the default one.
        if not per_checkpoint_callback and not wer_checkpoint_callback and not whisper_wer_checkpoint_callback:
            print(f"WARNING: Dual checkpointing was enabled, but no conditions were met. "
                  f"Falling back to a single checkpoint on '{monitor_metric}'.")
            single_checkpoint_callback = ModelCheckpoint(
                dirpath=checkpoint_dir_for_callback,
                filename=f'{args.training_stage}-{{epoch:02d}}-{{{monitor_metric}:.3f}}',
                monitor=monitor_metric,
                mode=monitor_mode,
                save_top_k=save_top_k_value
            )
            callbacks.append(single_checkpoint_callback)
    else:
        # Original single-checkpoint logic
        print(f"Single checkpointing enabled, monitoring: '{monitor_metric}'")
        single_checkpoint_callback = ModelCheckpoint(
            dirpath=checkpoint_dir_for_callback,
            filename=f'{args.training_stage}-{{epoch:02d}}-{{{monitor_metric}:.3f}}',
            monitor=monitor_metric,
            mode=monitor_mode,
            save_top_k=save_top_k_value
        )
        callbacks.append(single_checkpoint_callback)

    # Use getattr for checking curriculum learning flag
    if getattr(hparams, 'curriculum_learning_enabled', False) and \
       args.training_stage == 'joint_sequential_generation':
        # Ensure CurriculumUpdateCallback is imported
        from data_modules import CurriculumUpdateCallback
        callbacks.append(CurriculumUpdateCallback())

    # --- Determine checkpoint path for weight loading (from config or CLI) ---
    # This is for INITIALIZING a model with weights from a potentially different model/stage.
    # It only loads the weights, not the full trainer state.
    # Priority: CLI argument > config file
    load_weights_ckpt_path = args.load_weights_from_checkpoint or config.get('ckpt_path', None)
    if not args.resume_from_checkpoint and load_weights_ckpt_path:
        print(f"INITIALIZING model weights from checkpoint: {load_weights_ckpt_path}")
        # Note: strict=False is important here as the architectures might not perfectly match,
        # especially when adding new heads like BART.
        checkpoint = torch.load(load_weights_ckpt_path, map_location='cpu')
        original_state_dict = checkpoint.get('state_dict', checkpoint)

        # --- NEW: Filter out weights from incompatible parts of the model ---
        # This is necessary when switching between BART-base and BART-large, or different Whisper models,
        # as their layers have the same names but different shapes. `strict=False` only handles
        # missing/extra keys, not shape mismatches for existing keys.
        filtered_state_dict = {}
        # Add whisper model parts to the incompatible prefixes
        incompatible_prefixes = ["bart_model.", "ecog_to_bart_hidden_projection.",
                                 "whisper_model.", "ecog_to_whisper_hidden_projection."]
        
        for k, v in original_state_dict.items():
            is_incompatible = False
            # Check if the key belongs to a model part we might be swapping out
            for prefix in incompatible_prefixes:
                if k.startswith(prefix):
                    # Now, check if this part is actually active in the *new* model.
                    # If the prefix matches the active text decoder, we should try to load it.
                    # This logic assumes we are not trying to load BART weights into a Whisper model or vice-versa.
                    if (prefix.startswith("bart") and model.bart_text_decoder_active) or \
                       (prefix.startswith("whisper") and model.whisper_text_decoder_active):
                        # This part is active, so we want to keep the weights.
                        pass
                    else:
                        # The prefix matches a part that is NOT active in the new model, so filter it out.
                        is_incompatible = True
                    break
            
            if not is_incompatible:
                filtered_state_dict[k] = v

        print(f"Loading weights. Total keys in checkpoint: {len(original_state_dict)}. After filtering: {len(filtered_state_dict)} keys.")
        
        # Now load with strict=False. It will warn about missing keys (the ones we filtered out), which is expected.
        model.load_state_dict(filtered_state_dict, strict=False)
        print("Model weights initialized.")

    # --- Determine checkpoint path for RESUMING training (from CLI) ---
    # This is for RESUMING a training run from a specific state, including optimizer, epoch, etc.
    # The CLI argument `--resume_from_checkpoint` is used for this.
    resume_ckpt_path = args.resume_from_checkpoint
    if resume_ckpt_path:
        print(f"RESUMING training from checkpoint: {resume_ckpt_path}")
        if load_weights_ckpt_path:
            print("Warning: Both `ckpt_path` in config and `--resume_from_checkpoint` on CLI were provided.")
            print("The `--resume_from_checkpoint` will be used to restore the full trainer state.")


    
    if (getattr(hparams, 'curriculum_learning_enabled', False) and 
                              config.get('training_stage') == 'joint_sequential_generation'):
        print("TRAINER WITH DATALOADER RELOADING EVERY EPOCH")
        # --- Trainer ---
        trainer = pl.Trainer(
            logger=wandb_logger,
            callbacks=callbacks,
            max_epochs=getattr(hparams, 'max_epochs', 50), # Use getattr
            accelerator=args.accelerator,
            devices=[args.devices],
            precision=getattr(hparams, 'precision', 32), # Use getattr
            gradient_clip_val=getattr(hparams, 'gradient_clip_val', None), # Use getattr (default None if not specified)
            accumulate_grad_batches=getattr(hparams, 'accumulate_grad_batches', 1), # Use getattr
            strategy=getattr(hparams, 'strategy', 'auto'), # Use getattr (default None if not specified),
            reload_dataloaders_every_n_epochs=1  # Conditionally set
        )
    else:
        trainer = pl.Trainer(
            logger=wandb_logger,
            callbacks=callbacks,
            max_epochs=getattr(hparams, 'max_epochs', 50), # Use getattr
            accelerator=args.accelerator,
            devices=[args.devices],
            precision=getattr(hparams, 'precision', 32), # Use getattr
            gradient_clip_val=getattr(hparams, 'gradient_clip_val', None), # Use getattr (default None if not specified)
            accumulate_grad_batches=getattr(hparams, 'accumulate_grad_batches', 1), # Use getattr
            strategy=getattr(hparams, 'strategy', 'auto'), # Use getattr (default None if not specified),
            # reload_dataloaders_every_n_epochs=1  # Conditionally set
        )

    print(f"Starting training for stage: {args.training_stage}")
    print(f"Hyperparameters Namespace: {vars(hparams)}")
    
    trainer.fit(
        model, 
        datamodule=ecog_datamodule,
        ckpt_path=resume_ckpt_path # This will be None if not resuming
    )

    print(f"Training for stage {args.training_stage} finished.")

    # --- Print Best Checkpoint Paths ---
    # This logic ensures the two-stage script can always find its required checkpoint.
    # The `grep` in the shell script looks for "Best checkpoint for this stage:".
    primary_checkpoint_callback = per_checkpoint_callback or single_checkpoint_callback
    
    if primary_checkpoint_callback and primary_checkpoint_callback.best_model_path:
        # This is the primary output for the two-stage script.
        print(f"Best checkpoint for this stage: {primary_checkpoint_callback.best_model_path}")
    
    # Also print the best WER checkpoint if it exists, for informational purposes.
    if wer_checkpoint_callback and wer_checkpoint_callback.best_model_path:
        print(f"Informational - Best WER Checkpoint: {wer_checkpoint_callback.best_model_path}")

    # Also print the best Whisper WER checkpoint if it exists.
    if whisper_wer_checkpoint_callback and whisper_wer_checkpoint_callback.best_model_path:
        print(f"Informational - Best Whisper WER Checkpoint: {whisper_wer_checkpoint_callback.best_model_path}")

    wandb.finish()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Unified Training Script for ECoG-to-Phoneme Model")
    parser.add_argument('--config_file', type=str, required=True, help="Path to YAML configuration file.")
    parser.add_argument('--training_stage', type=str, required=True, 
                        choices=['encoder_only', 'joint_teacher_forcing', 'joint_sequential_generation', 'secondary_only', "bart_tf"],
                        help="Current training stage.")
    
    parser.add_argument('--initial_encoder_checkpoint_path', type=str, default=None, 
                        help="Path to pre-trained encoder checkpoint (for joint stages or continuing encoder_only).")
    parser.add_argument('--initial_decoder_checkpoint_path', type=str, default=None,
                        help="Path to pre-trained standalone decoder checkpoint (for joint stages).")
    parser.add_argument('--resume_from_checkpoint', type=str, default=None,
                        help="Path to a full model checkpoint to resume training for the current stage.")
    parser.add_argument('--load_weights_from_checkpoint', type=str, default=None,
                        help="Path to a checkpoint to initialize model weights from (for starting a new stage).")

    parser.add_argument('--train_data_fraction', type=float, default=1.0,
                        help="Fraction of the training data to use (for scaling law experiments).")

    parser.add_argument('--wandb_project', type=str, default="aa", help="WandB project name.")
    parser.add_argument('--experiment_group_name', type=str, default=None, help="Name for a sub-directory to group runs locally under experiments/[wandb_project]/")
    # parser.add_argument('--wandb_entity', type=str, default=None, help="WandB entity (username or team).")
    parser.add_argument('--run_name_suffix', type=str, default="exp1", help="Suffix for WandB run name and checkpoint dir.")
    
    parser.add_argument('--accelerator', type=str, default='gpu' if torch.cuda.is_available() else 'cpu', help="Accelerator (gpu, cpu, tpu, mps, etc.).")
    parser.add_argument('--devices', type=int, default=0, help="Number of devices to use (e.g., GPUs).")
    parser.add_argument('--seed', type=int, default=None, help="Global seed for reproducibility.")
    parser.add_argument('--save_dual_checkpoints', action='store_true', help="If set, save checkpoints for both PER and WER metrics.")
    parser.add_argument('--save_top_k', type=int, default=None, help="How many top checkpoints to save (per metric). Overrides config file setting.")
    
    cli_args = parser.parse_args()
    main_train_unified(cli_args)