import torch
import os
import math
import argparse
import models_mage
import modified_mage
import numpy as np
from tqdm import tqdm
# import cv2
from PIL import Image

import time

def save_image(path, img_array):
    img = Image.fromarray(img_array)
    img.save(path)

def build_halton_mask(input_size, nb_point=10_000):
    """ 
        Taken from https://github.com/valeoai/Halton-MaskGIT/blob/main/Sampler/halton_sampler.py
        Generate a halton 'quasi-random' sequence in 2D.
        :param
        input_size -> int: size of the mask, (input_size x input_size).
        nb_point   -> int: number of points to be sample, it should be high to cover the full space.
        h_base     -> torch.LongTensor: seed for the sampling.
        :return:
        mask -> Torch.LongTensor: (input_size x input_size) the mask where each value corresponds to the order of sampling.
    """

    def halton(b, n_sample):
        """Naive Generator function for Halton sequence."""
        n, d = 0, 1
        res = []
        for index in range(n_sample):
            x = d - n
            if x == 1:
                n = 1
                d *= b
            else:
                y = d // b
                while x <= y:
                    y //= b
                n = (b + 1) * y - x
            res.append(n / d)
        return res

    # Sample 2D mask
    data_x = torch.asarray(halton(2, nb_point)).view(-1, 1)
    data_y = torch.asarray(halton(3, nb_point)).view(-1, 1)
    mask = torch.cat([data_x, data_y], dim=1) * input_size
    mask = torch.floor(mask)

    # remove duplicate
    indexes = np.unique(mask.numpy(), return_index=True, axis=0)[1]
    mask = [mask[index].numpy().tolist() for index in sorted(indexes)]
    mask_1d = [input_size*a + b for a,b in mask]
    return torch.LongTensor(np.array(mask_1d))

def mask_by_random_topk(mask_len, probs, temperature=1.0, is_log=False, multiple_stages=False):
    # mask_len = mask_len.squeeze()
    log_probs = probs if is_log else torch.log(probs)
    confidence = log_probs + torch.Tensor(temperature * np.random.gumbel(size=probs.shape)).cuda()

    sorted_confidence, sorted_indices = torch.sort(confidence, axis=-1)

    # # Obtains cut off threshold given the mask lengths.
    # cut_off = sorted_confidence[:, mask_len.long()-1:mask_len.long()]
    # # Masks tokens with lower confidence.
    # masking = (confidence <= cut_off)

    batch_indices = torch.arange(probs.shape[0], device=probs.device).unsqueeze(1) #.expand(-1, M)

    if not multiple_stages:
        masking = torch.zeros_like(confidence, dtype=torch.bool)
        masking[batch_indices.expand(-1, mask_len), sorted_indices[:,:mask_len]] = True
        return masking

    else:
        # first_len = first_len.squeeze()
        # first_cut_off = sorted_confidence[:, first_len.long()-1:first_len.long()]
        # first_masking = (confidence <= first_cut_off)

        maskings = []
        for i in range(len(mask_len)):
            m_len = mask_len[i]
            masking = torch.zeros_like(confidence, dtype=torch.bool)
            masking[batch_indices.expand(-1, m_len), sorted_indices[:, :m_len]] = True
            maskings.append(masking)
        return maskings


def gen_image(model, bsz, seed, num_iter=12, choice_temperature=4.5, is_random=False, moment=False, unbiased=False, halton=False):
    torch.manual_seed(seed)
    np.random.seed(seed)
    codebook_emb_dim = 256
    codebook_size = 1024
    mask_token_id = model.mask_token_label
    unknown_number_in_the_beginning = 256
    _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf
    
    halton_seq = None
    if halton:
        halton_seq = build_halton_mask(16)
        moment = True

    initial_token_indices = mask_token_id * torch.ones(bsz, unknown_number_in_the_beginning)

    token_indices = initial_token_indices.cuda()

    for step in range(num_iter):
        cur_ids = token_indices.clone().long()

        token_indices = torch.cat(
            [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
        token_indices[:, 0] = model.fake_class_label
        token_indices = token_indices.long()
        token_all_mask = token_indices == mask_token_id

        token_drop_mask = torch.zeros_like(token_indices)

        # token embedding
        input_embeddings = model.token_emb(token_indices)

        # encoder
        x = input_embeddings
        for blk in model.blocks:
            x = blk(x)
        x = model.norm(x)

        # decoder
        logits = model.forward_decoder(x, token_drop_mask, token_all_mask)
        logits = logits[:, 1:, :codebook_size]

        # Defines the mask ratio for the next round. The number to mask out is
        # determined by mask_ratio * unknown_number_in_the_beginning.
        ratio = 1. * (step + 1) / num_iter
        mask_ratio = np.cos(math.pi / 2. * ratio)

        # get ids for next step
        unknown_map = (cur_ids == mask_token_id)

        mask_len = torch.Tensor([np.floor(unknown_number_in_the_beginning * mask_ratio)]).cuda()
        # Keeps at least one of prediction in this round and also masks out at least
        # one and for the next iteration
        mask_len = torch.maximum(torch.Tensor([1]).cuda(),
                                 torch.minimum(torch.sum(unknown_map, dim=-1, keepdims=True) - 1, mask_len))          

        if moment and (step < num_iter-1):
            
            logits_weight = 1 + 1/(choice_temperature * (1 - ratio))
            if unbiased:
                sample_dist = torch.distributions.categorical.Categorical(logits=logits*1.)
            else:
                sample_dist = torch.distributions.categorical.Categorical(logits=logits*logits_weight)
            sampled_ids = sample_dist.sample()
            sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)

            if halton:
                masking = torch.zeros_like(sampled_ids[0]).bool()
                masking[halton_seq[-mask_len[0].long().item():]] = True
                masking = masking.unsqueeze(0).expand(bsz, -1)

            else:
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                if is_random:
                    selected_probs = torch.zeros_like(sampled_ids)
                else:
                    selected_probs = torch.logsumexp(log_probs*logits_weight, -1) # moment computation
                selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
                masking = mask_by_random_topk(mask_len[0].long().item(), selected_probs, 1., is_log=True)

        else:

            logits_weight = 1.

            sample_dist = torch.distributions.categorical.Categorical(logits=logits*logits_weight)
            sampled_ids = sample_dist.sample()
            
            # keep already known tokens
            sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)


            # sample ids according to prediction confidence
            probs = torch.nn.functional.softmax(logits, dim=-1)
            if is_random:
                selected_probs = torch.ones_like(sampled_ids)
            else:
                selected_probs = torch.squeeze(
                    torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
            selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
            # Sample masking tokens for next iteration
            masking = mask_by_random_topk(mask_len[0].long().item(), selected_probs, choice_temperature * (1 - ratio))
        
        # Masks tokens with lower confidence.
        token_indices = torch.where(masking, mask_token_id, sampled_ids)

    # vqgan visualization
    z_q = model.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, 16, 16, codebook_emb_dim))
    gen_images = model.vqgan.decode(z_q)
    return gen_images


def gen_image_cache(model, bsz, seed, num_iter=12, choice_temperature=4.5, is_random=False, moment=False, num_substeps=4, guidance_w=0, unbiased=False, halton=False):
    torch.manual_seed(seed)
    np.random.seed(seed)
    codebook_emb_dim = 256
    codebook_size = 1024
    mask_token_id = model.mask_token_label
    unknown_number_in_the_beginning = 256
    _CONFIDENCE_OF_KNOWN_TOKENS = +np.inf

    initial_token_indices = mask_token_id * torch.ones(bsz, unknown_number_in_the_beginning)

    token_indices = initial_token_indices.cuda()

    for step in range(num_iter):
        cur_ids = token_indices.clone().long()

        token_indices = torch.cat(
            [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
        token_indices[:, 0] = model.fake_class_label
        token_indices = token_indices.long()
        token_all_mask = token_indices == mask_token_id

        token_drop_mask = torch.zeros_like(token_indices)

        logits = model.inference(token_indices, token_drop_mask, token_all_mask, cache_kv=True)

        # Defines the mask ratio for the next round. The number to mask out is
        # determined by mask_ratio * unknown_number_in_the_beginning.

        ratios = (step + 1 - np.arange(num_substeps)*1./num_substeps) / num_iter
        mask_ratios = np.cos(math.pi / 2. * ratios)

        # get token prediction
        logits_weight = 1.
        logits_second_weight = 1.
        if moment and (step < num_iter-1):
            logits_weight = 1 + 1/(choice_temperature * (1 - ratios[0]))
            # logits_second_weight = 1 + 1/(choice_temperature * (1 - second_ratio))
            logits_next_weights = 1 + 1/(choice_temperature * (1 - ratios - 1./num_iter))
        if unbiased:
            sample_dist = torch.distributions.categorical.Categorical(logits=logits)
        else:
            sample_dist = torch.distributions.categorical.Categorical(logits=logits*logits_weight)
        sampled_ids = sample_dist.sample()

        # get ids for next step
        unknown_map = (cur_ids == mask_token_id)
        sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)

        mask_lens = np.int_(np.floor(unknown_number_in_the_beginning * mask_ratios))
        if step == num_iter-1:
            mask_lens[0] = 0

        if moment and (step < num_iter-1):
            if is_random:
                selected_probs = torch.ones_like(sampled_ids)
            else:
                log_probs = torch.nn.functional.log_softmax(logits, dim=-1)
                selected_probs = torch.logsumexp(log_probs*logits_weight, -1) # moment computation
            selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
            maskings = mask_by_random_topk(mask_lens, selected_probs, 1., is_log=True, multiple_stages=True)
        else:
            if is_random:
                selected_probs = torch.ones_like(sampled_ids)
            else:
                # sample ids according to prediction confidence
                probs = torch.nn.functional.softmax(logits, dim=-1)
                selected_probs = torch.squeeze(
                    torch.gather(probs, dim=-1, index=torch.unsqueeze(sampled_ids, -1)), -1)
            selected_probs = torch.where(unknown_map, selected_probs.double(), _CONFIDENCE_OF_KNOWN_TOKENS).float()
            # Sample masking tokens for next iteration
            maskings = mask_by_random_topk(mask_lens, selected_probs, choice_temperature * (1 - ratios[0]), multiple_stages=True)

        # Masks tokens with lower confidence.
        token_indices = torch.where(maskings[-1], mask_token_id, sampled_ids)

        for i in range(num_substeps - 1):
            if i == 0:
                update_indices_bool = unknown_map & (~maskings[-i-2])
            else:
                update_indices_bool = maskings[-i] & (~maskings[-i-2])

            update_indices_list = []
            for b in range(bsz):
                indices = torch.where(update_indices_bool[b])[0]
                update_indices_list.append(indices)

            max_indices = max([len(indices) for indices in update_indices_list])

            if max_indices > 0: 
                update_indices = torch.zeros(bsz, max_indices, dtype=torch.long, device=token_indices.device)
                for b, indices in enumerate(update_indices_list):
                    if len(indices) > 0:
                        update_indices[b, :len(indices)] = indices
            
            _token_indices = torch.cat(
                [torch.zeros(token_indices.size(0), 1).cuda(device=token_indices.device), token_indices], dim=1)
            _token_indices[:, 0] = model.fake_class_label
            _token_indices = _token_indices.long()
            refined_logits = model.inference(_token_indices, update_indices=update_indices.long())
            refined_logits = torch.nn.functional.log_softmax(refined_logits, dim=-1)

            
            if guidance_w > 0:
                # guidance on refined logits
                batch_indices = torch.arange(bsz, device=refined_logits.device).unsqueeze(1).expand(-1, refined_logits.shape[1])
                base_logits = log_probs[batch_indices, update_indices.long()]
                refined_logits = (1.+guidance_w) * (refined_logits - base_logits) + base_logits

            if unbiased:
                refined_sampled_ids = torch.distributions.categorical.Categorical(logits=refined_logits).sample()
            else:
                refined_sampled_ids = torch.distributions.categorical.Categorical(logits=refined_logits*logits_next_weights[-i-1]).sample()

            _sampled_ids = sampled_ids.clone()
            # reflect partial updates
            for b in range(bsz):
                indices = update_indices_list[b]
                if len(indices) > 0:
                    _sampled_ids[b, indices] = refined_sampled_ids[b]
        
            # sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)

            token_indices = torch.where(update_indices_bool & maskings[-i-1], _sampled_ids, token_indices)



    # vqgan visualization
    z_q = model.vqgan.quantize.get_codebook_entry(sampled_ids, shape=(bsz, 16, 16, codebook_emb_dim))
    gen_images = model.vqgan.decode(z_q)
    return gen_images

parser = argparse.ArgumentParser('MAGE generation', add_help=False)
parser.add_argument('--temp', default=4.5, type=float,
                    help='sampling temperature')
parser.add_argument('--num_iter', default=12, type=int,
                    help='number of iterations for generation')
parser.add_argument('--batch_size', default=32, type=int,
                    help='batch size for generation')
parser.add_argument('--num_images', default=50000, type=int,
                    help='number of images to generate')
parser.add_argument('--ckpt', type=str,
                    help='checkpoint')
parser.add_argument('--model', default='mage_vit_base_patch16', type=str,
                    help='model')
parser.add_argument('--output_dir', default='output_dir/fid/gen/mage-vitb', type=str,
                    help='name')
# added
parser.add_argument('--random', action='store_true', help='ignore confidence')
parser.add_argument('--moment', action='store_true', help='use moment sampler')
parser.add_argument('--cache', default=0, type=int, help='caching steps')
parser.add_argument('--guidance', default=0, type=float, help='guidance scale')
parser.add_argument('--unbiased', action='store_true', help='no-temperature sampling')
parser.add_argument('--halton', action='store_true', help='use Halton scheduler')
parser.add_argument('--debug', action='store_true')

args = parser.parse_args()

vqgan_ckpt_path = 'vqgan_jax_strongaug.ckpt'

if args.cache > 0:
    model = modified_mage.__dict__[args.model](norm_pix_loss=False,
                                            mask_ratio_mu=0.55, mask_ratio_std=0.25,
                                            mask_ratio_min=0.0, mask_ratio_max=1.0,
                                            vqgan_ckpt_path=vqgan_ckpt_path)
else:
    model = models_mage.__dict__[args.model](norm_pix_loss=False,
                                            mask_ratio_mu=0.55, mask_ratio_std=0.25,
                                            mask_ratio_min=0.0, mask_ratio_max=1.0,
                                            vqgan_ckpt_path=vqgan_ckpt_path)

model.to(0)

checkpoint = torch.load(args.ckpt, map_location='cpu')
model.load_state_dict(checkpoint['model'])
model.eval()

num_steps = args.num_images // args.batch_size + 1
gen_img_list = []
folder_name = "temp{}-iter{}".format(args.temp, args.num_iter)
if args.random:
    folder_name = folder_name + "-random"
if args.halton:
    folder_name = folder_name + "-halton"
if args.moment:
    folder_name = folder_name + "-moment"
if args.unbiased:
    folder_name = folder_name + "-unbiased"
if args.cache > 0:
    folder_name = folder_name + "-cache{}".format(args.cache)
if args.guidance > 0:
    folder_name = folder_name + "-w{}".format(args.guidance)

if args.debug:
    folder_name = "debug"

save_folder = os.path.join(args.output_dir, folder_name)
if not os.path.exists(save_folder):
    os.makedirs(save_folder)

measure_time = False
elapsed_times = []

for i in tqdm(range(num_steps)):
    if measure_time:
        start_time = time.time()
    with torch.no_grad():
        if args.cache > 0:
            gen_images_batch = gen_image_cache(model, bsz=args.batch_size, seed=i+200000, choice_temperature=args.temp, num_iter=args.num_iter, is_random=args.random, moment=args.moment, num_substeps=args.cache, guidance_w=args.guidance, unbiased=args.unbiased, halton=args.halton)
        else:
            gen_images_batch = gen_image(model, bsz=args.batch_size, seed=i+200000, choice_temperature=args.temp, num_iter=args.num_iter, is_random=args.random, moment=args.moment, unbiased=args.unbiased, halton=args.halton)
    
    if measure_time:
        elapsed_times.append(time.time() - start_time)

    gen_images_batch = gen_images_batch.detach().cpu()
    gen_img_list.append(gen_images_batch)

    # save img
    for b_id in range(args.batch_size):
        if i*args.batch_size+b_id >= args.num_images:
            break
        gen_img = np.clip(gen_images_batch[b_id].numpy().transpose([1, 2, 0]) * 255, 0, 255)
        gen_img = gen_img.astype(np.uint8) #[:, :, ::-1]
        # cv2.imwrite(os.path.join(save_folder, '{}.png'.format(str(i*args.batch_size+b_id).zfill(5))), gen_img)
        save_image(os.path.join(save_folder, '{}.png'.format(str(i*args.batch_size+b_id).zfill(5))), gen_img)

if measure_time:
    etimes = elapsed_times[1:]
    print(f"Latency per batch: {np.mean(etimes):.4f} ± {np.std(etimes):.4f} seconds")
