import datetime
import logging
import os
import random
import sys

import numpy as np
import torch
from torch_geometric import seed_everything
from torch_geometric.graphgym.cmd_args import parse_args as original_parse_args
from torch_geometric.graphgym.config import (
    cfg,
    dump_cfg,
    load_cfg,
    makedirs_rm_exist,
    set_cfg,
)
from torch_geometric.graphgym.loader import create_loader
from torch_geometric.graphgym.logger import set_printing
from torch_geometric.graphgym.model_builder import create_model
from torch_geometric.graphgym.optim import (
    OptimizerConfig,
    create_optimizer,
    create_scheduler,
)
from torch_geometric.graphgym.register import train_dict
from torch_geometric.graphgym.train import train
from torch_geometric.graphgym.utils.agg_runs import agg_runs
from torch_geometric.graphgym.utils.comp_budget import params_count
from torch_geometric.graphgym.utils.device import auto_select_device
from torch_geometric.loader import DataLoader
from tqdm import tqdm

from grit.finetuning import (
    init_model_from_pretrained,
    load_pretrained_model_cfg,
)
from grit.logger import create_logger
from grit.optimizer.extra_optimizers import ExtendedSchedulerConfig


def parse_args():
    args = original_parse_args()
    for idx, arg in enumerate(sys.argv):
        if "slt.linear_sparsity" in arg:
            if "," in sys.argv[idx + 1]:
                values = [
                    float(value) for value in sys.argv[idx + 1].split(",")
                ]
            else:
                values = [float(sys.argv[idx + 1])]
            for opt_idx, opt_arg in enumerate(args.opts):
                if opt_arg == "slt.linear_sparsity":
                    args.opts[opt_idx + 1] = values
                    break
            else:
                args.opts.extend(["slt.linear_sparsity", values])
            break
    return args


def new_optimizer_config(cfg):
    return OptimizerConfig(
        optimizer=cfg.optim.optimizer,
        base_lr=cfg.optim.base_lr,
        weight_decay=cfg.optim.weight_decay,
        momentum=cfg.optim.momentum,
    )


def new_scheduler_config(cfg):
    return ExtendedSchedulerConfig(
        scheduler=cfg.optim.scheduler,
        steps=cfg.optim.steps,
        lr_decay=cfg.optim.lr_decay,
        max_epoch=cfg.optim.max_epoch,
        reduce_factor=cfg.optim.reduce_factor,
        schedule_patience=cfg.optim.schedule_patience,
        min_lr=cfg.optim.min_lr,
        num_warmup_epochs=cfg.optim.num_warmup_epochs,
        train_mode=cfg.train.mode,
        eval_period=cfg.train.eval_period,
        num_cycles=cfg.optim.num_cycles,
        min_lr_mode=cfg.optim.min_lr_mode,
    )


def custom_set_out_dir(cfg, cfg_fname, name_tag):
    run_name = os.path.splitext(os.path.basename(cfg_fname))[0]
    run_name += f"-{name_tag}" if name_tag else ""
    cfg.out_dir = os.path.join(cfg.out_dir, run_name)


def custom_set_run_dir(cfg, run_id):
    cfg.run_dir = os.path.join(cfg.out_dir, str(run_id))
    if cfg.train.auto_resume:
        os.makedirs(cfg.run_dir, exist_ok=True)
    else:
        makedirs_rm_exist(cfg.run_dir)


def run_loop_settings():
    if len(cfg.run_multiple_splits) == 0:
        num_iterations = args.repeat
        seeds = [cfg.seed + x for x in range(num_iterations)]
        split_indices = [cfg.dataset.split_index] * num_iterations
        run_ids = seeds
    else:
        if args.repeat != 1:
            raise NotImplementedError(
                "Running multiple repeats of multiple "
                "splits in one run is not supported."
            )
        num_iterations = len(cfg.run_multiple_splits)
        seeds = [cfg.seed] * num_iterations
        split_indices = cfg.run_multiple_splits
        run_ids = split_indices
    return run_ids, seeds, split_indices


def perturb_node_features(
    data, perturbation_rate, noise_level=0.1, perturbation_type="bernoulli_0.5"
):
    num_nodes = data.x.size(0)
    num_perturb_nodes = int(num_nodes * perturbation_rate)
    perturb_indices = random.sample(range(num_nodes), num_perturb_nodes)
    noise = torch.zeros((num_perturb_nodes, data.x.size(1)))
    if perturbation_type == "bernoulli_0.5":
        prob = 0.5
        noise = noise.bernoulli(prob)
    elif perturbation_type == "gaussian":
        noise = torch.randn((num_perturb_nodes, data.x.size(1))) * noise_level
    else:
        raise ValueError(f"Unsupported perturbation type: {perturbation_type}")
    if data.x.dtype == torch.int64:
        noise = noise.round().to(torch.int64)

    data.x[perturb_indices] = noise
    return data


def perturb_node(data_loader, perturbation_rate, noise_level=0.1):
    perturbed_graphs = []
    for batch in tqdm(data_loader, desc="Perturbing nodes"):
        for data in batch.to_data_list():
            perturbed_data = perturb_node_features(
                data, perturbation_rate, noise_level
            )
            perturbed_graphs.append(perturbed_data)
    return DataLoader(
        perturbed_graphs,
        batch_size=data_loader.batch_size,
        shuffle=False,
    )


def perturb_edge(data_loader, edge_perturbation_rate):
    perturbed_graphs = []
    for batch in tqdm(data_loader, desc="Perturbing edges"):
        for data in batch.to_data_list():
            edge_index = data.edge_index
            num_edges = edge_index.size(1)
            num_perturb_edges = int(num_edges * edge_perturbation_rate)
            edge_indices = np.random.choice(
                num_edges, num_perturb_edges, replace=False
            )
            for edge_idx in edge_indices:
                target_node = edge_index[1, edge_idx]
                candidates = np.setdiff1d(
                    np.arange(data.num_nodes), target_node
                )
                new_target_node = np.random.choice(candidates)
                data.edge_index[1, edge_idx] = torch.tensor(
                    new_target_node, device=data.edge_index.device
                )
            perturbed_graphs.append(data)
    return DataLoader(
        perturbed_graphs,
        batch_size=data_loader.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8,
    )


def delete_random_graphs(loader, delete_percentage):
    dataset = loader.dataset
    num_graphs = len(dataset)
    num_delete = int(num_graphs * delete_percentage)
    delete_indices = np.random.choice(num_graphs, num_delete, replace=False)
    remaining_graphs = []
    for i in tqdm(range(num_graphs), desc="Deleting graphs"):
        if i not in delete_indices:
            remaining_graphs.append(dataset[i])
    return DataLoader(
        remaining_graphs,
        batch_size=loader.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8,
    )


def filter_loader(loader, ood_classes=None):
    filtered_data_list = []
    for batch in tqdm(loader, desc="OOD"):
        for data in batch.to_data_list():
            if data.y in ood_classes:
                filtered_data_list.append(data)
    return DataLoader(
        filtered_data_list,
        batch_size=loader.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=8,
    )


def remove_ood_classes(loaders, num_ood_classes):
    all_classes = torch.cat([data.y for data in loaders[0]], dim=0).unique()
    ood_classes = np.random.choice(
        all_classes.cpu().numpy(), num_ood_classes, replace=False
    )
    ood_class_set = set(ood_classes)
    not_ood_classes = torch.tensor(
        [cls for cls in all_classes if cls.item() not in ood_class_set]
    )
    ood_classes = torch.tensor(ood_classes)
    loaders[0] = filter_loader(loaders[0], ood_classes=not_ood_classes)
    loaders[1] = filter_loader(loaders[1], ood_classes=ood_classes)
    loaders[2] = filter_loader(loaders[2], ood_classes=ood_classes)
    return loaders, ood_classes


if __name__ == "__main__":
    args = parse_args()
    set_cfg(cfg)
    cfg.set_new_allowed(True)
    cfg.work_dir = os.getcwd()
    load_cfg(cfg, args)
    cfg.cfg_file = args.cfg_file
    custom_set_out_dir(cfg, args.cfg_file, cfg.name_tag)
    dump_cfg(cfg)
    torch.set_num_threads(cfg.num_threads)
    for run_id, seed, split_index in zip(*run_loop_settings()):
        custom_set_run_dir(cfg, run_id)
        set_printing()
        cfg.dataset.split_index = split_index
        cfg.seed = seed
        cfg.run_id = run_id
        seed_everything(cfg.seed)
        if cfg.get("auto_select_device", False):
            auto_select_device()
        else:
            cfg.device = cfg.accelerator
        if cfg.pretrained.dir:
            cfg = load_pretrained_model_cfg(cfg)
        logging.info(
            f"[*] Run ID {run_id}: seed={cfg.seed}, "
            f"split_index={cfg.dataset.split_index}"
        )
        logging.info(f"    Starting now: {datetime.datetime.now()}")
        loaders = create_loader()
        if cfg.slt.node_perturbation > 0.0:
            loaders[1] = perturb_node(loaders[1], cfg.slt.node_perturbation)
            loaders[2] = perturb_node(loaders[2], cfg.slt.node_perturbation)
        if cfg.slt.edge_perturbation > 0.0:
            loaders[1] = perturb_edge(loaders[1], cfg.slt.edge_perturbation)
            loaders[2] = perturb_edge(loaders[2], cfg.slt.edge_perturbation)
        if cfg.slt.train_data_delete > 0.0:
            loaders[0] = delete_random_graphs(
                loaders[0], cfg.slt.train_data_delete
            )
        if cfg.slt.num_ood_classes > 0.0:
            loaders, ood_classes = remove_ood_classes(
                loaders, cfg.slt.num_ood_classes
            )
        else:
            ood_classes = None

        loggers = create_logger()
        model = create_model()
        if cfg.pretrained.dir:
            model = init_model_from_pretrained(
                model,
                cfg.pretrained.dir,
                cfg.pretrained.freeze_main,
                cfg.pretrained.reset_prediction_head,
            )
        for name, p in model.named_parameters():
            logging.info(
                f"{name}, {p.shape}, requires_grad, {p.requires_grad}"
            )
        optimizer = create_optimizer(
            model.parameters(), new_optimizer_config(cfg)
        )
        scheduler = create_scheduler(optimizer, new_scheduler_config(cfg))
        logging.info(model)
        logging.info(cfg)

        for name, param in model.named_parameters():
            print(
                f"Name: {name}, Size: {param.size()}, Number of elements: {param.numel()}"
            )
        cfg.params = params_count(model)
        logging.info("Num parameters: %s", cfg.params)
        if cfg.train.mode == "standard":
            if cfg.wandb.use:
                logging.warning(
                    "[W] WandB logging is not supported with the "
                    "default train.mode, set it to `custom`"
                )
            if cfg.mlflow.use:
                logging.warning(
                    "[ML] MLflow logging is not supported with the "
                    "default train.mode, set it to `custom`"
                )
            train(loggers, loaders, model, optimizer, scheduler)
        else:
            train_dict[cfg.train.mode](
                loggers, loaders, model, optimizer, scheduler, ood_classes
            )
    try:
        agg_runs(cfg.out_dir, cfg.metric_best)
    except Exception as e:
        logging.info(f"Failed when trying to aggregate multiple runs: {e}")
    if args.mark_done:
        os.rename(args.cfg_file, f"{args.cfg_file}_done")
    logging.info(f"[*] All done: {datetime.datetime.now()}")
