import argparse
import os

# Set the GPU to use before importing torch
os.environ['CUDA_VISIBLE_DEVICES'] = '7'

import torch
from train.train import train
from train.train_ogbn import train_ogbn
from train.train_batch import train_batch
from data.dataloader import load_dataset
from post_process.post_process import post_process
from utils import generate_random_numbers, set_seed, load_hyperparameters, set_device, create_directory, get_base_model_path, get_processed_model_path
from model.utils import clone_base_model, save_model, clone_processed_model, init_processed_model
from experiment.exp import run_experiment
from experiment.plot import plot_results

def train_and_save(data, params, device, args, iteration):
    save_path = get_base_model_path(args, iteration)
    create_directory(save_path)
    if args.dataset in ['ogbn-arxiv']:
        model, _ = train_ogbn(data, params, device, save_path, args)
    elif args.dataset in ['flickr', 'reddit']:
        model, _ = train_batch(data, params, device, save_path, args)
    else:
        model, _ = train(data, params, device, save_path, args)
    print(f"Completed training {args.architecture} on {args.dataset} iteration {iteration+1}.")
    
    return model

def process_and_save(data, params, device, args, train_num, processed_num):
    is_reprocess = processed_num != 0
    
    if is_reprocess:
        save_path = get_processed_model_path(args, train_num, processed_num)
        model_path = get_processed_model_path(args, train_num, processed_num - 1)
    else:
        save_path = get_processed_model_path(args, train_num, processed_num)
        model_path = get_base_model_path(args, train_num)

    create_directory(save_path)
    
    if is_reprocess:
        model = clone_processed_model(data, model_path, params, device, args)
    else:
        model = clone_base_model(data, model_path, params, device, args)

    target = init_processed_model(data, params, device, args)
    target = post_process(data, model, target, device, args, 1.0, params["step_size"])
    save_model(target, save_path)
    
    del target

def main():
    parser = argparse.ArgumentParser(description="FILLER")
    # Parameters that can affect to version of model
    parser.add_argument('--architecture', type=str, required=True, choices=['gcn', 'graphsage', 'sgc', 'gat', 'gin'])
    parser.add_argument('--dataset', type=str, required=True, choices=['cora', 'citeseer', 'pubmed', 'computers', 'photo', 'CS', 'physics', 'ogbn-arxiv', 'reddit', 'flickr'])
    parser.add_argument('--edge_ratio', type=float, default=1.0, help='Ratio of edges used in training')
    parser.add_argument("--random_seed", type=int, default=42)
    parser.add_argument('--repeat_num', type=int, default=5, help='Number of exp repeats')
    parser.add_argument('--process_num', type=int, default=5, help='Number of processing iterations')
    parser.add_argument('--pp_method', type=str, default='advanced',  choices=['simple', 'advanced'])

    # Parameters to control what to conduct under framework
    parser.add_argument('--step', type=str, required=True, choices=['train', 'post_process', 'plot', 'all'])
    args = parser.parse_args()

    # Set environment variable for GPU
    print(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")
    device = set_device()

    # {3*repeat_num} random numbers will be used as the random seed
    # First {repeat_num} for train, Second {repeat_num} for post-process and third {repeat_num} for experiment
    seed_list = generate_random_numbers(args.repeat_num*3, args.random_seed)
    
    # In the experiment, we divide edge_ratio into {exp_num} blocks.
    # For large datasets, we reduce this to 4 blocks due to time constraints.
    exp_num = 20
    if args.dataset in ['ogbn-arxiv', 'reddit', 'flickr']:
        exp_num = 4

    # Train the base model
    for i in range(args.repeat_num):
        set_seed(seed_list[i])

        data = load_dataset(args).to(device)
        print(data.num_nodes, data.edge_index.size(), data.train_mask.sum(), data.val_mask.sum(), data.test_mask.sum())

        # load hyperparameters according to model
        params = load_hyperparameters(args, data)
        
        if args.step in ['train', 'all']:
            train_and_save(data, params, device, args, i)

        if args.step in ['post_process', 'all']:
            set_seed(seed_list[args.repeat_num + i])
            for j in range(args.process_num):
                process_and_save(data, params, device, args, i, j)
            
            set_seed(seed_list[args.repeat_num*2 +i])
            run_experiment(data, params, device, args, exp_num, i, seed_list[i])
            torch.cuda.empty_cache()
    
    if args.step in ['post_process', 'plot', 'all']:
        plot_results(args, exp_num)
    
        
if __name__ == "__main__":
    main()