import os, sys
import pathlib
from pathlib import Path
sys.path.insert(0, os.path.dirname(pathlib.Path(__file__).parent.absolute())   )
from ddpm_examples.utils import load_config_from_yaml
from reconstruction.reverse_diffusion import ReverseDiffusion
from data.data_transforms import ImageDataTransform
from data.image_data import ImageNetDataset
from argparse import ArgumentParser
import copy 

def cli_main(args):
    print(args.__dict__)
    config_dict = load_config_from_yaml(args.recon_config_path)
    configs_to_run = []
    exp_name = str(Path(args.recon_config_path).stem) + '_ImageNet_test'
    configs_to_run.append((config_dict, exp_name))
            
    for config, exp_name in configs_to_run:
        print('Running ImageNet test data reconstruction {} with config {}.'.format(exp_name, config))
        reconstructor = ReverseDiffusion(model_ckpt_path=args.model_ckpt_path, 
                                         config_dict=config, 
                                         device=args.device,
                                         output_path=args.output_path, 
                                         experiment_name=exp_name,
                                        )

        test_transform = ImageDataTransform(is_train=False, 
                                           operator_config=reconstructor.config['operator_config'], 
                                           noise_config=reconstructor.config['noise_config'],
                                           dt=reconstructor.config['dt']
                                          )

        test_dataset = ImageNetDataset(
            root=args.data_path,
            split='test',
            transform=test_transform,
            num_images_per_class=1,
            )

        reconstructor.load_data(test_dataset)
        assert len(test_dataset) == 1000
        results = reconstructor.run_batch(len(test_dataset), 
                                          save=args.save_outputs, 
                                          evaluate=args.eval_results,
                                          evaluate_gen_metrics=args.evaluate_gen_metrics
                                         )

def build_args():
    parser = ArgumentParser()

    parser.add_argument(
        '--recon_config_path', 
        type=str,          
        help='Reconstruction configuration will be loaded from this file.',
    )
    parser.add_argument(
        '--model_ckpt_path', 
        type=str,          
        help='Model will be loaded from this file.',
    )
    parser.add_argument(
        '--data_path', 
        type=str,          
        help='Containing directory of celeba256 dataset. Val split will be loaded from here.',
    )
    parser.add_argument(
        '--output_path', 
        type=str,          
        help='Root folder where outputs will be generated. Target and noisy images will be stored here.',
    )
    parser.add_argument(
        '--experiment_name', 
        type=str,          
        help='Used to name the folder for the output reconstructions. If not given, the reconstruction config name will be used.',
    )
    parser.add_argument(
        '--device', 
        default='cuda:0',   
        type=str,          
        help='Which device to run the reconstruction on.',
    )      
    parser.add_argument(
        '--save_outputs', 
        default=False,   
        action='store_true',          
        help='If set, outputs will be saved to specified dir.',
    )    
    parser.add_argument(
        '--eval_results', 
        default=False,   
        action='store_true',          
        help='If set, reconstructions will be evaluated against ground truth.',
    )
    parser.add_argument(
        '--evaluate_gen_metrics', 
        default=False,   
        action='store_true',          
        help='If set, FID will be evaluated on final reconstructions.',
    )
    args = parser.parse_args()
    return args

def run_cli():
    args = build_args()

    # ---------------------
    # RUN RECONSTRUCTION
    # ---------------------
    cli_main(args)


if __name__ == "__main__":
    run_cli()