import os
import torch
import random
import pandas as pd
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler

import numpy as np 
from PIL import Image
from einops import rearrange
from torchvision.utils import make_grid
from tqdm import tqdm


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


SEED = 42
set_seed(SEED)


def load_model_from_config(config, ckpt, device='cuda:0'):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")  # 当gpu内存不够时用cpu加载
    sd = pl_sd["state_dict"]
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    device = device if torch.cuda.is_available() else 'cpu'
    model.to(device)
    print(model.device)
    model.eval()
    return model


def get_model(config_name, ckpt_path, device='cuda:0'):
    config = OmegaConf.load(os.path.join('configs/latent-diffusion', '{}.yaml'.format(config_name)))  
    model = load_model_from_config(config, ckpt_path, device=device)
    return model


batch_size = 50
n_samples_per_cond = 100
ddim_steps = 20
ddim_eta = 0.0
uncondition_guidance_scale = 3.0
device = 'cuda:1'

config_name = 'tai-sem-ldm-vq-f8'
ckpt_path = 'logs/2024-07-14T19-31-56_tai-sem-ldm-vq-f8/checkpoints/epoch=000044.ckpt'
eval_dir = os.path.join('logs/eval', config_name)

if not os.access(eval_dir, os.F_OK):
    os.makedirs(eval_dir)

# Load Model
model = get_model(config_name, ckpt_path, device=device)
sampler = DDIMSampler(model, device=device)

# Load Cond
conds = pd.read_csv('tai_data/sum1126.csv').iloc[:, -7:]

with torch.no_grad():
    with model.ema_scope():
        for j in range(conds.shape[0]):
            y1, y2, y3, y4, y5, y6, y7 = conds.iloc[j].to_numpy()
            print('Condition {} / {}: '.format(j + 1, conds.shape[0]), y1, y2, y3, y4, y5, y6, y7)
            uc = model.get_learned_7_conditioning(
                torch.tensor(y1).to(device),
                torch.tensor(y2).to(device),
                torch.tensor(y3).to(device),
                torch.tensor(y4).to(device),
                torch.tensor(y5).to(device),
                torch.tensor(y6).to(device),
                torch.tensor(y7).to(device)
            )

            all_samples = list()
            for bid in tqdm(range((n_samples_per_cond-1)//batch_size + 1)):
                n_samples = min(batch_size, n_samples_per_cond - bid*batch_size)
                xc1 = torch.tensor(y1).to(model.device)
                xc2 = torch.tensor(y2).to(model.device)
                xc3 = torch.tensor(y3).to(model.device)
                xc4 = torch.tensor(y4).to(model.device)
                xc5 = torch.tensor(y5).to(model.device)
                xc6 = torch.tensor(y6).to(model.device)
                xc7 = torch.tensor(y7).to(model.device)
                xc = [xc1,xc2,xc3,xc4,xc5,xc6,xc7]
                c = model.get_learned_7_conditioning(xc[0],xc[1],xc[2],xc[3],xc[4],xc[5],xc[6],return_oc=False)
                samples_ddim, _ = sampler.sample(S=ddim_steps,
                                                conditioning=c.unsqueeze(0).repeat((n_samples, 1)),
                                                batch_size=n_samples,
                                                shape=[4, 32, 32],
                                                verbose=False,
                                                unconditional_guidance_scale=uncondition_guidance_scale,
                                                unconditional_conditioning=uc.unsqueeze(0).repeat((n_samples, 1)), 
                                                eta=ddim_eta)

                x_samples_ddim = model.decode_first_stage(samples_ddim)
                x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, 
                                            min=0.0, max=1.0)
                all_samples.append(x_samples_ddim)
            all_samples = torch.cat(all_samples, dim=0)

            for i in tqdm(range(n_samples_per_cond)):
                # display as grid
                grid = all_samples[i:i+1].unsqueeze(0)
                grid = rearrange(grid, 'n b c h w -> (n b) c h w')
                grid = make_grid(grid, nrow=1)

                # to image
                grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
                name = 'c{:>03d}_i{:>03d}'.format(j, i)
                save_path = os.path.join(eval_dir, 'samples_{}.png'.format(name))
                Image.fromarray(grid.astype(np.uint8)).save(save_path)
