# scripts/train_cbm.py
import os
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import hydra
from omegaconf import DictConfig, OmegaConf
import sys
import numpy as np
import wandb
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from concept_alignment.models.utils import load_pretrained_llm
from concept_alignment.models.cbm import ConceptBottleneckModel, CBMLightningModule
from concept_alignment.models.property_predictor import PropertyPredictor
from concept_alignment.data.dataset import MolecularPropertyDataset, create_dataloaders

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)

@hydra.main(version_base=None, config_path="../config", config_name="config")
def main(cfg: DictConfig):
    """Main function to train the Concept Bottleneck Model."""
    # Print configuration
    logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
    
    # Set random seed for reproducibility
    pl.seed_everything(cfg.seed)
    
    # Set device
    if cfg.device == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = cfg.device

    if cfg.use_wandb:
        os.environ["WANDB_MODE"] = "offline"
        wandb_logger = WandbLogger(
            project=cfg.wandb.project,
            name=cfg.wandb.name
        )
        logger_to_use = wandb_logger
    else:
        logger_to_use = None
    
    # Load model and tokenizer
    model, tokenizer = load_pretrained_llm(model_name=cfg.model.name, tokenizer_name=cfg.model.tokenizer, device=device)

    special_tokens = [
        '[WAVELENGTH]', '[/WAVELENGTH]', '[F_OSC]', '[/F_OSC]', 
        '[QED]', '[/QED]', '[LOGP]', '[/LOGP]', 
        '[START_SMILES]', '[END_SMILES]', '[SEP]'
    ]
    tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
    model.resize_token_embeddings(len(tokenizer))
    tokenizer.padding_side = 'left'
    
    # Get hidden dimension from model
    print("Model configuration:", model.config)
    hidden_dim = model.config.hidden_size
    logger.info(f"Model hidden dimension: {hidden_dim}")
    
    # Prepare data
    property_cols = cfg.data.property_columns
    
    # Create dataset
    dataset = MolecularPropertyDataset(
        data_path=cfg.data.path,
        tokenizer=tokenizer,
        model=model,
        property_cols=property_cols,
        layer_idx=cfg.model.layer_idx,
        max_length=cfg.data.max_length,
        device=device
    )

    train_loader, val_loader, test_loader = create_dataloaders(
        dataset, 
        batch_size=cfg.batch_size,
        train_ratio=cfg.data.train_ratio,
        val_ratio=cfg.data.val_ratio,
        seed=cfg.seed
    )
    
    # Initialize models
    cbm = ConceptBottleneckModel(
        hidden_dim=hidden_dim,
        num_concepts=cfg.cbm.num_concepts,
        concept_dim=hidden_dim
    )
    
    property_predictor = PropertyPredictor(
        num_concepts=cfg.cbm.num_concepts,
        concept_dim=hidden_dim,
        num_properties=len(property_cols)
    )
    
    # Create Lightning module
    cbm_module = CBMLightningModule(
        cbm=cbm,
        property_predictor=property_predictor,
        learning_rate=cfg.training.learning_rate,
        lambda_ortho=cfg.training.lambda_ortho,
        property_cols=property_cols
    )
    
    # Setup callbacks
    checkpoint_callback = ModelCheckpoint(
        dirpath=cfg.save_dir,
        filename='best-cbm-{epoch:02d}-{val/total_loss:.4f}',
        save_top_k=3,
        monitor='val/total_loss',
        mode='min',
        save_last=True
    )
    
    # Batch checkpointing callback
    class BatchCheckpointCallback(pl.Callback):
        def __init__(self, save_every_n_batches, checkpoint_dir):
            self.save_every_n_batches = save_every_n_batches
            self.checkpoint_dir = checkpoint_dir
            
        def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
            if (batch_idx + 1) % self.save_every_n_batches == 0:
                epoch = trainer.current_epoch
                batch_checkpoint_dir = os.path.join(self.checkpoint_dir, f"epoch_{epoch}")
                os.makedirs(batch_checkpoint_dir, exist_ok=True)
                checkpoint_path = os.path.join(batch_checkpoint_dir, f"checkpoint_batch_{batch_idx}.ckpt")
                
                trainer.save_checkpoint(checkpoint_path)
                logger.info(f"Saved intermediate checkpoint at epoch {epoch}, batch {batch_idx}")
    
    batch_ckpt_callback = BatchCheckpointCallback(
        save_every_n_batches=cfg.training.save_every_n_batches,
        checkpoint_dir=cfg.training.checkpoint_dir
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    
    # Create trainer
    trainer = pl.Trainer(
        max_epochs=cfg.training.num_epochs,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        logger=logger_to_use,
        callbacks=[checkpoint_callback, batch_ckpt_callback, lr_monitor],
        deterministic=False,
        log_every_n_steps=1000
    )

    # Start training
    logger.info("Starting training...")
    if cfg.resume_from_checkpoint and cfg.resume_checkpoint_path:
        trainer.fit(cbm_module, train_loader, val_loader, ckpt_path=cfg.resume_checkpoint_path)
    else:
        trainer.fit(cbm_module, train_loader, val_loader)
    
    # Create directory if it doesn't exist
    save_dir = cfg.save_dir
    os.makedirs(save_dir, exist_ok=True)
    
    # Save the trained models
    cbm_path = os.path.join(save_dir, "cbm_model.pt")
    property_predictor_path = os.path.join(save_dir, "property_predictor.pt")
    
    torch.save(cbm_module.cbm.state_dict(), cbm_path)
    torch.save(cbm_module.property_predictor.state_dict(), property_predictor_path)
    
    logger.info(f"Models saved to {save_dir}")
    
    # Save a sample of concept embeddings for later use (interpretability analysis)
    sample_batch = next(iter(test_loader))
    sample_repr = sample_batch[0][0:1].to(device)
    cbm_module = cbm_module.to(device)
    with torch.no_grad():
        _, sample_concept_values, _ = cbm_module(sample_repr)
    
    sample_data = {
        'representation': sample_repr.cpu(),
        'concept_values': sample_concept_values.cpu(),
        'concept_embeddings': cbm_module.cbm.concept_embeddings.cpu(),
    }
    
    torch.save(sample_data, os.path.join(save_dir, "sample_concepts.pt"))
    logger.info("Saved sample concept values for later use")

    if cfg.use_wandb:
        wandb.finish()

if __name__ == "__main__":
    main()