import os
import torch
import random
import pandas as pd
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
from tqdm import tqdm


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


SEED = 42
set_seed(SEED)


def load_model_from_config(config, ckpt, device='cuda:0'):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")  # 当gpu内存不够时用cpu加载
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    device = device if torch.cuda.is_available() else 'cpu'
    model.to(device)
    print(model.device)
    model.eval()
    return model


def get_model(config_name, ckpt_path, device='cuda:0'):
    config = OmegaConf.load(os.path.join('configs/latent-diffusion', '{}.yaml'.format(config_name)))  
    model = load_model_from_config(config, ckpt_path, device=device)
    return model


batch_size = 20
ddim_steps = 20
ddim_eta = 0.0
uncondition_guidance_scale = 3.0
device = 'cuda:0'

cond_file = 'metric/intra_conds_10000.npy'
config_name = 'tai-om-ldm-vq-f8'
exp_name = 'tai-om-ldm-vq-f8-uncond-10000'
if 'sem' in exp_name:
    ckpt_path = 'logs/2024-07-14T19-31-56_tai-sem-ldm-vq-f8/checkpoints/epoch=000044.ckpt'
else:
    ckpt_path = 'logs/2024-07-14T16-35-34_tai-om-ldm-vq-f8/checkpoints/epoch=000049.ckpt'
eval_dir = os.path.join('logs/eval', exp_name)

if not os.access(eval_dir, os.F_OK):
    os.makedirs(eval_dir)

# Load Model
model = get_model(config_name, ckpt_path, device=device)
sampler = DDIMSampler(model, device=device)

# Load Cond
conds = np.load(cond_file)

with torch.no_grad():
    with model.ema_scope():
        num_conds = conds.shape[0]
        for bid in tqdm(range((num_conds-1)//batch_size + 1)):
            n_samples = min(batch_size, num_conds - bid*batch_size)
            ys = conds[bid*batch_size:min((bid+1)*batch_size, num_conds)]
            ys = np.zeros_like(ys)  # uncondition
            print('Batch {} / {}: '.format(bid+1, (num_conds-1)//batch_size+1))
            c = model.get_learned_7_conditioning(
                torch.tensor(ys[:, 0]).to(device),
                torch.tensor(ys[:, 1]).to(device),
                torch.tensor(ys[:, 2]).to(device),
                torch.tensor(ys[:, 3]).to(device),
                torch.tensor(ys[:, 4]).to(device),
                torch.tensor(ys[:, 5]).to(device),
                torch.tensor(ys[:, 6]).to(device),
                return_oc = False
            )
            samples_ddim, _ = sampler.sample(S=ddim_steps,
                                            conditioning=c,
                                            batch_size=n_samples,
                                            shape=[4, 32, 32],
                                            verbose=False,
                                            unconditional_guidance_scale=uncondition_guidance_scale,
                                            unconditional_conditioning=c, 
                                            eta=ddim_eta)
            x_samples_ddim = model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                        min=0.0, max=1.0)
            for i in range(n_samples):
                # display as grid
                grid = x_samples_ddim[i:i+1].unsqueeze(0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=1)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                name = 'c{:>06d}'.format(bid*batch_size + i)
                save_path = os.path.join(eval_dir, 'samples_{}.png'.format(name))
                Image.fromarray(grid.astype(np.uint8)).save(save_path)
