import argparse, os, sys, glob, datetime, yaml, pathlib, copy
import torch
from pathlib import Path
import time
import numpy as np
from tqdm import trange
sys.path.insert(0, os.path.dirname(pathlib.Path(__file__).parent.absolute())   )
from omegaconf import OmegaConf
from PIL import Image
import torch_fidelity
import random 

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, get_obj_from_str
from pl_modules.severity_encoder_module import SeverityEncoderModule
from pytorch_lightning import seed_everything
from data.data_transforms import ImageDataTransform
from data.image_data import CelebaDataset, FFHQDataset
from pl_modules.utils import load_config_from_yaml, load_np_to_tensor, str2int 
from data.metrics import psnr, LPIPS, ssim, nmse, mse
from data.operators import create_noise_schedule, create_operator

import matplotlib.pyplot as plt
rescale_to_zero_one = lambda x: (x + 1.) / 2.
rescale_to_minusone_one = lambda x: x * 2. - 1.
lpips = LPIPS('vgg')

def nested_get(dic, keys):    
    for key in keys:
        dic = dic[key]
    return dic

def nested_set(dic, keys, value):
    for key in keys[:-1]:
        dic = dic.setdefault(key, {})
    dic[keys[-1]] = value

def custom_to_pil(x):
    x = x.detach().cpu()
    x = torch.clamp(x, -1., 1.)
    x = (x + 1.) / 2.
    x = x.permute(1, 2, 0).numpy()
    x = (255 * x).astype(np.uint8)
    x = Image.fromarray(x)
    if not x.mode == "RGB":
        x = x.convert("RGB")
    return x

def custom_to_np(x):
    # saves the batch in adm style as in https://github.com/openai/guided-diffusion/blob/main/scripts/image_sample.py
    sample = x.detach().cpu()
    sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8)
    sample = sample.permute(0, 2, 3, 1)
    sample = sample.contiguous()
    return sample

def logs2pil(logs, keys=["sample"]):
    imgs = dict()
    for k in logs:
        try:
            if len(logs[k].shape) == 4:
                img = custom_to_pil(logs[k][0, ...])
            elif len(logs[k].shape) == 3:
                img = custom_to_pil(logs[k])
            else:
                print(f"Unknown format for key {k}. ")
                img = None
        except:
            img = None
        imgs[k] = img
    return imgs

class LatentDataConsistency:
    def __init__(self, fwd_operator, noise_schedule, encode_fn, decode_fn, dc_step, scaling_method='std', scale_with_alphas=False, fwd_seed=None):
        self.dc_step = dc_step
        self.fwd_operator = fwd_operator
        self.fwd_seed = fwd_seed
        self.encode_fn = encode_fn
        self.decode_fn = decode_fn
        self.scaling_method = scaling_method
        self.sigma = None
        self.noise_schedule = noise_schedule
        self.scale_with_alphas = scale_with_alphas
        
    def update_y(self, y, t):
        self.y = y
        self.t = t
        self.sigma = self.noise_schedule.get_std(self.t) * 2 # Scale noise from [0, 1] to [-1, 1]
        
    def update_z(self, z_mean, z_var_pred):
        self.z_mean = z_mean
        self.z_var_pred = z_var_pred
        
    def update_fwd_seed(self, seed):
        self.fwd_seed = seed
        
    def get_noised_z(self, noise_fn):
        return noise_fn(self.z_mean, self.z_var_pred)
            
    def modify_score(self, z_past, z_next, z0_pred, scale=None):
        if self.scale_with_alphas:
            assert scale is not None
        else:
            scale = 1.0     
        x0_pred = self.decode_fn(z0_pred)
        y_pred = self.fwd_operator(rescale_to_zero_one(x0_pred), self.t * torch.ones(1).to(x0_pred.device), seed=self.fwd_seed)
        y_pred = rescale_to_minusone_one(y_pred)
        error = (self.y - y_pred).pow(2).sum()
        grad = torch.autograd.grad(error, z_past)[0]
        if self.scaling_method == 'error':
            step = self.dc_step / torch.sqrt(error) * grad
        elif self.scaling_method == 'std':
            step = self.dc_step / (self.sigma ** 2) * grad
        elif self.scaling_method == 'identity':
            step = grad
        z_out = z_next - scale * step
        return z_out
    
def _var_to_t(var, model, method, ddim_sampler=None):
    if method == 'ddpm':
        t_pred = [alpha_cumprod <= 1 / (1 + var) for alpha_cumprod in model.alphas_cumprod].index(True)
    elif method == 'ddim':
        assert ddim_sampler is not None
        t_pred = [alpha_cumprod <= 1 / (1 + var) for alpha_cumprod in ddim_sampler.ddim_alphas].index(True)
    return t_pred

def find_ldm_start(model, y, method, var_mult=None, corr_mult=None, ddim_sampler=None, start_time=None):
    assert method in ['ddpm', 'ddim']
    z_mean, var = model.encode_first_stage(y)
    
    if var_mult is not None:
        var = var * var_mult
        
    if corr_mult is not None:
        var_corr = var * corr_mult
        z_start = torch.sqrt(1 - var_corr) * z_mean + torch.randn_like(z_mean) * torch.sqrt(var_corr)
    else:
        z_start = z_mean
    
    if start_time is None:
        t = _var_to_t(var, model, method, ddim_sampler)
    else:
        t = start_time
        ts = torch.full((1,), t, device=z_mean.device, dtype=torch.long)
        z_start = model.q_sample(z_mean, ts)

    b = z_mean.shape[0]
    
    return z_start, t, z_mean, var

@torch.no_grad()
def reconstruct_sample(model, x, method='ddpm', start_time=None, dc_corrector=None, dc_correct_freq=1.0, var_mult=None, corr_mult=None, custom_steps=None, eta=None, ddim=None, reconstructor=None):
    
    log = dict()
    format_metric = lambda x: float(torch.squeeze(x).cpu().numpy())
    metrics = {}
    
    if method == 'ddim':
        assert ddim is not None

    degraded_img = x["degraded_noisy"].cuda()
    t = x['t'].cuda()
    clean_img = x["clean"].cuda()
    
    cond = None
    uc = None
    uc_scale = None
    
    z_T, start_T, z_mean, var_pred = find_ldm_start(model, degraded_img, method, var_mult, corr_mult, ddim, start_time)
    x_auto_recon = model.decode_first_stage(z_mean)
    
    shape = [degraded_img.shape[0],
             model.model.diffusion_model.in_channels,
             model.model.diffusion_model.image_size,
             model.model.diffusion_model.image_size]
    
    if dc_corrector is not None:
        seed = str2int(x["fname"][0])
        dc_corrector.update_fwd_seed(seed)
        if dc_corrector.scaling_method in ['latent_std', 'latent_error']:
            dc_corrector.update_z(z_mean, var_pred)
        else:
            dc_corrector.update_y(degraded_img, t)
    
    t0 = time.time()
    model.zero_grad()

    if method == 'ddpm':
        sample, progrow = model.latent_reconstruction(x_T=z_T, 
                                                      start_T=start_T, 
                                                      shape=shape, 
                                                      cond=cond,
                                                      unconditional_guidance_scale=uc_scale,
                                                      unconditional_conditioning=uc,
                                                      verbose=False, 
                                                      dc_corrector=dc_corrector,
                                                      dc_correct_freq=dc_correct_freq,
                                                     )
    elif method == 'ddim':
        bs = shape[0]
        shape_ddim = shape[1:]
        sample, progrow = ddim.latent_reconstruction(x_T=z_T, 
                                                     start_T=start_T, 
                                                     S=custom_steps, 
                                                     conditioning=cond,
                                                     unconditional_guidance_scale=uc_scale,
                                                     unconditional_conditioning=uc,
                                                     dc_corrector=dc_corrector, 
                                                     dc_correct_freq=dc_correct_freq,
                                                     batch_size=bs, 
                                                     shape=shape_ddim, 
                                                     eta=eta, 
                                                     verbose=False,
                                                    )
    else:
        raise ValueError('Unsupported reconstruction method.')
    t1 = time.time()
    x_recon = model.decode_first_stage(sample)  
    y_recon = model.first_stage_model.fwd_operator(x_recon, t)
    
    log["var_pred"] = var_pred
    log["start_T"] = start_T
    log["degraded_img"] = degraded_img
    log["clean_img"] = clean_img
    log["recon"] = x_recon
    log["x_auto_recon"] = x_auto_recon
    log["y_recon"] = y_recon
    log["severity"] = t
    log["time"] = t1 - t0
    log['throughput'] = sample.shape[0] / (t1 - t0)
    return log

@torch.no_grad()
def run(model, 
        data, 
        logdir, 
        batch_size=1, 
        method='ddpm', 
        n_samples=5, 
        dc_corrector=None, 
        dc_correct_freq=1.0,
        custom_steps=None, 
        eta=None,
        var_mult=None,
        corr_mult=None,
        reconstructor=None,
        start_time=None,
        expname='recon',
        evaluate=False,
        evaluate_gen_metrics=False,
       ):
    
    tstart = time.time()
        
    if method == 'ddim':
        ddim = DDIMSampler(model)
        ddim.make_schedule(ddim_num_steps=custom_steps, ddim_eta=eta, verbose=False)
    else:
        ddim = None
        
    fnames = {}
    
    # Reconstruct
    with model.ema_scope("Reconstruction."):
        for i, item in enumerate(data):
            if i >= n_samples:
                break
            print("{}/{}".format(i+1, n_samples))
            logs = reconstruct_sample(model, item, method=method, dc_corrector=dc_corrector, dc_correct_freq=dc_correct_freq, var_mult=var_mult, corr_mult=corr_mult, custom_steps=custom_steps, eta=eta, ddim=ddim, reconstructor=reconstructor, start_time=start_time) 
            fnames[i] = item['fname'][0]
            save_logs(logs, logdir, i, expname)
    
    print(f"reconstruction of {n_samples} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
    
    # Save list of images in root dir
    with open(Path(logdir) / 'images.yml', 'w') as outfile:
        yaml.dump(fnames, outfile)    
        
    if evaluate:
        results = evaluate_results(logdir, expname, evaluate_gen_metrics, model.device)
    else: 
        results = None
        
    return results
        
@torch.no_grad()
def save_logs(log, log_root, image_id, expname='recon'):
    tensor_formatter = lambda im: rescale_to_zero_one(im.cpu().numpy())
    image_formatter = lambda im: rescale_to_zero_one(im[0].permute(1,2,0).cpu().numpy()).clip(0.0, 1.0)
    
    Path(log_root).mkdir(parents=True, exist_ok=True)
    (Path(log_root) / 'target').mkdir(parents=True, exist_ok=True)
    (Path(log_root) / 'noisy').mkdir(parents=True, exist_ok=True)
    (Path(log_root) / expname).mkdir(parents=True, exist_ok=True)
    (Path(log_root) / (expname + '_degraded')).mkdir(parents=True, exist_ok=True)
    (Path(log_root) / (expname + '_auto')).mkdir(parents=True, exist_ok=True)
    image_id = str(image_id)
    
    np.save(str((Path(log_root) / 'target' / (image_id+'.npy'))), tensor_formatter(log["clean_img"]))
    plt.imsave(str((Path(log_root) / 'target' / (image_id+'.png'))), image_formatter(log["clean_img"]))

    np.save(str((Path(log_root) / 'noisy' / (image_id+'.npy'))), tensor_formatter(log["degraded_img"]))
    plt.imsave(str((Path(log_root) / 'noisy' / (image_id+'.png'))), image_formatter(log["degraded_img"]))

    np.save(str(Path(log_root) / expname / (image_id+'.npy')), tensor_formatter(log["recon"]))    
    plt.imsave(str(Path(log_root)/ expname / (image_id+'.png')), image_formatter(log["recon"]))
    
    np.save(str(Path(log_root) / (expname + '_degraded') / (image_id+'.npy')), tensor_formatter(log["y_recon"]))    
    plt.imsave(str(Path(log_root)/ (expname + '_degraded') / (image_id+'.png')), image_formatter(log["y_recon"]))
    
    np.save(str(Path(log_root) / (expname + '_auto') / (image_id+'.npy')), tensor_formatter(log["x_auto_recon"]))    
    plt.imsave(str(Path(log_root)/ (expname + '_auto') / (image_id+'.png')), image_formatter(log["x_auto_recon"]))
    
    recon_info = {"var_pred": float(log["var_pred"][0].cpu().numpy()), "severity": float(log["severity"][0].cpu().numpy()), "start_T": log["start_T"], "wall_time": log["time"]}
    with open(Path(log_root)/ expname / (image_id +'_recon_info.yml'), 'w') as outfile:
        yaml.dump(recon_info, outfile) 

@torch.no_grad()
def evaluate_results(log_root, expname='recon', evaluate_gen_metrics=False, device='cuda:0'):
    print('Evaluating results.')
    (Path(log_root) / expname / 'eval').mkdir(parents=True, exist_ok=True)
    
    target_files = sorted(list((Path(log_root) / 'target').glob('*.npy')))
    degraded_files = sorted(list((Path(log_root) / 'noisy').glob('*.npy')))
    recon_files = sorted(list((Path(log_root) / expname).glob('*.npy')))
    recon_info_files = sorted(list((Path(log_root) / expname).glob('*_recon_info.yml')))
    y_recon_files = sorted(list((Path(log_root) / (expname + '_degraded')).glob('*.npy')))
    assert len(target_files) == len(recon_files) == len(y_recon_files) == len(degraded_files)
    
    ssim_arr = []
    psnr_arr = []
    nmse_arr = []
    lpips_arr = []
    mse_dc_arr = []
    start_T_arr = []
    
    num_images = len(target_files)
    print('Number of images found: ', num_images)
    for target, recon, y_recon, recon_info, noisy in zip(target_files, recon_files, y_recon_files, recon_info_files, degraded_files):
        assert str(target.stem) == str(recon.stem) == str(noisy.stem)
        target_arr = load_np_to_tensor(target, device)
        recon_arr = load_np_to_tensor(recon, device).clip(0.0, 1.0)
        y_recon_arr = load_np_to_tensor(y_recon, device).clip(0.0, 1.0)
        noisy_arr = load_np_to_tensor(noisy, device)
        start_T = load_config_from_yaml(recon_info)['start_T']
        
        ssim_arr.append(ssim(target_arr, recon_arr).cpu().numpy())
        psnr_arr.append(psnr(target_arr, recon_arr).cpu().numpy())
        nmse_arr.append(nmse(target_arr, recon_arr).cpu().numpy())
        lpips_arr.append(lpips(target_arr, recon_arr).cpu().numpy())
        mse_dc_arr.append(mse(y_recon_arr, noisy_arr).cpu().numpy())
        start_T_arr.append(float(start_T))
        
    # Compute generative/distributional metrics
    if evaluate_gen_metrics:
        recon_path = str(Path(log_root) / expname)
        target_path = str(Path(log_root) / 'target')
        gen_res = torch_fidelity.calculate_metrics(input1=recon_path,
                                               input2=target_path,
                                               cuda=True,
                                               isc=False,
                                               fid=True,
                                               kid=False,
                                               verbose=False,
                                              )
        
    # Aggregate results
    ssim_final = np.array(ssim_arr).mean()
    psnr_final = np.array(psnr_arr).mean()
    nmse_final = np.array(nmse_arr).mean()
    lpips_final = np.array(lpips_arr).mean()
    mse_dc_final = np.array(mse_dc_arr).mean()
    start_T_mean_final = np.array(start_T_arr).mean()
    results = {'ssim': float(ssim_final), 
               'psnr': float(psnr_final), 
               'nmse': float(nmse_final), 
               'lpips': float(lpips_final), 
               'dc_mse': float(mse_dc_final),
               'start_T_mean': float(start_T_mean_final)
              }
    
    if evaluate_gen_metrics:
        for k,v in gen_res.items():
            results[k] = v
        
    with open(Path(log_root) / expname / 'eval' / 'final_metrics.yml', 'w') as outfile:
        yaml.dump(results, outfile)
    
    for k,v in results.items():
        print('{}:{}'.format(k, v))
        
    return results
        
    
def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--logdir",
        type=str,
        help="Logging directory.",
    )
    parser.add_argument(
        "--expname",
        type=str,
        default='recon',
        help="Experiment name. Reconstructions will be saved under this folder.",
    )
    parser.add_argument(
        "--recon_config_path",
        type=str,
        help="Path to config file.",
    )
    parser.add_argument(
        "--evaluate",
        action='store_true',
        default=False,
        help="Set to evaluate reconstruction results.",
    )
    parser.add_argument(
        "--evaluate_gen_metrics",
        action='store_true',
        default=False,
        help="Set to evaluate generative metrics, such as FID.",
    )
    parser.add_argument(
        '--batch_run_key', 
        default=None,  
        nargs='+',
        type=str,          
        help='Run batch of experiments over the key specified here, where the values are read from the config file. The corresponding values must be in a list in the config.',
    )
    return parser


def load_model_from_config(config, sd):
    model = instantiate_from_config(config)
    model.load_state_dict(sd,strict=False)
    model.cuda()
    model.eval()
    return model

def load_model(config, ckpt):
    if ckpt:
        print(f"Loading model from {ckpt}")
        pl_sd = torch.load(ckpt, map_location="cpu")
    else:
        pl_sd = {"state_dict": None}
    global_step = None
    model = load_model_from_config(config.model,
                                   pl_sd["state_dict"])

    return model , global_step

if __name__ == "__main__":
    sys.path.append(os.getcwd())

    # Load config files
    parser = get_parser()
    opt, unknown = parser.parse_known_args()
    recon_config = load_config_from_yaml(opt.recon_config_path)
    
    # Handle hparam sweeps
    recon_configs = []
    results = {}
    if opt.batch_run_key is None:
        recon_configs.append((recon_config, opt.expname))
        print('Running a single config for reconstruction.')
    else:       
        hparams = nested_get(recon_config, opt.batch_run_key)
        expname_key = opt.expname
        for key in opt.batch_run_key:
            expname_key += '_'
            expname_key += key
        for hparam in hparams:
            recon_config_single = copy.deepcopy(recon_config)
            nested_set(recon_config_single, opt.batch_run_key, hparam)
            exp_id = expname_key + '_' + str(hparam)  
            recon_configs.append((recon_config_single, exp_id))
        print('Running a batch of {} configs:'.format(len(hparams)))
        for _, exp in recon_configs:
            print(exp)
    
    for i, (r_config, name) in enumerate(recon_configs):   
        g = torch.Generator()
        g.manual_seed(0)
        np.random.seed(0)
        random.seed(0)
        torch.cuda.manual_seed(0)
        print('Starting {}/{}: {}'.format(i+1, len(recon_configs), name))
        opt.expname = name
        opt.resume = r_config['model']['ldm']
        ckpt = r_config['model']['ldm']
        model_dir = '/'.join(ckpt.split('/')[:-1])

        model_configs = sorted(glob.glob(os.path.join(model_dir, "config.yaml")))

        configs = [OmegaConf.load(cfg) for cfg in model_configs]
        cli = OmegaConf.from_dotlist(unknown)
        config = OmegaConf.merge(*configs, cli)

        logdir = opt.logdir

        print('Main configs: {}, reconstruction configs: {}'.format(config, r_config))

        # Set up model
        recon_method = r_config['reconstruction']['method']
        degradation_config = load_config_from_yaml(r_config['data']['degradation_config'])
        fwd_operator = create_operator(degradation_config['operator'])
        noise_schedule = create_noise_schedule(degradation_config['noise'])
        model, _ = load_model(config, ckpt) 
        if recon_method not in ['ccdf-l']:
            severity_encoder = SeverityEncoderModule.load_from_checkpoint(
                r_config['model']['severity_encoder_ckpt'],
                ldm_model_ckpt_path=r_config['model']['ldm'],
                ldm_model_config_path=r_config['model']['ldm'].rstrip('model.ckpt') + 'config.yaml',
                lr_milestones=None, 
                dc_reg=0.0,
                img_space_reg=0.0,
            ).to(model.device)
            severity_encoder.eval()
            model.first_stage_model = severity_encoder
        else:
            model.first_stage_model.fwd_operator = fwd_operator
            
        # Set up dataset
        dataset_config = load_config_from_yaml('configs/data_configs/dataset_config.yaml')
        dataset_key = r_config['data']['dataset']
        dataset_path = dataset_config[dataset_key]['path']
        dataset_class = dataset_config[dataset_key]['dataset_class']

        if 0.0 <= r_config['data']['fixed_severity'] <= 1.0:
            fixed_t = r_config['data']['fixed_severity']
        else:
            fixed_t = None
        data_transform = ImageDataTransform(
            is_train=False, 
            operator_schedule=degradation_config['operator'],
            noise_schedule=degradation_config['noise'],
            fixed_t=fixed_t,
        )

        num_images_per_class = {}
        if 'num_images_per_class' in r_config['data']:
            num_images_per_class = {'num_images_per_class' : r_config['data']['num_images_per_class']}
            
        dataset = get_obj_from_str(dataset_class)(
            root=dataset_path,
            split=r_config['data']['split'],
            transform=data_transform,
            **num_images_per_class,
            )
        
        def seed_worker(worker_id):
            worker_seed = torch.initial_seed() % 2**32
            np.random.seed(worker_seed)
            random.seed(worker_seed)

        dataloader = torch.utils.data.DataLoader(
                dataset=dataset,
                batch_size=1,
                num_workers=4,
                generator=g,
                worker_init_fn=seed_worker,
            ) 

        # Set up data consistency correction
        if r_config['data_consistency']['dc_step'] > 0.0:
            dc_corrector = LatentDataConsistency(
                fwd_operator=fwd_operator, 
                noise_schedule=noise_schedule, 
                encode_fn=lambda x: model.differentiable_encode_first_stage(x)[0], 
                decode_fn=lambda x: model.differentiable_decode_first_stage(x), 
                dc_step=r_config['data_consistency']['dc_step'],
                scaling_method=r_config['data_consistency']['scaling_method'],
                scale_with_alphas=r_config['data_consistency']['scale_with_alphas'],
            )        
            dc_correct_freq = r_config['data_consistency']['dc_correct_freq'] if 'dc_correct_freq' in r_config['data_consistency'] else 1.0
            r_config['data_consistency']['dc_correct_freq'] = dc_correct_freq
        else:
            dc_corrector = None
            dc_correct_freq = 1.0
        
        # Set up SwinIR reconstructor, if using CCDF-L
        if r_config['reconstruction']['method'] == 'ccdf-l':
            reconstructor = SwinIRModule.load_from_checkpoint(r_config['reconstruction']['recon_model_ckpt']).to(model.device)
            start_time = r_config['reconstruction']['start_time']
        else: 
            reconstructor = None
            if 'start_time' in r_config['reconstruction']:
                start_time = r_config['reconstruction']['start_time']
            else:
                start_time = None

        print("logging to:", logdir)
        os.makedirs(logdir, exist_ok=True)
        print(75 * "=")

        # Save reconstruction config
        sampling_file = os.path.join(logdir, opt.expname + "_config.yaml")

        with open(sampling_file, 'w') as f:
            yaml.dump(r_config, f, default_flow_style=False)

        # Run reconstruction
        result = run(model=model, 
            logdir=logdir, 
            data=dataloader, 
            n_samples=r_config['data']['num_images'],
            batch_size=1, 
            method=recon_method,
            dc_corrector=dc_corrector, 
            dc_correct_freq=dc_correct_freq,
            custom_steps=r_config['reconstruction']['ddim_steps'] if recon_method == 'ddim' else None,
            eta=r_config['reconstruction']['ddim_eta'] if recon_method == 'ddim' else None,
            var_mult=r_config['reconstruction']['var_mult'],
            corr_mult=r_config['reconstruction']['corr_mult'],
            reconstructor=reconstructor,
            start_time=start_time,
            expname=opt.expname,
            evaluate=opt.evaluate,
            evaluate_gen_metrics=opt.evaluate_gen_metrics,
           )
        results[name] = result
        print('Finished ', name)
        
    for name in results:
        print(name)
        for k, v in results[name].items():
            print('{}: {}'.format(k, v))
    
    results_file = os.path.join(logdir, "results_summary.yaml")
    with open(results_file, 'w') as f:
        yaml.dump(results, f, default_flow_style=False)
    print("done.")
