import os
import numpy as np
import argparse
import torchvision
import torch as th

from composable_diffusion.model_creation import (
    create_model_and_diffusion,
    model_and_diffusion_defaults,
    Sampler_create_gaussian_diffusion,
)

from composable_diffusion.model_creation import create_model_and_diffusion, model_and_diffusion_defaults
from anneal_samplers import AnnealedMALASampler, AnnealedCHASampler, AnnealedUHASampler,AnnealedULASampler, AnnealedLHMCSampler

def pil_image_to_norm_tensor(pil_image):
    return th.from_numpy(np.asarray(pil_image)).float().permute(2, 0, 1) / 127.5 - 1.0

def convert_images(batch: th.Tensor):
    scaled = ((batch + 1)*127.5).round().clamp(0,255).to(th.uint8).permute(0, 2, 3, 1)
    return scaled

def get_caption_simple(label):
    shapes_to_idx = {"cube": 0, "sphere": 1, "cylinder":2, "none": 3}
    shapes = list(shapes_to_idx.keys())

    return f'{shapes[label[0]]}'

th.manual_seed(1)
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt_path', required=True)
parser.add_argument('--lr', type=float, default=0.1)
parser.add_argument('--pl', type=float, default=0.01)
parser.add_argument('--sampler', type=str, default="mala", choices=["LHMC", "MALA", "HMC", "UHMC", "ULA","Rev_Diff"])
args = parser.parse_args()
device = th.device('cpu' if not th.cuda.is_available() else 'cuda')
options = model_and_diffusion_defaults()

model_path= args.ckpt_path # 64 x 64
options["noise_schedule"]= "linear"
options["learn_sigma"] = False
options["use_fp16"] = False
options["num_classes"] = "3,"
options["dataset"] = "clevr_norel"
options["image_size"]     = 64
options["num_channels"]   = 128
options["num_res_blocks"] = 3
options["energy_mode"] = True
base_timestep_respacing = '100' 

diffusion = Sampler_create_gaussian_diffusion(
    steps=int(base_timestep_respacing), #1000,
    learn_sigma=options['learn_sigma'],
    noise_schedule=options['noise_schedule'],
    timestep_respacing=base_timestep_respacing,
)

if len(model_path) > 0:
    assert os.path.exists(
        model_path
    ), f"Failed to resume from {model_path}, file does not exist."
    weights = th.load(model_path, map_location="cpu")
    model,_ = create_model_and_diffusion(**options)
    model.load_state_dict(weights)

model = model.to("cuda"); model.eval()
guidance_scale = 10.0; batch_size = 1
# labels = th.tensor([[ [0], [1] ]]).long() # Compose Cube And Sphere Labels
# labels = th.tensor([[ [1], [2] ]]).long() # Compose SPHERE And Cylinder Labels
# labels = th.tensor([[ [0], [2] ]]).long() # Compose SPHERE And Cylinder Labels
labels = th.tensor([[ [2], ]]).long() # Compose Cylinder Label
# labels = th.tensor([[ [0], ]]).long() # Compose Cube Label

texts = [get_caption_simple(lab.numpy()) for lab in labels[0]]; print(texts)
labels = [x.squeeze(dim=1) for x in th.chunk(labels, labels.shape[1], dim=1)]
full_batch_size = batch_size * (len(labels) + 1)
masks = [True] * len(labels) * batch_size + [False] * batch_size
labels = th.cat((labels + [th.zeros_like(labels[0]) + 3]), dim=0)
model_kwargs = dict(
    y=labels.clone().detach().to(device),
    masks=th.tensor(masks, dtype=th.bool, device=device)
)

def cfg_model_fn(x_t, ts, **kwargs):
    combined = th.cat([x_t[:1]] * kwargs['y'].size(0), dim=0)
    eps = model(combined, ts, eval=True, **kwargs)
    cond_eps, uncond_eps = eps[:-1], eps[-1:]
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps).sum(dim=0, keepdim=True)
    eps = th.cat([half_eps] * x_t.size(0), dim=0)
    return eps

def cfg_model_fn_noen(x_t, ts, **kwargs):
    combined = th.cat([x_t[:1]] * kwargs['y'].size(0), dim=0)
    eps = model(combined, ts,**kwargs)
    cond_eps, uncond_eps = eps[:-1], eps[-1:]
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps).sum(dim=0, keepdim=True)
    eps = th.cat([half_eps] * x_t.size(0), dim=0)
    return eps

alphas = 1 - diffusion.betas
alphas_cumprod = np.cumprod(alphas)
scalar = np.sqrt(1 / (1 - alphas_cumprod))
def gradient(x_t, ts, **kwargs):
    half = x_t[:1]
    combined = th.cat([half] * kwargs['y'].size(0), dim=0)
    eps = model(combined, ts, eval=True,**kwargs)
    cond_eps, uncond_eps = eps[:-1], eps[-1:]
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps).sum(dim=0, keepdim=True)
    eps = th.cat([half_eps] * x_t.size(0), dim=0)  
    scale = scalar[ts[0]]
    return -1 * scale * eps

# HMC SAMPLER
num_steps = len(diffusion.betas)
ha_steps = 5 # 2 # Hamiltonian steps to run
num_leapfrog_steps = 4 # Steps to run in leapfrog
damping_coeff = 0.5
mass_diag_sqrt = diffusion.betas
ha_step_sizes  = diffusion.betas * 0.1 # 0.1

# MALA SAMPLER 
la_steps = 20
la_step_sizes = diffusion.betas * 0.035

def gradient_cha(x_t, ts, **kwargs):
    half = x_t[:1]
    combined = th.cat([half] * kwargs['y'].size(0), dim=0)
    energy_norm, eps = model(combined, ts, mala_sampler=True,**kwargs)
    cond_energy, uncond_energy = energy_norm[:-1], energy_norm[-1:]
    total_energy = uncond_energy.sum() + guidance_scale * (cond_energy.sum() - uncond_energy.sum())

    cond_eps, uncond_eps = eps[:-1], eps[-1:]
    half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps).sum(dim=0, keepdim=True)
    eps = th.cat([half_eps] * x_t.size(0), dim=0)  
    scale = scalar[ts[0]]
    return -scale * total_energy,-1 * scale*eps

if args.sampler == 'MALA':
    la_step_sizes = diffusion.betas * args.lr
    sampler = AnnealedMALASampler(num_steps, la_steps, la_step_sizes, gradient_cha)
elif args.sampler == 'ULA':
    la_step_sizes = diffusion.betas * args.lr
    sampler = AnnealedULASampler(num_steps, la_steps, la_step_sizes, gradient)
elif args.sampler == 'UHMC':
    ha_step_sizes = diffusion.betas * args.lr
    sampler = AnnealedUHASampler(num_steps, ha_steps, ha_step_sizes, damping_coeff, mass_diag_sqrt, num_leapfrog_steps, gradient)
elif args.sampler == 'HMC':
    ha_step_sizes = diffusion.betas * args.lr
    sampler = AnnealedCHASampler(num_steps, ha_steps, ha_step_sizes, damping_coeff, mass_diag_sqrt, num_leapfrog_steps, gradient_cha)
elif args.sampler == 'LHMC':
    step_pl = args.pl
    ha_step_sizes = diffusion.betas * args.lr
    sampler = AnnealedLHMCSampler(num_steps, ha_steps, ha_step_sizes, step_pl, damping_coeff, mass_diag_sqrt, num_leapfrog_steps, gradient_cha)
elif args.sampler == 'Rev_Diff':
    print("Using Reverse Diffusion Sampling only")
    sampler = None
else:
    raise ValueError('Not defined!')
print("Using Sampler: ",args.sampler)

all_samp = []
samples = diffusion.p_sample_loop(
    sampler,
    cfg_model_fn,
    (full_batch_size, 3, 64, 64),
    device=device,
    clip_denoised=True,
    progress=True,
    model_kwargs=model_kwargs,
    cond_fn=None,
)[:batch_size]
print(samples.shape)
sample = samples.contiguous()
sample = convert_images(sample)
show_img = sample.cpu().detach().numpy()
all_samp.append(show_img)

print('Now saving....')
arr = np.concatenate(all_samp, axis=0)
show_img = th.tensor(arr)
show_img = show_img.permute(0, 3, 1, 2) # N C H W
th.save(show_img, 'show_img.pt')
show_img = show_img.float() / 255.
cap = "-".join(texts).upper(); os.makedirs(cap, exist_ok=True)
torchvision.utils.save_image(show_img, f"{cap}/{cap}_{args.sampler}_{guidance_scale:.1f}({args.lr}-{args.pl}).png")
# if 'HMC' in args.sampler:
#     torchvision.utils.save_image(show_img, f"{cap}/{cap}_{args.sampler}_{guidance_scale:.1f}({args.lr}-{args.pl}).png")
# else:
#     torchvision.utils.save_image(show_img, f"{cap}/{cap}_{args.sampler}_{guidance_scale:.1f}.png")