import sys
import argparse
import os
import time
import logging
from datetime import datetime
import csv
import multiprocessing

import os
import copy
import signal
import json

def get_args(): 
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', required=True, help='path to config file')
    parser.add_argument('--gpu', default='0', help='GPU(s) to be used')
    parser.add_argument('--resume', default=None, help='path to the weights to be resumed')
    parser.add_argument('--completion', default=None, help='path to the exp folder to complete')

    parser.add_argument(
        '--resume_weights_only',
        action='store_true',
        help='specify this argument to restore only the weights (w/o training states), e.g. --resume path/to/resume --resume_weights_only'
    )

    #group = parser.add_mutually_exclusive_group(required=True)
    parser.add_argument('--train', action='store_true')
    parser.add_argument('--validate', action='store_true')
    parser.add_argument('--test', action='store_true')
    parser.add_argument('--predict', action='store_true')
    parser.add_argument('--debug', action='store_true')
    parser.add_argument('--sanity', action='store_true')

    parser.add_argument('--exp_dir', default='./exp')
    parser.add_argument('--runs_dir', default='./runs')
    parser.add_argument('--verbose', action='store_true', help='if true, set logging level to DEBUG')

    args, extras = parser.parse_known_args()

    logger = logging.getLogger('pytorch_lightning')
    if args.verbose:
        logger.setLevel(logging.DEBUG)

    return args, extras

args, extras = get_args()

from utils.misc import load_config, resolve_config    
from utils.loggers import ResultLogger

import numpy as np


def get_callback(args, config): 
    from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
    from utils.callbacks import CodeSnapshotCallback, ConfigSnapshotCallback, CustomProgressBar

    callbacks = []
    if args.train:
        callbacks += [
            ModelCheckpoint(
                dirpath=config.ckpt_dir,
                **config.checkpoint
            ),
            LearningRateMonitor(logging_interval='step'),
            CodeSnapshotCallback(
                config.code_dir, use_version=False
            ),
            ConfigSnapshotCallback(
                config, config.config_dir, use_version=False
            ),
            CustomProgressBar(refresh_rate=1),
        ]
    return callbacks

def get_loggers(args, config): 
    from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger

    loggers = []
    if args.train:
        loggers += [
            TensorBoardLogger(args.runs_dir, name=config.name, version=config.trial_name),
            CSVLogger(config.exp_dir, name=config.trial_name, version='csv_logs')
        ]
    return loggers

def get_distri_strategy(): 
    if sys.platform == 'win32':
        # does not support multi-gpu on windows
        strategy = 'dp'
        assert n_gpus == 1
    else:
        #strategy = "ddp"
        strategy = 'ddp_find_unused_parameters_false'
        from pytorch_lightning.strategies.single_device import SingleDeviceStrategy
        strategy = SingleDeviceStrategy(device="cuda:0")
    return strategy

def parse_n(n, n_images, n_all_images): 
    if n == 0: 
        n = n_all_images
    elif 0 < n < 1:
        n = int(n * n_all_images)
    elif n_images >= 1: 
        n = n
    else: 
        raise ValueError(f"Invalid n images")

    return n

def process(trainer, args, config): 
    import systems
    import datasets

    dm = datasets.make(config.dataset.name, config.dataset)
    system = systems.make(config.system.name, config, load_from_checkpoint=None if not args.resume_weights_only else args.resume)
    system = system.to("cuda:0")

    if args.debug: 
        if args.resume and not args.resume_weights_only:
            trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
        else:
            trainer.fit(system, datamodule=dm)
    elif args.train:
        if args.resume and not args.resume_weights_only:
            trainer.fit(system, datamodule=dm, ckpt_path=args.resume)
        else:
            trainer.fit(system, datamodule=dm)
        if args.test:
            trainer.test(system, datamodule=dm, ckpt_path=args.resume)
    elif args.validate:
        trainer.validate(system, datamodule=dm, ckpt_path=args.resume)
    elif args.test:
        trainer.test(system, datamodule=dm, ckpt_path=args.resume)
    elif args.predict:
        trainer.predict(system, datamodule=dm, ckpt_path=args.resume)

    return system

def get_trainer(callbacks, loggers, strategy, config): 
    from pytorch_lightning import Trainer

    trainer = Trainer(
        devices=1, 
        accelerator='gpu',
        callbacks=callbacks,
        logger=loggers,
        strategy=strategy,
        #plugins=[MyClusterEnvironment()], 
        **config.trainer
    )

    return trainer

def launch(result_logger, args, config, control_param, control_params_names, device_id): 
    n_init_images = config.dataset.n_images
    args, config = copy.deepcopy(args), copy.deepcopy(config)
    
    if device_id is not None:
        device = device_id
    else:
        device = 0

    ### Do not move anything before this part!!!!! Do not change this part!!! 

    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    os.environ['CUDA_VISIBLE_DEVICES'] = str(device_id)
    #device = 1
    import torch.multiprocessing
    torch.multiprocessing.set_sharing_strategy('file_system')
    ### Do not move anything before this part!!!!! Do not change this part!!! 

    config = get_config(args, config, control_param, n_init_images)

    import datasets
    dm = datasets.make(config.dataset.name, config.dataset)
    print("max_num_new_imgs: ", config.nbv.max_num_new_imgs, "n_images: ", config.dataset.n_images, "dm.n_train_images: ", dm.n_train_images)
    config.nbv.max_num_new_imgs = parse_n(config.nbv.max_num_new_imgs, config.dataset.n_images, dm.n_train_images)
    config.trainer.max_steps = (config.nbv.add_nbv_start_steps + config.nbv.add_nbv_n_steps*config.nbv.max_num_new_imgs \
                + config.system.steps_after_nbv) if config.nbv.use else config.trainer.max_steps
    
    #if config.checkpoint.get("save_last", False):
    #    config.checkpoint.every_n_train_steps = config.trainer.max_steps-2

    print(f"Total Training Steps: {config.trainer.max_steps}")

    callbacks = get_callback(args, config)
    loggers = get_loggers(args, config)
    strategy = get_distri_strategy()

    trainer = get_trainer(callbacks, loggers, strategy, config)
    system = process(trainer, args, config)

    #control_param[4] = ''.join([str(r) for r in control_param[4]])

    log_item = dict(zip(control_params_names, control_param))
    log_item.update({"test_psnr": system.test_psnr})
    log_item.update({"test_ssim": system.test_ssim})
    log_item.update({"test_lpips": system.test_lpips})

    result_logger.log_new_item(log_item)
    result_logger.export_csv()

def parse_param(*params):
    if len(params) == 0:
        return [[]]
    else:
        result = []
        rest = parse_param(*params[1:])
        for value in params[0]:
            for combination in rest:
                result.append([value] + combination)
        return result

'''
Add param
1. Add in main control_params_list, control_params_names 
2. Parse in get_config
3. Change config.tag in get_config to avoid same folder name

'''

def get_config(args, config, control_param, n_init_images=1): 
    n_image, scene, seed, ig, add_nbv_n_steps, planner = control_param
    config.tag = f"{scene}_n_image_{n_image}_planner_{planner}_add_nbv_n_steps_{add_nbv_n_steps}_ig_{ig}_seed_{seed}"

    config.seed = seed
    config.dataset.scene = scene

    config.dataset.root_dir = os.path.join(config.exp.root_dir, scene)

    config.nbv.use = False
    config.dataset.n_images = n_image

    config.trial_name = config.tag + datetime.now().strftime('@%Y%m%d-%H%M%S')
    config.save_dir = os.path.join(config.exp_dir, config.trial_name, 'save')
    config.ckpt_dir = os.path.join(config.exp_dir, config.trial_name, 'ckpt')
    config.code_dir = os.path.join(config.exp_dir, config.trial_name, 'code')
    config.config_dir = os.path.join(config.exp_dir, config.trial_name, 'config')

    if args.debug: 
        config.system.debug = True
        config.trainer.max_steps = 1000000000

    if args.sanity: 
        config = sanity_check_config(config)

    print(config.trainer.max_steps, "config.trainer.max_steps")

    if seed is None: 
        config.seed = int(time.time() * 1000) % 1000
    
    import pytorch_lightning as pl
    pl.seed_everything(config.seed)
    #print(config.system.scheduler.schedulers[1].args.total_iters, "first off")
    #print(config.system.scheduler.schedulers[2].args.gamma, "first")

    #config = resolve_config(config)
    return config

def sanity_check_config(config): 
    config.dataset.n_images = 4

    config.system.steps_after_nbv = 50
    config.system.save_training_video_interval = 250

    config.trainer.max_steps = 1000
    config.trainer.limit_test_batches = 1
    config.trainer.val_check_interval = 250

    config.model.geometry.isosurface.resolution = 64

    return config

def main():
    # set CUDA_VISIBLE_DEVICES then import pytorch-lightning
    # parse YAML config to OmegaConf

    config = load_config(args.config, cli_args=extras)
    config.cmd_args = vars(args)

    yaml_name = args.config.split('/')[-1].split('.')[0]
    config.name = yaml_name+datetime.now().strftime('@%Y%m%d-%H%M%S')
    config.exp_dir = os.path.join(args.exp_dir, config.name)

    print("exp_dir: ", config.exp_dir)

    # Get the available GPU IDs
    device_ids = [int(gpu) for gpu in args.gpu.split(',')]
    print("device_ids: ", device_ids)

    # Can be removed if mp_launch is not proper to use when debugging
    if config.system.debug or args.debug or yaml_name.startswith("debug"): 
        config.name = "debug"
    
    scene_list = config.exp.scene_list if "scene_list" in config.get('exp') else config.dataset.scene_list 
    n_images = config.exp.n_images
    seeds = config.exp.seed
    ig = config.exp.ig
    add_nbv_n_steps = config.exp.add_nbv_n_steps
    
    planner = config.exp.planner

    control_params_list = [n_images, scene_list, seeds, ig, add_nbv_n_steps, planner]
    control_params_names = ["n_image", "scene", "seeds", "ig", "add_nbv_n_steps", "planner"]
    params = parse_param(*control_params_list)

    if args.completion: 
        with open(f"{args.completion}/results.json", 'r') as f:
            log_dict = json.load(f)

        remove_count = 0
        params_search = copy.deepcopy(params)
        for param in params_search: 
            exact = False
            for one_exp in log_dict.values(): 
                pair = True
                for i, p in enumerate(param): 
                    if p != list(one_exp.values())[i]: 
                        pair = False
                        break
                if pair: 
                    exact = True
                    remove_count += 1
                    break
            if exact: 
                params.remove(param)
        print(len(params), remove_count, "len(params), remove_count")
        config.exp_dir = args.completion

    os.makedirs(config.exp_dir, exist_ok=True)
    result_logger = ResultLogger(os.path.join(config.exp_dir, "results.json"))

    print(f"Remaining {len(params)} exps to be run")

    # Initialize the list of running processes and the set of used GPU IDs
    running_processes = []
    used_device_ids = set()

    # Start the first batch of processes
    for param, device_id in zip(params[:len(device_ids)], device_ids):
        p = multiprocessing.Process(target=launch, args=(result_logger, args, config, param, control_params_names, device_id))
        p.device_id = device_id
        p.start()
        running_processes.append(p)
        used_device_ids.add(device_id)

    params = params[len(device_ids):]

    # Set up a signal handler to catch SIGINT (Ctrl-C)
    def signal_handler(sig, frame):
        # Terminate all child processes
        for p in running_processes:
            p.terminate()

    signal.signal(signal.SIGINT, signal_handler)

    # Check if any processes have finished and start new ones if possible
    while running_processes:
        for p in running_processes:
            if not p.is_alive():
                running_processes.remove(p)
                used_device_ids.remove(p.device_id)

                # Check if there are any configurations left to run
                if params:
                    # Check if there are any free GPU IDs
                    free_device_ids = [id for id in device_ids if id not in used_device_ids]
                    if free_device_ids:
                        # Start a new process with the next configuration and a free GPU ID
                        param = params.pop(0)
                        device_id = free_device_ids[0]
                        p = multiprocessing.Process(target=launch, args=(result_logger, args, config, param, control_params_names, device_id))
                        p.device_id = device_id
                        p.start()
                        running_processes.append(p)
                        used_device_ids.add(device_id)
                else:
                    break

    result_logger.export_csv()

    print("exp_dir: ", config.exp_dir)

if __name__ == '__main__': 
    multiprocessing.set_start_method('fork')
    main()
