import gc
import io
import os
import time

import os

import numpy as np
import logging
# Keep the import below for registering all model definitions
from models import ddpm, ncsnv2, ncsnpp, classifier
from models import utils as mutils
from models.ema import ExponentialMovingAverage
import datasets2 as datasets
# import evaluation
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint

import pytorch_lightning as pl
import score_model
from pytorch_lightning.plugins import DDPPlugin

from pytorch_lightning.callbacks import ModelCheckpoint

import torch.nn.functional as F
import pickle
import ml_collections
import random

from matplotlib import pyplot as plt
import argparse
import torch.distributed as dist
torch.backends.cudnn.benchmark = True 
torch.set_float32_matmul_precision('medium')

def get_model(path):
    workdir = f"workdirs/{path}"
    config = pickle.load(open(f"{workdir}/config.pkl","rb")).unlock()
    ckpt_path = f"{workdir}/last.ckpt"
    model = score_model.ScoreModel(config, workdir).cuda()
    sd = torch.load(ckpt_path)
    model.load_state_dict(sd["state_dict"])
    model.on_load_checkpoint(sd)
    model.ema.copy_to(model.score_model.parameters())
    model.clf_ema.copy_to(model.clf.parameters())
    model=model.cuda().eval()
    
    return model,config 


if __name__=='__main__':
    import tqdm 
    parser = argparse.ArgumentParser()
    defaults = {
        'workdir':'none',
        'num_samples':100,
        'bs':128,
        'id':0
    }
    for k,v in defaults.items():
        parser.add_argument(f"--{k}", default=v if v!='none' else '', type=type(v))
    args = parser.parse_args()
    ID = args.id
    BS = args.bs
    N = args.num_samples
    print(args)
    model, config = get_model(args.workdir)
    for c in range(10):
        os.makedirs(f'workdirs/{args.workdir}/samples/{c}',exist_ok=True)
        labels = torch.LongTensor([c]*N)
        start  = len([i for i in os.listdir(f'workdirs/{args.workdir}/samples/{c}') if i.endswith(f'{ID}.npz')])
        
        j = start 
        start = start*BS
        for i in tqdm.tqdm(range(start,len(labels),BS)):
            L = labels[i:i+BS]
            with torch.no_grad():
                x_con = model.reversediffusion_langevin_samples(num_samples=L.shape[0],labels=L)

            samples = np.clip(x_con.permute(0, 2, 3, 1).cpu().numpy() * 255., 0, 255).astype(np.uint8)
            samples = samples.reshape((-1, config.data.image_size, config.data.image_size, config.data.num_channels))
            np.savez_compressed(f"workdirs/{args.workdir}/samples/{c}/samples_{j}_{ID}.npz",samples=samples)
            save_image(
                    x_con,
                    f"workdirs/{args.workdir}/samples/{c}/samples_{j}_{ID}.png")
            j += 1
