import os
import os.path as osp

import torch
import torch.nn.functional as F
import numpy as np


def labels2image(all_indices, label_type='int_label', scale_schedule=None):
    summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type)
    recons_img = recons_imgs[0]
    recons_img = (recons_img + 1) / 2
    recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
    return recons_img

def features2image(raw_features):
    recons_imgs = self.vae.decode(raw_features.squeeze(-3))
    recons_img = recons_imgs[0]
    recons_img = (recons_img + 1) / 2
    recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
    return recons_img

class BitwiseSelfCorrection(object):
    def __init__(self, vae, args):
        self.noise_apply_layers = args.noise_apply_layers
        self.noise_apply_requant = args.noise_apply_requant
        self.noise_apply_strength = args.noise_apply_strength
        self.apply_spatial_patchify = args.apply_spatial_patchify
        self.vae = vae
        self.debug_bsc = args.debug_bsc

    def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
        with torch.amp.autocast('cuda', enabled = False):
            B = raw_features.shape[0]
            if raw_features.dim() == 4:
                codes_out = raw_features.unsqueeze(2)
            else:
                codes_out = raw_features
            cum_var_input = 0
            gt_all_bit_indices = []
            pred_all_bit_indices = []
            x_BLC_wo_prefix = []
            for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
                residual = codes_out - cum_var_input
                if si != len(vae_scale_schedule)-1:
                    residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
                quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
                gt_all_bit_indices.append(bit_indices)
                if si < self.noise_apply_layers:
                    noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
                    mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
                    pred_bit_indices = bit_indices.clone()
                    pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
                    pred_all_bit_indices.append(pred_bit_indices)
                    if self.noise_apply_requant:
                        quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
                else:
                    pred_all_bit_indices.append(bit_indices)
                cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                if si < len(vae_scale_schedule)-1:
                    this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                    if self.apply_spatial_patchify:
                        # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                        this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
                    x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)

            if self.apply_spatial_patchify:
                gt_ms_idx_Bl = []
                for item in gt_all_bit_indices:
                    # item shape: (B,1,H,W,d)
                    item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
                    # (B,d,H,W) -> (B,4d,H/2,W/2)
                    item = torch.nn.functional.pixel_unshuffle(item, 2)
                    # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
                    item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
                    gt_ms_idx_Bl.append(item)
            else:
                gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
            x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1)

            if self.debug_bsc:
                self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
        
        return x_BLC_wo_prefix, gt_ms_idx_Bl
    
    def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices):
        gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255
        gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1]
        recons_img_2 = labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
        recons_img_3 = labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
        cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1)
        save_path = osp.abspath('non_teacher_force.jpg')
        cv2.imwrite(save_path, cat_image)
        print(f'Save to {save_path}')
        import pdb; pdb.set_trace()
        print(cat_image.shape)
        