import traceback
import torch
import torch_geometric
from torch_geometric.transforms import NormalizeFeatures, Compose
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch.profilers import AdvancedProfiler, SimpleProfiler
import hydra
from omegaconf import DictConfig, OmegaConf
import neptune
from neptune.utils import stringify_unsupported
from tgp.poolers import pooler_map

# Local imports
from source.pl_modules import NodeClassificationModule
from source.models import AutoEncoderModel, GCNModel, GraphUNet
from source.utils import (
    NeptuneLogger,
    register_resolvers,
    find_devices,
    FilterCallback,
    log_gpu_memory_usage,
)
from source.data import NodeClassDataModule

register_resolvers()
neptune.internal.operation_processors.async_operation_processor.logger.addFilter(
    FilterCallback()
)


@hydra.main(
    version_base=None, config_path="config", config_name="run_node_classification"
)
def run(cfg: DictConfig) -> float:
    try:
        return _run_node_classification(cfg)
    except Exception as e:
        # Extract key configuration details for error reporting
        pooler_name = cfg.get('pooler', {}).get('name', 'unknown')
        dataset_name = cfg.get('dataset', {}).get('name', 'unknown')
        task_id = cfg.get('task', 'unknown')
        config_info = f"pooler={pooler_name}, dataset={dataset_name}, task={task_id}"
        
        print(f"❌ NODE_CLASSIFICATION_JOB_FAILED: {config_info} | Error: {str(e)}")
        print("Full traceback:")
        traceback.print_exc()
        
        # For multirun sweeps, we want individual job failures to not stop the entire sweep
        # Return a sentinel value to indicate failure
        return float('inf')  # or some other error indicator

def _run_node_classification(cfg: DictConfig) -> float:
    print(OmegaConf.to_yaml(cfg, resolve=True))
    
    # Log initial GPU state
    log_gpu_memory_usage("BEFORE_TRAINING")

    ### 🌱 Seed everything
    if "seed" in cfg.dataset.hparams:
        print(f"Setting seed to {cfg.dataset.hparams.seed}")
        torch_geometric.seed.seed_everything(cfg.dataset.hparams.seed)

    ### 📊 Load data
    pre_transform = Compose(
        list(
            filter(
                None,
                (
                    NormalizeFeatures() if cfg.dataset.hparams.norm_feats else None,
                    pooler_map[cfg.pooler.name].data_transforms(),
                ),
            )
        )
    )

    data_module = NodeClassDataModule(
        cfg.dataset.name, pre_transform=pre_transform, **cfg.dataset.hparams
    )

    ### 🧠 Load the model
    if cfg.architecture.model == "bottleneck":
        torch_model = AutoEncoderModel(
            in_channels=data_module.torch_dataset.num_features,
            out_channels=data_module.torch_dataset.num_classes,
            hidden_channels=cfg.architecture.hparams.hidden_channels,
            num_mp_layers=cfg.architecture.hparams.num_mp_layers,
            activation=cfg.architecture.hparams.activation,
            pooler=cfg.pooler.name,
            pool_kwargs=cfg.pooler.hparams,
            pooled_nodes=int(
                data_module.torch_dataset[0].x.size(0)
                * cfg.architecture.hparams.pool_ratio
            ),
            use_gine_enc=cfg.architecture.hparams.use_gine_enc,
            use_gine_bottleneck=cfg.architecture.hparams.use_gine_bottleneck,
            res_connect=cfg.architecture.hparams.res_connect,
            dropout=cfg.architecture.hparams.dropout,
            dropout_decoder=cfg.architecture.hparams.dropout_decoder,
        )
    elif cfg.architecture.model == "gcn":
        torch_model = GCNModel(
            in_channels=data_module.torch_dataset.num_features,
            out_channels=data_module.torch_dataset.num_classes,
            hidden_channels=cfg.architecture.hparams.hidden_channels,
        )
    elif cfg.architecture.model == "graph_unet":
        torch_model = GraphUNet(
            in_channels=data_module.torch_dataset.num_features,
            hidden_channels=cfg.architecture.hparams.hidden_channels,
            out_channels=data_module.torch_dataset.num_classes,
            depth=2,
            res_connect=cfg.architecture.hparams.res_connect,
        )
    else:
        raise NotImplementedError(f"Model {cfg.architecture.model} not implemented")

    ### 📈 Optimizer scheduler
    if cfg.get("lr_scheduler") is not None:
        scheduler_class = getattr(torch.optim.lr_scheduler, cfg.lr_scheduler.name)
        scheduler_kwargs = dict(cfg.lr_scheduler.hparams)
    else:
        scheduler_class = scheduler_kwargs = None

    ### ⚡ Lightning module
    lightning_model = NodeClassificationModule(
        model=torch_model,
        optim_class=getattr(torch.optim, cfg.optimizer.name),
        optim_kwargs=dict(cfg.optimizer.hparams),
        scheduler_class=scheduler_class,
        scheduler_kwargs=scheduler_kwargs,
        log_lr=cfg.log_lr,
        log_grad_norm=cfg.log_grad_norm,
        plot_dict=dict(cfg.plot_preds_at_epoch),
        fold=data_module.fold,
    )

    ### 🪵 Logger
    if cfg.get("logger").get("backend") is None:
        logger = None
    elif cfg.logger.backend == "neptune":
        logger = NeptuneLogger(
            project_name=cfg.logger.project,
            save_dir=cfg.logger.logdir,
            tags=cfg.tags,
            params=stringify_unsupported(OmegaConf.to_container(cfg, resolve=True)),
            debug=cfg.logger.offline,
            source_files=[
                "run_node_classification.py",
                "source/models/node_classification_models.py",
                "source/pl_modules/node_classification_module.py",
            ],
        )
        OmegaConf.save(cfg, "run_config.yaml")
        logger.log_artifact("run_config.yaml", delete_after=True)
        logger.cfg = cfg
    else:
        raise NotImplementedError("Logger backend not supported.")

    ### 📞 Callbacks
    cb = []
    if cfg.callbacks.early_stop:
        early_stop_callback = EarlyStopping(
            monitor=cfg.callbacks.monitor,
            patience=cfg.callbacks.patience,
            mode=cfg.callbacks.mode,
        )
        cb.append(early_stop_callback)

    if cfg.callbacks.checkpoints:
        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor=cfg.callbacks.monitor,
            mode=cfg.callbacks.mode,
            dirpath=cfg.logger.logdir + "/checkpoints/",
            filename=cfg.architecture.name
            + "_"
            + cfg.pooler.name
            + "___{epoch:03d}-{cfg.callbacks.monitor:e}",
        )
        cb.append(checkpoint_callback)

    ### 📊 Profiler
    if cfg.profiler:
        profiler_dirpath = cfg.profiler.hparams.dirpath
        profiler_filename = cfg.profiler.hparams.filename
        if cfg.profiler.name == "simple":
            profiler = SimpleProfiler(
                dirpath=profiler_dirpath, filename=profiler_filename
            )
        elif cfg.profiler.name == "advanced":
            profiler = AdvancedProfiler(
                dirpath=profiler_dirpath, filename=profiler_filename
            )
    else:
        profiler = None

    ### 🚀 Trainer
    trainer = pl.Trainer(
        logger=logger,
        profiler=profiler,
        callbacks=cb,
        devices=find_devices(1),  # Num of GPUs available
        max_epochs=cfg.epochs,
        limit_train_batches=cfg.limit_train_batches,
        limit_val_batches=cfg.limit_val_batches,
        gradient_clip_val=cfg.clip_val,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        overfit_batches=0.0,  # >0 for debugging
    )

    trainer.fit(lightning_model, data_module.dataloader(), data_module.dataloader())
    val_loss = trainer.callback_metrics[
        cfg.callbacks.monitor
    ].item()  # Used by the sweeper to optimize the hyperparameters

    if cfg.logger.backend == "neptune" and profiler is not None:
        try:
            logger.log_artifact(
                profiler_dirpath + "/fit-" + profiler_filename + ".txt",
                delete_after=True,
            )
        except FileNotFoundError:
            print("Profiler artifact not found")

    if cfg.callbacks.checkpoints:
        trainer.test(lightning_model, data_module.dataloader(), ckpt_path="best")
    else:
        trainer.test(lightning_model, data_module.dataloader())

    if cfg.logger.backend == "neptune" and profiler is not None:
        try:
            logger.log_artifact(
                profiler_dirpath + "/test-" + profiler_filename + ".txt",
                delete_after=True,
            )
        except FileNotFoundError:
            print("Profiler artifact not found")

    if logger is not None:
        logger.finalize("success")

    # Log final GPU state
    log_gpu_memory_usage("AFTER_TRAINING")

    print("✅ Node classification job completed successfully.")
    return val_loss


if __name__ == "__main__":
    run()
