
import gc
import os
import io

import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
from torch import nn
from torch.nn import functional as F
# Keep the import below for registering all model definitions
from models import ncsnpp, unet_classifier, cond_ncsnpp
import sampling
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets
import losses
import sde_lib
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint
import evaluation
import run_class as run_lib
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
import logging
from tqdm import trange, tqdm
from pathlib import Path

FLAGS = flags.FLAGS

config_flags.DEFINE_config_file(
  "config", None, "Training configuration.", lock_config=True)
flags.DEFINE_string("eval_folder", "eval",
                    "The folder name for storing evaluation results")
flags.mark_flags_as_required(["config", "eval_folder"])


class ClassfierGuidance(nn.Module):
    def __init__(self, config):
        super(ClassfierGuidance, self).__init__()

        config.model.name = config.model.score
        score_model = mutils.create_model(config)
        optimizer = losses.get_optimizer(config, score_model.parameters())
        ema = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
        state = dict(optimizer=optimizer, model=score_model, ema=ema, step=0)
        state = restore_checkpoint(config.model.score_path, state, device=config.device)
        ema.copy_to(score_model.parameters())
        self.score_model = score_model
        
        if config.sampling.class_guidance:
            config.model.name = config.model.classifier
            classifier = mutils.create_model(config)
            optimizer_c = losses.get_optimizer(config, classifier.parameters())
            ema_c = ExponentialMovingAverage(classifier.parameters(), decay=config.model.ema_rate)
            state_c = dict(optimizer=optimizer_c, model=classifier, ema=ema_c, step=0)
            state_c = restore_checkpoint(config.model.class_path, state_c, device=config.device)
            ema_c.copy_to(classifier.parameters())
            self.classifier = classifier

        self.config = config

        self.scale = config.model.scale
        print(f'Classifier guidance with scale {self.scale}')
        self.y = None
        self.num_classes = config.data.num_classes

        gc.collect()

    def forward(self, x, time_cond):
        uncond = self.score_model(x, time_cond)

        if self.config.sampling.class_guidance:
            with torch.enable_grad():
                x_in = x.detach().requires_grad_(True)
                logits = self.classifier(x_in, time_cond)
                log_probs = F.log_softmax(logits, dim=-1)
                cond = torch.autograd.grad(log_probs, x_in, grad_outputs=F.one_hot(self.y, num_classes=self.num_classes))[0]

            return uncond + cond * self.scale
        
        return uncond

def main(argv):
    config = FLAGS.config
    for i in range(config.data.num_classes):
        Path(os.path.join(FLAGS.eval_folder, str(i))).mkdir(parents=True, exist_ok=True)

    # Setup SDEs
    if config.training.sde.lower() == 'vpsde':
        sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'subvpsde':
        sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
        sampling_eps = 1e-3
    elif config.training.sde.lower() == 'vesde':
        sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
        sampling_eps = 1e-5
    else:
        raise NotImplementedError(f"SDE {config.training.sde} unknown.")

    model = ClassfierGuidance(config)
    model.eval()
    scaler = datasets.get_data_scaler(config)
    inverse_scaler = datasets.get_data_inverse_scaler(config)
    sampling_shape = (config.eval.batch_size,
                    config.data.num_channels,
                    config.data.image_size, config.data.image_size)
    sampling_fn = sampling.get_sampling_fn(config, sde, sampling_shape, inverse_scaler, sampling_eps)
    eval_folder = FLAGS.eval_folder
    if config.eval.num_samples < config.eval.batch_size:
        assert config.eval.batch_size % config.eval.num_samples == 0
        assert config.data.num_classes % (config.eval.batch_size // config.eval.num_samples) == 0
        num_sampling_rounds = config.data.num_classes // (config.eval.batch_size // config.eval.num_samples)
        class_per_round = config.eval.batch_size // config.eval.num_samples
    
        print('Multi-class batch')
        for r in trange(num_sampling_rounds):
            cur_classes = range(r * class_per_round, (r + 1) * class_per_round)
            class_indices = []
            for class_idx in cur_classes:
                class_indices += [class_idx] * config.eval.num_samples
            model.y = torch.tensor(
                        class_indices, device=config.device
                    )
            samples, n = sampling_fn(model)

            for i, class_idx in enumerate(cur_classes):
                cur_samples = samples[i*config.eval.num_samples:(i+1)*config.eval.num_samples]
                nrow = int(config.eval.num_samples ** 0.5)
                image_grid = make_grid(cur_samples, nrow, padding=2)
                with tf.io.gfile.GFile(os.path.join(eval_folder, str(class_idx), f'sample_0.png'), "wb") as fout:
                    save_image(image_grid, fout)
        
                cur_samples = np.clip(cur_samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
                with tf.io.gfile.GFile(os.path.join(eval_folder, str(class_idx), f'sample_0.npz'), "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, samples=cur_samples)
                    fout.write(io_buffer.getvalue())
                gc.collect()
    else:
        num_sampling_rounds = config.eval.num_samples // config.eval.batch_size
        if config.eval.num_samples % config.eval.batch_size:
            num_sampling_rounds += 1

        for class_idx in range(config.data.num_classes):
            for r in trange(config.sampling.start_round, num_sampling_rounds):
                model.y = torch.tensor(
                            [class_idx] * config.eval.batch_size, device=config.device
                        )
                samples, n = sampling_fn(model)

                nrow = int(config.eval.batch_size ** 0.5)
                image_grid = make_grid(samples, nrow, padding=2)
                with tf.io.gfile.GFile(os.path.join(eval_folder, str(class_idx), f'sample_{r}.png'), "wb") as fout:
                    save_image(image_grid, fout)

                samples = np.clip(samples.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
                with tf.io.gfile.GFile(os.path.join(eval_folder, str(class_idx), f'sample_{r}.npz'), "wb") as fout:
                    io_buffer = io.BytesIO()
                    np.savez_compressed(io_buffer, samples=samples)
                    fout.write(io_buffer.getvalue())
                gc.collect()
if __name__ == "__main__":
  app.run(main)
