"""
Class-conditional image translation from one ImageNet class to another.
"""
import torch
import argparse
import os
import logging
import shutil
import accelerate
from pathlib import Path

import blobfile as bf
import numpy as np
import torch as th
import torch.distributed as dist
import torch.nn.functional as F
from torch import multiprocessing as mp
from PIL import Image

from guided_diffusion import logger
from guided_diffusion.image_datasets import (
    load_source_data_for_domain_translation,
    get_image_filenames_for_label
)
from guided_diffusion.script_util import (
    model_and_diffusion_defaults,
    classifier_defaults,
    create_model_and_diffusion,
    create_classifier,
    add_dict_to_argparser,
    args_to_dict,
)

## copy the source images into the val_dir
def copy_imagenet_dataset(imagenet_dir,val_dir,out_dir,classes):
    """
    Finds the validation images for the given classes from val_dir,
    and copies them over to ./experiments/imagenet for translation.
    "/raid/common/imagenet-raw/train"
    """
    base_dir = imagenet_dir
    #base_dir = os.path.join(os.getcwd(), "experiments", "imagenet")
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    data_dir = base_dir + '/' + val_dir
    files = sorted(bf.listdir(data_dir))
    for source_label in classes:
        logger.log(f"Copying image files for class {source_label}.")
        class_dir_name = files[source_label]
        imagenet_dir_final = data_dir + '/' + class_dir_name
        filenames = bf.listdir(imagenet_dir_final)
        #filenames = get_image_filenames_for_label(source_label)
        for i, filename in enumerate(filenames):
            path = os.path.join(imagenet_dir_final, filename)
            copy_path = os.path.join(out_dir, f"{source_label}_{i + 1}.PNG")
            shutil.copyfile(path, copy_path)


def evaluate(config):
    mp.set_start_method('spawn')
    accelerator = accelerate.Accelerator()

    device = accelerator.device
    accelerate.utils.set_seed(config.seed, device_specific=True)
    if accelerator.is_main_process:
        logger.log(f"arguments: {config}")
        logging.info(f'Process {accelerator.process_index} using device: {device}')


    #dist_util.setup_dist()
    logger.configure()

    logger.log("creating model and diffusion...")
    model, diffusion = create_model_and_diffusion(**config.net_diffusion)
    model.load_state_dict(torch.load(config.model_path))
    model.to(device)
    #model = accelerator.prepare(model)

    if config.net_diffusion.use_fp16:
        model.convert_to_fp16()
    model.eval()

    logger.log("loading classifier...")
    classifier = create_classifier(**config.classifier)
    classifier.load_state_dict(torch.load(config.classifier_path))
    classifier.to(device)
    if config.classifier.classifier_use_fp16:
        classifier.convert_to_fp16()
    classifier.eval()
    logger.info("Loading all networks sucessfully")
    #model,classifier = accelerator.prepare(model,classifier)

    # classifier-guidance
    def cond_fn(x, t, y=None):
        assert y is not None
        with th.enable_grad():
            x_in = x.detach().requires_grad_(True)
            logits = classifier(x_in, t)
            log_probs = F.log_softmax(logits, dim=-1)
            selected = log_probs[range(len(logits)), y.view(-1)]
            return th.autograd.grad(selected.sum(), x_in)[0] * config.sample.classifier_scale

    def model_fn(x, t, y=None):
        assert y is not None
        return model(x, t, y if config.net_diffusion.class_cond else None)

    # Copies the source dataset from ImageNet validation set.
    logger.log("copying source dataset.")
    source = [int(v) for v in config.sample.source.split(",")]
    target = [int(v) for v in config.sample.target.split(",")]
    source_to_target_mapping = {s: t for s, t in zip(source, target)}

    output_final_dir = config.output_dir + '/source/'
    try:
        os.mkdir(output_final_dir)
    except:
        pass
    copy_imagenet_dataset(config.imagenet_path,config.val_path,output_final_dir, source)


    logger.log("running image translation...")
    dataloader = load_source_data_for_domain_translation(
        batch_size=config.sample.batch_size,
        image_size=config.net_diffusion.image_size,
        data_dir=output_final_dir,
    )
    model,classifier,dataloader = accelerator.prepare(model,classifier,dataloader)

    for i, (batch, extra) in enumerate(dataloader):
        logger.log(f"translating batch {i}, shape {batch.shape}.")
        logger.log(f"translating batch {i}, mean {batch.mean()}.")
        logger.log("saving the original, cropped images.")
        images = ((batch + 1) * 127.5).clamp(0, 255).to(th.uint8)
        images = images.permute(0, 2, 3, 1)
        images = images.contiguous()
        images = images.cpu().numpy()
        for index in range(images.shape[0]):
            filepath = extra["filepath"][index]
            image = Image.fromarray(images[index])
            image.save(filepath)
            logger.log(f"    saving: {filepath}")

        batch = batch.to(device)

        # Class labels for source and target sets
        source_y = dict(y=extra["y"].to(device))
        target_y_list = [source_to_target_mapping[v.item()] for v in extra["y"]]
        target_y = dict(y=th.tensor(target_y_list).to(device))

        logger.log("encoding the source images.")
        if config.sample.algorithm == 'ode':
            ### ode/sde inversion
            noise = diffusion.ddim_reverse_sample_loop(
                model_fn,
                batch,
                clip_denoised=False,
                model_kwargs=source_y,
                device=device,
            )
            #noise = torch.randn_like(batch).to(batch.device)
            logger.log(f"obtained latent representation for {batch.shape[0]} samples...")
            logger.log(f"latent with mean {noise.mean()} and std {noise.std()}")

            # Next, decode the latents to the target class. 
            ### ode/sde sampler
            sample = diffusion.ddim_sample_loop(
                model_fn,
                (config.sample.batch_size, 3, config.net_diffusion.image_size, config.net_diffusion.image_size),
                noise=noise,
                clip_denoised=config.sample.clip_denoised,
                model_kwargs=target_y,
                cond_fn=cond_fn,
                device=device,
                eta=config.sample.eta
            )
        elif config.sample.algorithm == 'sde':
            noise,noise_list = diffusion.sde_reverse_sample_loop(
                model_fn,
                batch,
                clip_denoised=False,
                model_kwargs=source_y,
                device=device,
            )
            #logger.log(f"latent shape {noise.shape}")
            ##logger.log(f"latent is {noise}")
            logger.log(f"obtained latent representation for {batch.shape[0]} samples...")
            logger.log(f"latent with mean {noise.mean()} and std {noise.std()}")
            logger.log(f"the length of noise list is {len(noise_list)}")

            # Next, decode the latents to the target class. 
            ### ode/sde sampler
            sample = diffusion.sde_sample_loop(
                model_fn,
                (config.sample.batch_size, 3, config.net_diffusion.image_size, config.net_diffusion.image_size),
                noise=noise,
                clip_denoised=config.sample.clip_denoised,
                model_kwargs=target_y,
                cond_fn=cond_fn,
                device=device,
                eta=config.sample.eta,
                sde=True,
                noise_list=noise_list
            )
        sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8)
        sample = sample.permute(0, 2, 3, 1)
        sample = sample.contiguous()

        images = []
        #gathered_samples = [th.zeros_like(sample) for _ in range(1)]
        #dist.all_gather(gathered_samples, sample)  # gather not supported with NCCL
        images.extend([sample.cpu().numpy()])
        logger.log(f"created {len(images) * config.sample.batch_size} samples")

        logger.log("saving translated images.")
        images = np.concatenate(images, axis=0)

        for index in range(images.shape[0]):
            base_dir, filename = os.path.split(extra["filepath"][index])
            base_dir_final = config.output_dir + '/translated/' 
            filename, ext = filename.split(".")
            try:
                os.mkdir(base_dir_final)
            except:
                pass
            filepath = os.path.join(base_dir_final, f"{filename}_translated_{target_y_list[index]}.{ext}")
            image = Image.fromarray(images[index])
            image.save(filepath)
            logger.log(f"    saving: {filepath}")

    dist.barrier()
    logger.log(f"domain translation complete")

from absl import flags
from absl import app
from ml_collections import config_flags
import os

FLAGS = flags.FLAGS
config_flags.DEFINE_config_file(
    "config", None, "Training configuration.", lock_config=False)
flags.mark_flags_as_required(["config"])
flags.DEFINE_string("nnet_path", None, "The nnet to evaluate.")
flags.DEFINE_string("output_path", None, "The path to output log.")


def main(argv):
    config = FLAGS.config
    config.nnet_path = FLAGS.nnet_path
    config.output_path = FLAGS.output_path
    evaluate(config)


if __name__ == "__main__":
    app.run(main)
