import os
import sys
import argparse

# Suppress TensorFlow informational messages (TensorFlow may be imported by wandb or tensorboard-logger)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # 0=all, 1=info, 2=warnings, 3=errors only
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'  # Disable oneDNN custom operations messages

# Parse --gpu argument EARLY (before importing torch) to set CUDA_VISIBLE_DEVICES
# This must happen before PyTorch initializes CUDA, otherwise CUDA_VISIBLE_DEVICES has no effect
parser = argparse.ArgumentParser(add_help=False)  # Don't show help, we'll parse full args later
parser.add_argument('--gpu', type=str, default=None)
parser.add_argument('--no_cuda', action='store_true')
early_args, _ = parser.parse_known_args()

# Set CUDA_VISIBLE_DEVICES before importing torch (critical for GPU selection to work)
if early_args.gpu is not None and not early_args.no_cuda:
    try:
        # Parse and validate format (but can't validate GPU IDs exist without torch)
        gpu_ids = [int(x.strip()) for x in early_args.gpu.split(',')]
        if any(gpu_id < 0 for gpu_id in gpu_ids):
            raise ValueError("GPU IDs must be non-negative")
        os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(map(str, gpu_ids))
        print(f"Setting CUDA_VISIBLE_DEVICES={os.environ['CUDA_VISIBLE_DEVICES']} (GPU(s): {gpu_ids})")
    except ValueError as e:
        print(f"Error: Invalid --gpu argument '{early_args.gpu}': {e}")
        print("Expected format: '0' or '0,1,2,3'")
        sys.exit(1)

import json
import torch
import random
import pprint
from tensorboard_logger import Logger as TbLogger
import warnings
from options import get_options

from problems.problem_pdtsp import PDTSP
from problems.problem_pdtspl import PDTSPL
from agent.ppo import PPO

# Adding the support of the WandB
import wandb

def load_agent(name):
    agent = {
        'ppo': PPO,
    }.get(name, None)
    assert agent is not None, "Currently unsupported agent: {}!".format(name)
    return agent

def load_problem(name):
    problem = {
        'pdtsp': PDTSP,
        'pdtspl': PDTSPL,
    }.get(name, None)
    assert problem is not None, "Currently unsupported problem: {}!".format(name)
    return problem


def run(opts):

    # If the wandb is enabled, initialize the wandb
    if opts.wandb:
        if opts.eval_only:
            wandb.init(
                project="HADES-N2S", 
                name=f"Eval/{opts.problem}/{opts.graph_size}/ori"
            )
        else:
            wandb.init(
                project="HADES-N2S", 
                name=f"Train/{opts.problem}/{opts.graph_size}/ori"
            )
    else:
        # Raise a warning
        print("WandB is not enabled, please enable it by adding the --wandb flag")

    # Pretty print the run args
    pprint.pprint(vars(opts))

    # Set the random seed
    torch.manual_seed(opts.seed)
    random.seed(opts.seed)

    # Optionally configure tensorboard
    tb_logger = None
    if not opts.no_tb and not opts.distributed:
        tb_logger = TbLogger(os.path.join(opts.log_dir, "{}_{}".format(opts.problem, 
                                                          opts.graph_size), opts.run_name))
    if not opts.no_saving and not os.path.exists(opts.save_dir):
        os.makedirs(opts.save_dir)
        
    # Save arguments so exact configuration can always be found
    if not opts.no_saving:
        with open(os.path.join(opts.save_dir, "args.json"), 'w') as f:
            json.dump(vars(opts), f, indent=True)

    # Set the device
    if opts.use_cuda:
        if opts.gpu_ids is not None and len(opts.gpu_ids) == 1:
            # Single GPU specified - use that device
            # After CUDA_VISIBLE_DEVICES is set, the GPU is reindexed to 0
            opts.device = torch.device("cuda:0")
        else:
            # Multiple GPUs or no specific GPU - use default (will be handled by distributed training if needed)
            opts.device = torch.device("cuda:0")
    else:
        opts.device = torch.device("cpu")
    
    # Figure out what's the problem
    problem = load_problem(opts.problem)(
                            p_size = opts.graph_size,
                            init_val_met = opts.init_val_met,
                            with_assert = opts.use_assert)
    
    # Figure out the RL algorithm
    agent = load_agent(opts.RL_agent)(problem.NAME, problem.size,  opts)

    # Load data from load_path
    assert opts.load_path is None or opts.resume is None, "Only one of load path and resume can be given"
    load_path = opts.load_path if opts.load_path is not None else opts.resume
    if load_path is not None:
        agent.load(load_path)

    
    # Do validation only
    if opts.eval_only:
        # Load the validation datasets
        agent.start_inference(problem, opts.val_dataset, tb_logger)
        
    else:
        if opts.resume:
            epoch_resume = int(os.path.splitext(os.path.split(opts.resume)[-1])[0].split("-")[1])
            print("Resuming after {}".format(epoch_resume))
            agent.opts.epoch_start = epoch_resume + 1
    
        # Start the actual training loop
        agent.start_training(problem, opts.val_dataset, tb_logger)
            


if __name__ == "__main__":
    
    warnings.filterwarnings("ignore")
    
    os.environ['KMP_DUPLICATE_LIB_OK']='True'
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    run(get_options())
