import os, sys
import pathlib
from pathlib import Path
sys.path.insert(0, os.path.dirname(pathlib.Path(__file__).parent.absolute())   )
from ddpm_examples.utils import load_config_from_yaml
from reconstruction.reverse_diffusion import ReverseDiffusion
from data.data_transforms import ImageDataTransform
from data.image_data import ImageNetDataset
from argparse import ArgumentParser
import copy 

def nested_get(dic, keys):    
    for key in keys:
        dic = dic[key]
    return dic

def nested_set(dic, keys, value):
    for key in keys[:-1]:
        dic = dic.setdefault(key, {})
    dic[keys[-1]] = value

def cli_main(args):
    print(args.__dict__)
    config_dict = load_config_from_yaml(args.recon_config_path)
    configs_to_run = []
    if args.batch_run_key is None:
        if args.experiment_name is None:
            exp_name = str(Path(args.recon_config_path).stem)
        else:
            exp_name = args.experiment_name
        configs_to_run.append((config_dict, exp_name))
    else:
        assert isinstance(args.batch_run_key, list)
        values = nested_get(config_dict, args.batch_run_key)
        assert isinstance(values, list)
        for v in values:
            config_v = copy.deepcopy(config_dict)
            nested_set(config_v, args.batch_run_key, v)
            if args.experiment_name is None:
                exp_name = str(Path(args.recon_config_path).stem)
            else:
                exp_name = args.experiment_name
            for k in args.batch_run_key:
                exp_name += '_'
                exp_name += k
            exp_name += '_'
            exp_name += str(v)
            configs_to_run.append((copy.deepcopy(config_v), exp_name))
            
    for config, exp_name in configs_to_run:
        print('Running reconstruction experiment {} with config {}.'.format(exp_name, config))
        reconstructor = ReverseDiffusion(model_ckpt_path=args.model_ckpt_path, 
                                         config_dict=config, 
                                         device=args.device,
                                         output_path=args.output_path, 
                                         experiment_name=exp_name,
                                        )

        val_transform = ImageDataTransform(is_train=False, 
                                           operator_config=reconstructor.config['operator_config'], 
                                           noise_config=reconstructor.config['noise_config'],
                                           dt=reconstructor.config['dt']
                                          )

        val_dataset = ImageNetDataset(
            root=args.data_path,
            split='val',
            transform=val_transform,
            num_images_per_class=1,
            )

        reconstructor.load_data(val_dataset)
        results = reconstructor.run_batch(args.num_images, 
                                          save=args.save_outputs,
                                          evaluate=args.eval_results,
                                          evaluate_gen_metrics=args.evaluate_gen_metrics,
                                          save_intermediate_times=args.fid_eval_times,
                                         )

def build_args():
    parser = ArgumentParser()

    parser.add_argument(
        '--recon_config_path', 
        type=str,          
        help='Reconstruction configuration will be loaded from this file.',
    )
    parser.add_argument(
        '--model_ckpt_path', 
        type=str,          
        help='Model will be loaded from this file.',
    )
    parser.add_argument(
        '--data_path', 
        type=str,          
        help='Containing directory of imagenet dataset. Val split will be loaded from here.',
    )
    parser.add_argument(
        '--output_path', 
        type=str,          
        help='Root folder where outputs will be generated. Target and noisy images will be stored here.',
    )
    parser.add_argument(
        '--experiment_name', 
        type=str,          
        help='Used to name the folder for the output reconstructions. If not given, the reconstruction config name will be used.',
    )
    parser.add_argument(
        '--device', 
        default='cuda:0',   
        type=str,          
        help='Which device to run the reconstruction on.',
    )
    parser.add_argument(
        '--batch_run_key', 
        default=None,  
        nargs='+',
        type=str,          
        help='Run batch of experiments over the key specified here, where the values are read from the config file. The corresponding values must be in a list in the config.',
    )
    parser.add_argument(
        '--num_images', 
        default=3,
        type=int,          
        help='Number of images to be reconstructed from validation dataset.',
    )       
    parser.add_argument(
        '--save_outputs', 
        default=False,   
        action='store_true',          
        help='If set, outputs will be saved to specified dir.',
    )    
    parser.add_argument(
        '--eval_results', 
        default=False,   
        action='store_true',          
        help='If set, reconstructions will be evaluated against ground truth.',
    )
    parser.add_argument(
        '--evaluate_gen_metrics', 
        default=False,   
        action='store_true',          
        help='If set, FID will be evaluated on final reconstructions.',
    )
    parser.add_argument(
        '--fid_eval_times', 
        default=None,   
        type=int,         
        help='If set to an integer > 2, FID will be evaluated at this many intermediate steps during reconstruction.',
    )
    args = parser.parse_args()
    return args

def run_cli():
    args = build_args()

    # ---------------------
    # RUN RECONSTRUCTION
    # ---------------------
    cli_main(args)


if __name__ == "__main__":
    run_cli()