import torch
import wandb
import random
import numpy as np
from omegaconf import DictConfig
from pytorch_lightning.loggers import WandbLogger
from src.t4_node_classification.model import NodeClassifier
from src.t3_edge_regression.model import EdgeRegressor
from pytorch_lightning import Trainer, LightningModule

from src.brain_topo_decoding.t2_dacs.data_module import DACSDataModule
from src.t4_node_classification.model import NodeClassifier
from src.brain_topo_decoding.t1_dynamics.model import DynamicsClassifier
from pytorch_lightning.utilities.model_summary import ModelSummary
from src.brain_topo_decoding.t2_dacs.model import DACClassifier
from src.brain_topo_decoding.t1_dynamics.data_module import FixedVolumeDataModule
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from src.t3_edge_regression.data_module import EdgeRegressionDataModule
from src.t4_node_classification.data_module import NodeClassificationDataModule
from utils.utils import get_masks


def get_model(
    cfg: DictConfig, data_module, task: str, routing: bool = False
) -> LightningModule:
    """
    Returns the appropriate model based on the routing flag.

    Args:
        cfg (DictConfig): The Hydra configuration.
        data_module (FullBatchGraphDataModule): The data module for the current dataset.
        routing (bool): If True, initializes DirectionRoutingNodeClassifier. Otherwise, NodeClassifierDirSNN.

    Returns:
        LightningModule: The initialized model.
    """

    models = {
        "node-classification": NodeClassifier,
        "edge-regression": EdgeRegressor,
        "dac-classification": DACClassifier,
        "dynamics-classification": DynamicsClassifier,
    }

    if task == "node-classification":
        data = data_module.full_dataset.data
        edge_types = data.edge_types
        num_classes = data_module.full_dataset.y.max().item() + 1
        loss = torch.nn.NLLLoss()
    if task == "edge-regression":
        data = data_module.full_dataset.data
        edge_types = data.edge_types
        num_classes = 1  # regression task, only one output value
        loss = torch.nn.MSELoss()
    elif task == "dac-classification":
        data = data_module.train_dataset[0]
        edge_types = data.edge_types
        num_classes = 8
        loss = torch.nn.NLLLoss()
    elif task == "dynamics-classification":
        data = data_module.all_dynamics
        edge_types = list(data_module.edge_index_dict.keys())
        num_classes = 8
        loss = torch.nn.NLLLoss()

    hidden_sizes_list = cfg.model.hidden_sizes_list

    hidden_sizes = {
        data.node_types[i]: hidden_sizes_list[i] for i in range(len(data.node_types))
    }

    common_args = dict(
        edge_types=edge_types,
        hidden_sizes=hidden_sizes,
        n_classes=num_classes,
        num_layers=cfg.model.num_layers,
        dropout=cfg.model.dropout,
        jumping_knowledge=cfg.model.get("jumping_knowledge", None),
        normalize=cfg.model.get("normalize", False),
        learning_rate=cfg.training.lr,
        loss=loss,
        wd=cfg.training.get("wd", 0),
        conv_type=cfg.model.conv_type,
        device=cfg.training.device,
        lin_res=cfg.model.get("lin_res", False),
        in_aggr=cfg.model.in_aggr,
        out_aggr=cfg.model.out_aggr,
        bn=cfg.model.get("bn", False),
        ln=cfg.model.get("ln", False),
        routing=cfg.model.routing,
    )

    if task in ["node-classification"]:
        common_args["train_mask"] = data_module.train_mask
        common_args["val_mask"] = data_module.val_mask
        common_args["test_mask"] = data_module.test_mask
    elif task == "edge-regression":
        common_args["train_mask"] = data_module.train_mask
        common_args["val_mask"] = data_module.val_mask
        common_args["test_mask"] = data_module.test_mask
        common_args["loss_alpha"] = cfg.model.loss_alpha
        directed_masks = get_masks(cfg.dataset.name, "./data")
        common_args["directed_masks"] = directed_masks

    if task not in models:
        raise ValueError(f"Invalid task: {task}")

    model = models[task]

    if routing:
        routing_args = {
            "input_dims": {
                str(i): data.x_dict[str(i)].shape[1]
                for i in range(len(data.node_types))
            },
            "k": cfg.model.k,
        }
        return model(**common_args, **routing_args)

    return model(**common_args)


def train_and_evaluate(cfg: DictConfig, task: str):
    """
    This function:
      - Creates the data module
      - Instantiates the model
      - Sets up trainer (with callbacks)
      - Performs multiple splits & runs
      - Logs metrics to wandb (if enabled)
    """

    seed = cfg.training.seed if "seed" in cfg.training else 42  # Default seed
    random.seed(seed)  # Fix Python random module's seed
    np.random.seed(seed)  # Fix NumPy's random module's seed
    torch.manual_seed(seed)  # Fix PyTorch's random seed on the CPU
    torch.cuda.manual_seed_all(seed)  # Fix PyTorch's random seed on all GPUs

    wandb_logger = WandbLogger(project=cfg.project, log_model=False)
    all_val_metrics, all_test_metrics = [], []
    splits = cfg.dataset.splits

    if task == "edge-regression":
        monitor_val_metric = "val_loss"
        monitor_test_metric = "test_loss"
        val_monitor_mode = "min"
    elif task in ["node-classification","dac-classification", "dynamics-classification"]:
        monitor_val_metric = "val_acc"
        monitor_test_metric = "test_acc"
        val_monitor_mode = "max"
    else:
        raise ValueError(f"Invalid task: {task}")

    for split in range(splits):

        cfg.dataset.split_number = split

        if task == "node-classification":
            data_module = NodeClassificationDataModule(cfg.dataset)
        elif task == "edge-regression":
            data_module = EdgeRegressionDataModule(cfg.dataset)
        elif task == "dac-classification":
            data_module = DACSDataModule(cfg)
        elif task == "dynamics-classification":
            data_module = FixedVolumeDataModule(cfg)

        data_module.prepare_data()
        data_module.setup()

        val_metrics, test_metrics = [], []

        for run_idx in range(cfg.training.iterations):
            print(f"Split {split}, Run {run_idx + 1}/{cfg.training.iterations}")

            model = get_model(cfg, data_module, task, routing=cfg.model.routing)

            # Generate a full model summary after `trainer.fit()`
            # Set max_depth=-1 to recursively print all submodules
            model_summary = ModelSummary(model, max_depth=-1)

            # Print the detailed model summary to the console.
            print("Detailed Model Summary:")
            print(model_summary)

            # Update EarlyStopping and ModelCheckpoint dynamically
            early_stopping = EarlyStopping(
                monitor=monitor_val_metric,
                mode=val_monitor_mode,
                patience=cfg.training.patience,
            )

            model_checkpoint = ModelCheckpoint(
                monitor=monitor_val_metric,
                mode=val_monitor_mode,
                save_top_k=1,
                dirpath="./checkpoints",
                filename=f"best-checkpoint-{{epoch:02d}}-{{{monitor_val_metric}:.4f}}",
            )

            callbacks = [early_stopping, model_checkpoint]

            if cfg.training.device == "cuda":
                trainer = Trainer(
                    logger=wandb_logger,
                    max_epochs=cfg.training.n_epochs,
                    log_every_n_steps=20,
                    callbacks=callbacks,
                    accelerator=cfg.training.device,
                    check_val_every_n_epoch=cfg.training.validation,
                    devices=[cfg.training.gpu_idx],
                    num_sanity_val_steps=0,
                )

            elif cfg.training.device == "cpu":
                trainer = Trainer(
                    logger=wandb_logger,
                    max_epochs=cfg.training.n_epochs,
                    log_every_n_steps=20,
                    callbacks=callbacks,
                    accelerator=cfg.training.device,
                    check_val_every_n_epoch=cfg.training.validation,
                    num_sanity_val_steps=0,
                )

            trainer.fit(model=model, datamodule=data_module)

            # Find checkpoint callback for best model
            checkpoint_callback = next(
                (cb for cb in callbacks if isinstance(cb, ModelCheckpoint)), None
            )

            if checkpoint_callback is None:
                raise ValueError("ModelCheckpoint callback not defined in config.yaml.")

            # Evaluate
            best_val_metric = checkpoint_callback.best_model_score.item()

            test_metrics_dict = trainer.test(
                ckpt_path="best", dataloaders=data_module.test_dataloader()
            )[0]

            test_metric = test_metrics_dict[monitor_test_metric]

            val_metrics.append(best_val_metric)
            test_metrics.append(test_metric)

        # Log metrics for this split
        mean_val, std_val = np.mean(val_metrics), np.std(val_metrics)
        mean_test, std_test = np.mean(test_metrics), np.std(test_metrics)

        wandb.run.summary[f"split_{split}_{monitor_val_metric}_mean"] = mean_val
        wandb.run.summary[f"split_{split}_{monitor_val_metric}_std"] = std_val

        # Log final test metrics
        wandb.run.summary[f"split_{split}_{monitor_test_metric}_mean"] = mean_test
        wandb.run.summary[f"split_{split}_{monitor_test_metric}_std"] = std_test

        print(
            f"Split {split} - Validation {monitor_val_metric.upper()}: {mean_val:.4f} ± {std_val:.4f}"
        )
        print(
            f"Split {split} - Test {monitor_val_metric.upper()}: {mean_test:.4f} ± {std_test:.4f}"
        )

        all_val_metrics.extend(val_metrics)
        all_test_metrics.extend(test_metrics)

    # Overall metrics across splits & runs
    final_val_mean, final_val_std = np.mean(all_val_metrics), np.std(all_val_metrics)
    final_test_mean, final_test_std = np.mean(all_test_metrics), np.std(
        all_test_metrics
    )

    # Log final validation metrics
    wandb.run.summary[f"{monitor_val_metric}_mean"] = final_val_mean
    wandb.run.summary[f"{monitor_val_metric}_std"] = final_val_std

    # Log final test metrics
    wandb.run.summary[f"{monitor_test_metric}_mean"] = final_test_mean
    wandb.run.summary[f"{monitor_test_metric}_std"] = final_test_std

    # Generate the summary after `trainer.fit()`
    model_summary = ModelSummary(model, max_depth=-1)
    num_trainable_params = model_summary.total_parameters

    # Log it to wandb
    wandb.run.summary["num_parameters"] = num_trainable_params

    print(f"Trainable parameters logged to wandb: {num_trainable_params}")

    print(
        f"Final Validation {monitor_val_metric.upper()}: {final_val_mean:.4f} ± {final_val_std:.4f}"
    )
    print(
        f"Final Test {monitor_test_metric.upper()}: {final_test_mean:.4f} ± {final_test_std:.4f}"
    )

    return final_val_mean, final_val_std, final_test_mean, final_test_std
