import traceback
import torch
import torch_geometric
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Timer
from lightning.pytorch.profilers import AdvancedProfiler, SimpleProfiler
import hydra
from omegaconf import DictConfig, OmegaConf
import neptune
from neptune.utils import stringify_unsupported
from tgp.poolers import get_pooler
from tgp.data import PreCoarsening

# Local imports
from source.data import (
    EXPWL1DataModule,
    BenchHardDataModule,
    MultipartiteDataModule,
    TUDataModule,
    OGBDataModule,
    LRGBDataModule,
    ZINCDataModule,
)
from source.pl_modules import (
    SingleClassificationModule,
    MultiClassificationModule,
    RegressionModule,
)
from source.models import ClassificationModel, OGBModel
from source.utils import (
    NeptuneLogger,
    register_resolvers,
    find_devices,
    FilterCallback,
    log_gpu_memory_usage,
)

register_resolvers()
neptune.internal.operation_processors.async_operation_processor.logger.addFilter(
    FilterCallback()
)


@hydra.main(version_base=None, config_path="config", config_name="run_classification")
def run(cfg: DictConfig) -> float:
    try:
        return _run_classification(cfg)
    except Exception as e:
        # Extract key configuration details for error reporting
        pooler_name = cfg.get("pooler", {}).get("name", "unknown")
        dataset_name = cfg.get("dataset", {}).get("name", "unknown")
        task_id = cfg.get("task", "unknown")
        config_info = f"pooler={pooler_name}, dataset={dataset_name}, task={task_id}"

        print(f"❌ CLASSIFICATION_JOB_FAILED: {config_info} | Error: {str(e)}")
        print("Full traceback:")
        traceback.print_exc()

        # For multirun sweeps, we want individual job failures to not stop the entire sweep
        # Return a sentinel value to indicate failure
        return float("inf")  # or some other error indicator


def _run_classification(cfg: DictConfig) -> float:
    print(OmegaConf.to_yaml(cfg, resolve=True))

    # Log initial GPU state
    log_gpu_memory_usage("BEFORE_TRAINING")

    ### 🌱 Seed everything
    if "seed" in cfg.dataset.hparams:
        print(f"Setting seed to {cfg.dataset.hparams.seed}")
        torch_geometric.seed.seed_everything(cfg.dataset.hparams.seed)

    ### 📊 Load data
    pooler = get_pooler(
        cfg.pooler.name,
        in_channels=3,  # dummy values
        k=25,
        **cfg.pooler.hparams,
    )
    pre_transform = list(
        filter(
            None,
            (
                pooler.data_transforms(),
                None if pooler.is_trainable else PreCoarsening(pooler=pooler),
            ),
        )
    )

    if cfg.dataset.get("family") is not None:
        if cfg.dataset.family in ["TUDataset"]:
            data_module = TUDataModule(
                cfg.dataset.name, cfg.dataset.hparams, pre_transform=pre_transform
            )
        elif cfg.dataset.family in ["OGBDataset"]:
            data_module = OGBDataModule(
                cfg.dataset.name, cfg.dataset.hparams, pre_transform=pre_transform
            )
        elif cfg.dataset.family in ["LRGBDataset"]:
            data_module = LRGBDataModule(
                cfg.dataset.name, cfg.dataset.hparams, pre_transform=pre_transform
            )
        else:
            raise ValueError(f"Dataset family {cfg.dataset.family} not recognized")
    else:
        if cfg.dataset.name in ["EXPWL1"]:
            data_module = EXPWL1DataModule(
                cfg.dataset.hparams, pre_transform=pre_transform
            )
        elif cfg.dataset.name in ["BenchHard"]:
            data_module = BenchHardDataModule(
                cfg.dataset.hparams, pre_transform=pre_transform
            )
        elif cfg.dataset.name in ["Multipartite"]:
            data_module = MultipartiteDataModule(
                cfg.dataset.hparams, pre_transform=pre_transform
            )
        elif cfg.dataset.name in ["ZINC"]:
            data_module = ZINCDataModule(
                cfg.dataset.hparams, pre_transform=pre_transform
            )
        else:
            raise ValueError(f"Dataset {cfg.dataset.name} not recognized")

    ### 🧠 Load the model
    if cfg.dataset.name in [
        "ogbg-molhiv",
        "ogbg-molpcba",
        "ogbg-ppa",
        "peptides-func",
        "peptides-struct",
    ]:
        model = OGBModel
    else:
        model = ClassificationModel

    out_channels = (
        data_module.train_dataset.num_tasks
        if cfg.dataset.name == "ogbg-molpcba"
        else data_module.num_classes
    )

    torch_model = model(
        in_channels=data_module.num_features,  # Size of node features
        out_channels=out_channels,  # Number of classes
        edge_channels=data_module.num_edge_features,  # Size of edge features
        num_layers_pre=cfg.architecture.hparams.num_layers_pre,  # Number of GIN layers before pooling
        num_layers_post=cfg.architecture.hparams.num_layers_post,  # Number of GIN layers after pooling
        hidden_channels=cfg.architecture.hparams.hidden_channels,  # Dimensionality of node embeddings
        activation=cfg.architecture.hparams.activation,  # Activation of the MLP in GIN
        dropout=cfg.architecture.hparams.dropout,  # Dropout in the MLP
        pooler=cfg.pooler.name,  # Pooling method
        pool_kwargs=cfg.pooler.hparams,  # Pooling method kwargs
        pooled_nodes=int(
            data_module.avg_nodes * cfg.architecture.hparams.pool_ratio
        ),  # Number of nodes after pooling
        use_gine=True
        if data_module.num_edge_features > 0
        else False,  # Use GINE instead of GIN
    )

    ### 📈 Optimizer scheduler
    if cfg.get("lr_scheduler") is not None:
        scheduler_class = getattr(torch.optim.lr_scheduler, cfg.lr_scheduler.name)
        scheduler_kwargs = dict(cfg.lr_scheduler.hparams)
    else:
        scheduler_class = scheduler_kwargs = None

    ### ⚡ Lightning module
    if cfg.dataset.name in ["ogbg-molpcba", "peptides-func"]:
        module = MultiClassificationModule
    elif cfg.dataset.name in ["peptides-struct", "ZINC"]:
        module = RegressionModule
    else:
        module = SingleClassificationModule
    lightning_model = module(
        model=torch_model,
        optim_class=getattr(torch.optim, cfg.optimizer.name),
        optim_kwargs=dict(cfg.optimizer.hparams),
        scheduler_class=scheduler_class,
        scheduler_kwargs=scheduler_kwargs,
        log_lr=cfg.log_lr,
        log_grad_norm=cfg.log_grad_norm,
        plot_dict=dict(cfg.plot_preds_at_epoch),
        sync_dist=True if isinstance(find_devices(2), list) else False,
        ogbg_evaluator=cfg.dataset.get("ogbg_evaluator", False),
        ogbg_dataset=cfg.dataset.name,
    )

    ### 🪵 Logger
    if cfg.get("logger").get("backend") is None:
        logger = None
    elif cfg.logger.backend == "neptune":
        logger = NeptuneLogger(
            project_name=cfg.logger.project,
            save_dir=cfg.logger.logdir,
            tags=cfg.tags,
            params=stringify_unsupported(OmegaConf.to_container(cfg, resolve=True)),
            debug=cfg.logger.offline,
            source_files=[
                "run_classification.py",
                "source/models/classification_models.py",
                "source/pl_modules/classification_module.py",
            ],
        )
        OmegaConf.save(cfg, "run_config.yaml")
        logger.log_artifact("run_config.yaml", delete_after=True)
        logger.cfg = cfg
    else:
        raise NotImplementedError("Logger backend not supported.")

    ### 📞 Callbacks
    cb = []
    if cfg.callbacks.early_stop:
        early_stop_callback = EarlyStopping(
            monitor=cfg.callbacks.monitor,
            patience=cfg.callbacks.patience,
            mode=cfg.callbacks.mode,
        )
        cb.append(early_stop_callback)

    if cfg.callbacks.checkpoints:
        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor=cfg.callbacks.monitor,
            mode=cfg.callbacks.mode,
            dirpath=cfg.logger.logdir + "/checkpoints/",
            filename=cfg.architecture.name
            + "_"
            + cfg.pooler.name
            + "___{epoch:03d}-{cfg.callbacks.monitor:e}",
        )
        cb.append(checkpoint_callback)
        
    if cfg.callbacks.timer:
        timer_callback = Timer()
        cb.append(timer_callback)

    ### 📊 Profiler
    if cfg.profiler:
        profiler_dirpath = cfg.profiler.hparams.dirpath
        profiler_filename = cfg.profiler.hparams.filename
        if cfg.profiler.name == "simple":
            profiler = SimpleProfiler(
                dirpath=profiler_dirpath, filename=profiler_filename
            )
        elif cfg.profiler.name == "advanced":
            profiler = AdvancedProfiler(
                dirpath=profiler_dirpath, filename=profiler_filename
            )
    else:
        profiler = None

    ### 🚀 Trainer
    trainer = pl.Trainer(
        logger=logger,
        profiler=profiler,
        callbacks=cb,
        devices=find_devices(2),  # Num of GPUs available
        max_epochs=cfg.epochs,
        limit_train_batches=cfg.limit_train_batches,
        limit_val_batches=cfg.limit_val_batches,
        gradient_clip_val=cfg.clip_val,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        overfit_batches=0.0,  # >0 for debugging
        detect_anomaly=False,  # helps to find NaNs
    )

    trainer.fit(
        lightning_model, data_module.train_dataloader(), data_module.val_dataloader()
    )
    val_loss = trainer.callback_metrics[
        cfg.callbacks.monitor
    ].item()  # Used by the sweeper to optimize the hyperparameters

    if cfg.logger.backend == "neptune" and profiler is not None:
        try:
            logger.log_artifact(
                profiler_dirpath + "/fit-" + profiler_filename + ".txt",
                delete_after=True,
            )
        except FileNotFoundError:
            print("Profiler artifact not found")

    if cfg.callbacks.checkpoints:
        trainer.test(lightning_model, data_module.test_dataloader(), ckpt_path="best")
    else:
        trainer.test(lightning_model, data_module.test_dataloader())

    if cfg.logger.backend == "neptune" and profiler is not None:
        try:
            logger.log_artifact(
                profiler_dirpath + "/test-" + profiler_filename + ".txt",
                delete_after=True,
            )
        except FileNotFoundError:
            print("Profiler artifact not found")

    # Log train and val time
    if cfg.callbacks.timer:
        tr_time = timer_callback.time_elapsed("train")
        val_time = timer_callback.time_elapsed("validate")
        logger.log_metric("tr_time", tr_time)
        logger.log_metric("val_time", val_time)
    log_gpu_memory_usage("AFTER_TRAINING")

    if logger is not None:
        logger.finalize("success")

    print("✅ Classification job completed successfully.")
    return val_loss


if __name__ == "__main__":
    run()
