import datetime
import logging
import os
import sys

import torch
from torch_geometric import seed_everything
from torch_geometric.graphgym.cmd_args import parse_args as original_parse_args
from torch_geometric.graphgym.config import (
    cfg,
    dump_cfg,
    load_cfg,
    makedirs_rm_exist,
    set_cfg,
)
from torch_geometric.graphgym.loader import create_loader
from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.optim import (
    OptimizerConfig,
    create_optimizer,
    create_scheduler,
)
from torch_geometric.graphgym.register import train_dict
from torch_geometric.graphgym.train import GraphGymDataModule, train
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device

from graphgps.agg_runs import agg_runs
from graphgps.finetuning import (
    init_model_from_pretrained,
    load_pretrained_model_cfg,
)
from graphgps.layer.gps_layer import GPSLayer
from graphgps.logger import create_logger
from graphgps.ood_perturb import (
    delete_random_graphs,
    perturb_edge,
    perturb_node,
    remove_ood_classes,
)
from graphgps.optimizer.extra_optimizers import ExtendedSchedulerConfig
from graphgps.tome.utils import parse_r

torch.backends.cuda.matmul.allow_tf32 = True  # Default False in PyTorch 1.12+
torch.backends.cudnn.allow_tf32 = True  # Default True


def parse_args():
    args = original_parse_args()

    # Check if 'slt.linear_sparsity' is in command line arguments
    for idx, arg in enumerate(sys.argv):
        if "slt.linear_sparsity" in arg:
            # Check if the argument after 'slt.linear_sparsity' contains a comma
            if "," in sys.argv[idx + 1]:
                # If so, split the string by commas and convert each part to a float
                values = [
                    float(value) for value in sys.argv[idx + 1].split(",")
                ]
            else:
                # If not, convert the string directly to a float and wrap it in a list
                values = [float(sys.argv[idx + 1])]

            # Replace or add the 'slt.linear_sparsity' key and its list of values in args.opts
            for opt_idx, opt_arg in enumerate(args.opts):
                if opt_arg == "slt.linear_sparsity":
                    args.opts[opt_idx + 1] = values
                    break
            else:
                args.opts.extend(["slt.linear_sparsity", values])

            break
    return args


def new_optimizer_config(cfg):
    return OptimizerConfig(
        optimizer=cfg.optim.optimizer,
        base_lr=cfg.optim.base_lr,
        weight_decay=cfg.optim.weight_decay,
        momentum=cfg.optim.momentum,
    )


def new_scheduler_config(cfg):
    return ExtendedSchedulerConfig(
        scheduler=cfg.optim.scheduler,
        steps=cfg.optim.steps,
        lr_decay=cfg.optim.lr_decay,
        max_epoch=cfg.optim.max_epoch,
        reduce_factor=cfg.optim.reduce_factor,
        schedule_patience=cfg.optim.schedule_patience,
        min_lr=cfg.optim.min_lr,
        num_warmup_epochs=cfg.optim.num_warmup_epochs,
        train_mode=cfg.train.mode,
        eval_period=cfg.train.eval_period,
    )


def custom_set_out_dir(cfg, cfg_fname, name_tag):
    """Set custom main output directory path to cfg.
    Include the config filename and name_tag in the new :obj:`cfg.out_dir`.

    Args:
        cfg (CfgNode): Configuration node
        cfg_fname (string): Filename for the yaml format configuration file
        name_tag (string): Additional name tag to identify this execution of the
            configuration file, specified in :obj:`cfg.name_tag`
    """
    run_name = os.path.splitext(os.path.basename(cfg_fname))[0]
    run_name += f"-{name_tag}" if name_tag else ""
    cfg.out_dir = os.path.join(cfg.out_dir, run_name)


def custom_set_run_dir(cfg, run_id):
    """Custom output directory naming for each experiment run.

    Args:
        cfg (CfgNode): Configuration node
        run_id (int): Main for-loop iter id (the random seed or dataset split)
    """
    cfg.run_dir = os.path.join(cfg.out_dir, str(run_id))
    # Make output directory
    if cfg.train.auto_resume:
        os.makedirs(cfg.run_dir, exist_ok=True)
    else:
        makedirs_rm_exist(cfg.run_dir)


def percentile(t, q):
    k = 1 + round(0.01 * float(q) * (t.numel() - 1))
    return t.view(-1).kthvalue(k).values


def run_loop_settings():
    """Create main loop execution settings based on the current cfg.

    Configures the main execution loop to run in one of two modes:
    1. 'multi-seed' - Reproduces default behaviour of GraphGym when
        args.repeats controls how many times the experiment run is repeated.
        Each iteration is executed with a random seed set to an increment from
        the previous one, starting at initial cfg.seed.
    2. 'multi-split' - Executes the experiment run over multiple dataset splits,
        these can be multiple CV splits or multiple standard splits. The random
        seed is reset to the initial cfg.seed value for each run iteration.

    Returns:
        List of run IDs for each loop iteration
        List of rng seeds to loop over
        List of dataset split indices to loop over
    """
    if len(cfg.run_multiple_splits) == 0:
        # 'multi-seed' run mode
        num_iterations = args.repeat
        seeds = [cfg.seed + x for x in range(num_iterations)]
        split_indices = [cfg.dataset.split_index] * num_iterations
        run_ids = seeds
    else:
        # 'multi-split' run mode
        if args.repeat != 1:
            raise NotImplementedError(
                "Running multiple repeats of multiple "
                "splits in one run is not supported."
            )
        num_iterations = len(cfg.run_multiple_splits)
        seeds = [cfg.seed] * num_iterations
        split_indices = cfg.run_multiple_splits
        run_ids = split_indices
    return run_ids, seeds, split_indices


def make_tome_class(transformer_class):
    class ToMeGraphGPS(transformer_class):
        """
        Modifications:
        - Initialize r, token size, and token sources.
        """

        def forward(self, *args, **kwdargs) -> torch.Tensor:
            self._tome_info["r"] = parse_r(cfg.gt.layers, cfg.slt.tome_r)
            self._tome_info["size"] = None
            self._tome_info["source"] = None

            return super().forward(*args, **kwdargs)

    return ToMeGraphGPS


if __name__ == "__main__":
    # Load cmd line args
    args = parse_args()
    # Load config file
    set_cfg(cfg)
    load_cfg(cfg, args)
    custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag)
    dump_cfg(cfg)
    # Set Pytorch environment
    torch.set_num_threads(cfg.num_threads)
    # Repeat for multiple experiment runs
    for run_id, seed, split_index in zip(*run_loop_settings()):
        # Set configurations for each run
        custom_set_run_dir(cfg, run_id)
        set_printing()
        cfg.dataset.split_index = split_index
        cfg.seed = seed
        cfg.run_id = run_id
        seed_everything(cfg.seed)
        auto_select_device()
        if cfg.pretrained.dir:
            cfg = load_pretrained_model_cfg(cfg)
            load_cfg(cfg, args)
        logging.info(
            f"[*] Run ID {run_id}: seed={cfg.seed}, "
            f"split_index={cfg.dataset.split_index}"
        )
        logging.info(f"    Starting now: {datetime.datetime.now()}")
        loaders = create_loader()
        if cfg.slt.node_perturbation > 0.0:
            loaders[1] = perturb_node(loaders[1], cfg.slt.node_perturbation)
            loaders[2] = perturb_node(loaders[2], cfg.slt.node_perturbation)

        if cfg.slt.edge_perturbation > 0.0:
            loaders[1] = perturb_edge(loaders[1], cfg.slt.edge_perturbation)
            loaders[2] = perturb_edge(loaders[2], cfg.slt.edge_perturbation)

        if cfg.slt.train_data_delete > 0.0:
            loaders[0] = delete_random_graphs(
                loaders[0], cfg.slt.train_data_delete
            )

        if cfg.slt.remove_ood_classes > 0.0:
            loaders, ood_classes = remove_ood_classes(
                loaders, cfg.slt.remove_ood_classes
            )
        else:
            ood_classes = None

        loggers = create_logger()
        model = create_model()
        if cfg.pretrained.dir:
            model = init_model_from_pretrained(
                model,
                cfg.pretrained.dir,
                cfg.pretrained.freeze_main,
                cfg.pretrained.reset_prediction_head,
                seed=cfg.seed,
            )
            if cfg.slt.tome:
                ToMeGraphGPS = make_tome_class(model.__class__)
                model.__class__ = ToMeGraphGPS
                model.r = 0
                model._tome_info = {
                    "r": model.r,
                    "size": None,
                    "source": None,
                    "trace_source": False,
                    "prop_attn": False,
                    # "class_token": model.cls_token is not None,
                    # "distill_token": False,
                }
                if (
                    hasattr(model, "dist_token")
                    and model.dist_token is not None
                ):
                    model._tome_info["distill_token"] = True

                for module in model.modules():
                    if isinstance(module, GPSLayer):
                        # module.__class__ = ToMeCustomMultiheadAttention
                        module._tome_info = model._tome_info

            logging.info(model)
            logging.info(cfg)
            cfg.params = params_count(model)
            logging.info("Num parameters: %s", cfg.params)
            train_dict[cfg.train.mode](loggers, loaders, model, None, None)
        else:
            for name, p in model.named_parameters():
                logging.info(
                    f"{name}, {p.shape}, requires_grad, {p.requires_grad}"
                )
            optimizer = create_optimizer(
                model.parameters(), new_optimizer_config(cfg)
            )
            scheduler = create_scheduler(optimizer, new_scheduler_config(cfg))
            # Print model info
            logging.info(model)
            logging.info(cfg)
            cfg.params = params_count(model)
            logging.info("Num parameters: %s", cfg.params)

            # Start training
            if cfg.train.mode == "standard":
                if cfg.wandb.use:
                    logging.warning(
                        "[W] WandB logging is not supported with the "
                        "default train.mode, set it to `custom`"
                    )
                datamodule = GraphGymDataModule()
                train(model, datamodule, logger=True)
            else:
                train_dict[cfg.train.mode](
                    loggers, loaders, model, optimizer, scheduler, ood_classes
                )

    # Aggregate results from different seeds
    try:
        agg_runs(cfg.out_dir, cfg.metric_best)
    except Exception as e:
        logging.info(f"Failed when trying to aggregate multiple runs: {e}")
    # When being launched in batch mode, mark a yaml as done
    if args.mark_done:
        os.rename(args.cfg_file, f"{args.cfg_file}_done")
    logging.info(f"[*] All done: {datetime.datetime.now()}")
