import sys
import traceback
import torch
from torch_geometric.datasets import Planetoid, CitationFull
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Compose, NormalizeFeatures
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.profilers import AdvancedProfiler, SimpleProfiler
import hydra
from omegaconf import DictConfig, OmegaConf
import neptune
from neptune.utils import stringify_unsupported

from tgp.datasets import PyGSPDataset
from tgp.poolers import pooler_map

# Local imports
from source.pl_modules import ClusterModule
from source.models import ClusterModel
from source.utils import (
    NeptuneLogger,
    register_resolvers,
    find_devices,
    SortNodes,
    CustomTensorBoardLogger,
    FilterCallback,
    log_gpu_memory_usage,
)

register_resolvers()
neptune.internal.operation_processors.async_operation_processor.logger.addFilter(
    FilterCallback()
)

@hydra.main(version_base=None, config_path="config", config_name="run_clustering")
def run(cfg: DictConfig) -> float:

    
    try:
        return _run_clustering(cfg)
    except Exception as e:
        # Extract key configuration details for error reporting
        pooler_name = cfg.get('pooler', {}).get('name', 'unknown')
        dataset_name = cfg.get('dataset', {}).get('name', 'unknown')
        task_id = cfg.get('task', 'unknown')
        config_info = f"pooler={pooler_name}, dataset={dataset_name}, task={task_id}"
        
        print(f"❌ CLUSTERING_JOB_FAILED: {config_info} | Error: {str(e)}")
        print("Full traceback:")
        traceback.print_exc()
        
        # For multirun sweeps, we want individual job failures to not stop the entire sweep
        # Return a sentinel value to indicate failure
        return float('inf')  # or some other error indicator

def _run_clustering(cfg: DictConfig) -> float:
    print(OmegaConf.to_yaml(cfg, resolve=True))
    
    # Log initial GPU state
    log_gpu_memory_usage("BEFORE_TRAINING")

    ### 📊 Load data
    pre_transform = Compose(
        list(
            filter(
                None,
                (
                    SortNodes(),
                    NormalizeFeatures()
                    if cfg.dataset.family in {"CitationFull", "Planetoid"}
                    else None,
                    pooler_map[cfg.pooler.name].data_transforms(),
                ),
            )
        )
    )

    if cfg.dataset.family == "Planetoid":
        torch_dataset = Planetoid(
            root="data/",
            name=cfg.dataset.name,
            split=cfg.dataset.hparams.split,
            pre_transform=pre_transform,
            force_reload=True,
        )
        num_classes = torch_dataset.num_classes
    elif cfg.dataset.family == "CitationFull":
        torch_dataset = CitationFull(
            root="data/", name=cfg.dataset.name, pre_transform=pre_transform, force_reload=True
        )
        num_classes = torch_dataset.num_classes
    elif cfg.dataset.family == "PyGSPDataset":
        torch_dataset = PyGSPDataset(
            root="data/PyGSP",
            name=cfg.dataset.name,
            kwargs=cfg.dataset.params,
            force_reload=cfg.dataset.hparams.force_reload,
            pre_transform=pre_transform,
        )
        num_classes = cfg.architecture.hparams.pool_ratio
    else:
        raise ValueError(f"Dataset {cfg.dataset.family} not recognized")

    data_loader = DataLoader(torch_dataset, batch_size=cfg.batch_size, shuffle=False)

    ### 🧠 Load the model
    torch_model = ClusterModel(
        in_channels=torch_dataset.num_features,  # Size of node features
        num_layers_pre=cfg.architecture.hparams.num_layers_pre,  # Number of GIN layers before pooling
        hidden_channels=cfg.architecture.hparams.hidden_channels,  # Dimensionality of node embeddings
        activation=cfg.architecture.hparams.activation,  # Activation of the MLP in GIN
        pooler=cfg.pooler.name,  # Pooling method
        pool_kwargs=cfg.pooler.hparams,  # Pooling method kwargs
        pooled_nodes=num_classes,  # Number of nodes after pooling
    )

    ### 📈 Optimizer scheduler
    if cfg.get("lr_scheduler") is not None:
        scheduler_class = getattr(torch.optim.lr_scheduler, cfg.lr_scheduler.name)
        scheduler_kwargs = dict(cfg.lr_scheduler.hparams)
    else:
        scheduler_class = scheduler_kwargs = None

    ### ⚡ Lightning module
    lightning_model = ClusterModule(
        model=torch_model,
        num_classes=num_classes,
        optim_class=getattr(torch.optim, "Adam"),
        optim_kwargs=dict(cfg.optimizer.hparams),
        scheduler_class=scheduler_class,
        scheduler_kwargs=scheduler_kwargs,
        log_lr=cfg.log_lr,
        log_grad_norm=cfg.log_grad_norm,
        plot_dict=dict(cfg.plot_preds_at_epoch),
    )

    ### 🪵 Logger
    if cfg.get("logger").get("backend") is None:
        logger = None
    elif cfg.logger.backend == "neptune":
        logger = NeptuneLogger(
            project_name=cfg.logger.project,
            save_dir=cfg.logger.logdir,
            tags=cfg.tags,
            params=stringify_unsupported(OmegaConf.to_container(cfg, resolve=True)),
            debug=cfg.logger.offline,
            source_files=[
                "run_clustering.py",
                "source/models/clustering_models.py",
                "source/pl_modules/clustering_module.py",
            ],  # Add other files to be logged
        )
        OmegaConf.save(cfg, "run_config.yaml")
        logger.log_artifact("run_config.yaml", delete_after=True)
        logger.cfg = cfg
    elif cfg.logger.backend == "tensorboard":
        logger = CustomTensorBoardLogger(
            save_dir=cfg.logger.logdir, name=None, version=""
        )
        logger.cfg = cfg
    else:
        raise NotImplementedError("Backend not in ['tensorboard','neptune']")

    ### 📞 Callbacks
    cb = []
    if cfg.callbacks.early_stop:
        early_stop_callback = EarlyStopping(
            monitor=cfg.callbacks.monitor,
            patience=cfg.callbacks.patience,
            mode=cfg.callbacks.mode,
        )
        cb.append(early_stop_callback)

    if cfg.callbacks.checkpoints:
        checkpoint_callback = ModelCheckpoint(
            save_top_k=1,
            monitor=cfg.callbacks.monitor,
            mode=cfg.callbacks.mode,
            dirpath=cfg.logger.logdir + "/checkpoints/",
            filename=cfg.architecture.name
            + "_"
            + cfg.pooler.name
            + "___{epoch:03d}-{NMI:e}",
        )
        cb.append(checkpoint_callback)

    ### 📊 Profiler
    if cfg.profiler:
        profiler_dirpath = cfg.profiler.hparams.dirpath
        profiler_filename = cfg.profiler.hparams.filename
        if cfg.profiler.name == "simple":
            profiler = SimpleProfiler(
                dirpath=profiler_dirpath, filename=profiler_filename
            )
        elif cfg.profiler.name == "advanced":
            profiler = AdvancedProfiler(
                dirpath=profiler_dirpath, filename=profiler_filename
            )
    else:
        profiler = None

    ### 🚀 Trainer
    trainer = pl.Trainer(
        logger=logger,
        profiler=profiler,
        callbacks=cb,
        devices=find_devices(1),  # Num of GPUs available
        max_epochs=cfg.epochs,
        gradient_clip_val=cfg.clip_val,
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        overfit_batches=0.0,  # >0 for debugging
    )

    trainer.fit(lightning_model, data_loader)
    if cfg.logger.backend == "neptune" and profiler is not None:
        try:
            logger.log_artifact(
                profiler_dirpath + "/fit-" + profiler_filename + ".txt",
                delete_after=True,
            )
        except FileNotFoundError:
            print("Profiler file not found")

    if cfg.callbacks.checkpoints:
        trainer.test(lightning_model, data_loader, ckpt_path="best")
    else:
        trainer.test(lightning_model, data_loader)

    if cfg.logger.backend == "neptune" and profiler is not None:
        try:
            logger.log_artifact(
                profiler_dirpath + "/test-" + profiler_filename + ".txt",
                delete_after=True,
            )
        except FileNotFoundError:
            print("Profiler file not found")

    if logger is not None:
        logger.finalize("success")

    result = trainer.callback_metrics["test_loss"].item()
    
    # Log final GPU state
    log_gpu_memory_usage("AFTER_TRAINING")
    
    print(f"✅ Clustering job completed successfully.")
    return result


if __name__ == "__main__":
    run()
