"""Training script for baseline model with offline data."""

from pathlib import Path

import warnings
import hydra
import random
import torch
import numpy as np
from omegaconf import DictConfig

from src.models.benchmarks import TNP
from src.utils import DataAttr
from src.train_model import Trainer

try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("Warning: wandb not available. Install with 'pip install wandb' for experiment tracking.")

def set_seed(seed: int=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

class BaselineTrainer(Trainer):
    masked_forward: bool = False # TNPs & PFNv1 make their own masks; no need to mask here.

    def _build_model(self) -> TNP:
        """Build TNP model from config."""
        print("Building model...")
        
        if self.cfg.model._target_.split(".")[-1] == "PFN":
            # PFN model requires head_bucket_samples
            model_class = hydra.utils.get_class(self.cfg.model._target_)
            warnings.warn(
                "PFN requires prior samples to determine bucket borders. Use random samples for now."
            )
            prior_samples = torch.randn([self.cfg.model.head_num_buckets * 10])
            model = model_class(
                self.cfg.model.dim_x,
                self.cfg.model.dim_y,
                d_model=self.cfg.model.d_model,
                dim_feedforward=self.cfg.model.dim_feedforward,
                nhead=self.cfg.model.nhead,
                dropout=self.cfg.model.dropout,
                num_layers=self.cfg.model.num_layers,
                head_num_buckets=self.cfg.model.head_num_buckets,
                head_bucket_samples=prior_samples
            )
        else:
            model = hydra.utils.instantiate( self.cfg.model )

        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Model built: {trainable_params:,} trainable parameters (total: {total_params:,})")
        
        return model

    def _create_block_mask(self, batch: DataAttr) -> torch.Tensor:
        raise NotImplementedError("TNPs & PFNv1 make their own masks; no need to create block mask.")


@hydra.main(version_base=None, config_path="../configs/baselines", config_name="train_gp_pfn")
def main(cfg: DictConfig):
    """Main training function."""
    set_seed(cfg.seed)
    # Create checkpoint directory
    checkpoint_dir = Path(cfg.checkpoint.save_dir)
    checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # Initialize trainer
    trainer = BaselineTrainer(cfg)
    
    # Run training
    trainer.train()
    
    # Close wandb
    if trainer.use_wandb:
        wandb.finish()


if __name__ == "__main__":
    main()