import logging
from pathlib import Path

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
import lightning.pytorch as pl
from lightning.pytorch.loggers import WandbLogger
from transformers import Gemma3ForCausalLM, GPTNeoXForCausalLM, Qwen3ForCausalLM

from matt.datamodule import AIMDataModule
from matt.aim import AttentionInfluenceModelingTask
from matt.modeling.gemma3 import Gemma3ForAIM
from matt.modeling.gpt_neox import GPTNeoXForAIM
from matt.modeling.qwen3 import Qwen3ForAIM
from matt.callbacks import (
    PeakVRAMMonitorCallback,
    FLOPSMonitorCallback,
    TotalTrainingTimeCallback,
)


logger = logging.getLogger(__name__)


def train(
    cfg: DictConfig,
    datamodule: pl.LightningDataModule,
    task: pl.LightningModule
):
    wandb_logger = WandbLogger(project='???')
    print(wandb_logger.experiment.id)

    peak_vram_monitor = PeakVRAMMonitorCallback()
    flops_monitor = FLOPSMonitorCallback()
    total_training_time_monitor = TotalTrainingTimeCallback()

    # Use specified checkpoint name if provided, otherwise use default naming
    if cfg.checkpoint_name is not None:
        filename = cfg.checkpoint_name
    else:
        filename = f'{wandb_logger.experiment.name}_{cfg.model.name}_step{{step}}_loss{{train_loss:.3f}}'
    
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath=cfg.checkpoint_output_dir,
        filename=filename,
        save_last=True,
        save_top_k=-1,
        auto_insert_metric_name=False,
        every_n_train_steps=cfg.task.save_every_n_steps,
    )
    learning_rate_monitor = pl.callbacks.LearningRateMonitor(logging_interval='step')

    # TODO: add grad_norm logging 
    # https://lightning.ai/docs/pytorch/stable/debug/debugging_intermediate.html#look-out-for-exploding-gradients
    trainer = pl.Trainer(
        logger=wandb_logger,
        callbacks=[
            checkpoint_callback,
            learning_rate_monitor,
            peak_vram_monitor,
            flops_monitor,
            total_training_time_monitor,
        ],
        profiler=cfg.profiler,
        # overfit_batches=4,
        # fast_dev_run=True,
        # detect_anomaly=True,
        max_epochs=cfg.task.max_epochs,
        max_steps=cfg.task.max_steps,
        val_check_interval=cfg.task.val_check_interval,
        log_every_n_steps=cfg.task.log_every_n_steps,
        accelerator=cfg.accelerator,
        devices=cfg.devices,
        strategy=cfg.task.strategy,
        precision=cfg.task.precision,
        gradient_clip_val=cfg.task.gradient_clip_val,
        gradient_clip_algorithm=cfg.task.gradient_clip_algorithm,
        accumulate_grad_batches=cfg.task.accumulate_grad_batches,
        enable_model_summary=True,
    )
    # if no checkpoint_path is passed, then it is None, thus the model will start from the very beginning
    trainer.fit(
        task,
        datamodule=datamodule,
        # ckpt_path=cfg.model.checkpoint_path, # TODO: add checkpoint path?
    )


@hydra.main(config_path='matt/conf', config_name='config', version_base=None)
def main(cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    pl.seed_everything(cfg.seed)

    match cfg.datamodule.name:
        case 'aim':
            datamodule = AIMDataModule(cfg)
        case _:
            raise ValueError(
                'unknown datamodule, can be either `aim` or ...'
            )

    match cfg.model.name:
        case 'gemma-3':
            teacher_model = Gemma3ForCausalLM.from_pretrained(
                cfg.model.teacher_checkpoint_path,
                torch_dtype=torch.bfloat16,
                attn_implementation='eager',
            )
            teacher_model = Gemma3ForAIM(
                teacher_model,
                n_layers=cfg.model.n_layers,
                all_layer_outputs=cfg.task.align_all_layers,
            )

            student_frozen_embeddings_mask_path = cfg.model.student_frozen_embeddings_mask \
                or (Path(cfg.model.student_checkpoint_path) / 'original_embeddings_mask.pt')

            student_model = Gemma3ForCausalLM.from_pretrained(
                cfg.model.student_checkpoint_path,
                torch_dtype=torch.bfloat16,
                attn_implementation='eager',
            )
            student_model = Gemma3ForAIM(
                student_model,
                n_layers=cfg.model.n_layers,
                frozen_embeddings_mask=torch.load(student_frozen_embeddings_mask_path) \
                    if cfg.model.freeze_original_embeddings else None,
                all_layer_outputs=cfg.task.align_all_layers,
            )
        case 'pythia' | 'gpt-neox':
            teacher_model = GPTNeoXForCausalLM.from_pretrained(
                cfg.model.teacher_checkpoint_path,
                torch_dtype=torch.bfloat16,
                attn_implementation='eager',
            )
            teacher_model = GPTNeoXForAIM(
                teacher_model,
                n_layers=cfg.model.n_layers,
                all_layer_outputs=cfg.task.align_all_layers,
            )

            student_frozen_embeddings_mask_path = cfg.model.student_frozen_embeddings_mask \
                or (Path(cfg.model.student_checkpoint_path) / 'original_embeddings_mask.pt')

            student_model = GPTNeoXForCausalLM.from_pretrained(
                cfg.model.student_checkpoint_path,
                torch_dtype=torch.bfloat16,
                attn_implementation='eager',
            )
            student_model = GPTNeoXForAIM(
                student_model,
                n_layers=cfg.model.n_layers,
                frozen_embeddings_mask=torch.load(student_frozen_embeddings_mask_path) \
                    if cfg.model.freeze_original_embeddings else None,
                all_layer_outputs=cfg.task.align_all_layers,
            )
        case 'qwen3':
            teacher_model = Qwen3ForCausalLM.from_pretrained(
                cfg.model.teacher_checkpoint_path,
                torch_dtype=torch.bfloat16,
                attn_implementation='eager',
            )
            teacher_model = Qwen3ForAIM(
                teacher_model,
                n_layers=cfg.model.n_layers,
                all_layer_outputs=cfg.task.align_all_layers,
            )

            student_frozen_embeddings_mask_path = cfg.model.student_frozen_embeddings_mask \
                or (Path(cfg.model.student_checkpoint_path) / 'original_embeddings_mask.pt')

            student_model = Qwen3ForCausalLM.from_pretrained(
                cfg.model.student_checkpoint_path,
                torch_dtype=torch.bfloat16,
                attn_implementation='eager',
            )
            student_model = Qwen3ForAIM(
                student_model,
                n_layers=cfg.model.n_layers,
                frozen_embeddings_mask=torch.load(student_frozen_embeddings_mask_path) \
                    if cfg.model.freeze_original_embeddings else None,
                all_layer_outputs=cfg.task.align_all_layers,
            )
        case _:
            raise ValueError(
                'unknown model, can be either `gemma-3`, `pythia`, `qwen3` or ...'
            )

    match cfg.task.name:
        case 'aim':
            if 'teacher_model' not in locals() or 'student_model' not in locals():
                raise ValueError('teacher_model and student_model must be defined for AIM task. '
                                 'You have probably chosen a wrong model.')
            task = AttentionInfluenceModelingTask(cfg, teacher_model, student_model)
        case _:
            raise ValueError(
                'unknown task, can be either `aim` or ...'
            )

    train(cfg, datamodule, task)


if __name__ == '__main__':
    main()
