import os
import os.path as osp

import torch
import torch.nn.functional as F
import numpy as np
import pdb
import cv2


def labels2image(all_indices, label_type='int_label', scale_schedule=None):
    summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type)
    recons_img = recons_imgs[0]
    recons_img = (recons_img + 1) / 2
    recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
    return recons_img

def features2image(raw_features):
    recons_imgs = self.vae.decode(raw_features.squeeze(-3))
    recons_img = recons_imgs[0]
    recons_img = (recons_img + 1) / 2
    recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
    return recons_img

class BitwiseSelfCorrection(object):
    def __init__(self, vae, args):
        self.noise_apply_layers = args.noise_apply_layers
        self.noise_apply_requant = args.noise_apply_requant
        self.noise_apply_strength = args.noise_apply_strength
        self.apply_spatial_patchify = args.apply_spatial_patchify
        self.vae = vae
        self.debug_bsc = args.debug_bsc

    def flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
        with torch.amp.autocast('cuda', enabled = False):
            B = raw_features.shape[0]
            if raw_features.dim() == 4:
                codes_out = raw_features.unsqueeze(2)
            else:
                codes_out = raw_features
            cum_var_input = 0
            gt_all_bit_indices = []
            pred_all_bit_indices = []
            x_BLC_wo_prefix = []
            for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
                residual = codes_out - cum_var_input
                if si != len(vae_scale_schedule)-1:
                    residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
                quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
                gt_all_bit_indices.append(bit_indices)
                if si < self.noise_apply_layers:
                    noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
                    mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
                    pred_bit_indices = bit_indices.clone()
                    pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
                    pred_all_bit_indices.append(pred_bit_indices)
                    if self.noise_apply_requant:
                        quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
                else:
                    pred_all_bit_indices.append(bit_indices)
                cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                if si < len(vae_scale_schedule)-1:
                    this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                    if self.apply_spatial_patchify:
                        # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                        this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
                    x_BLC_wo_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)

            if self.apply_spatial_patchify:
                gt_ms_idx_Bl = []
                for item in gt_all_bit_indices:
                    # item shape: (B,1,H,W,d)
                    item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
                    # (B,d,H,W) -> (B,4d,H/2,W/2)
                    item = torch.nn.functional.pixel_unshuffle(item, 2)
                    # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
                    item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
                    gt_ms_idx_Bl.append(item)
            else:
                gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
            x_BLC_wo_prefix = torch.cat(x_BLC_wo_prefix, 1)

            if self.debug_bsc:
                self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
        
        return x_BLC_wo_prefix, gt_ms_idx_Bl
    
    def my_flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
        my_noise_apply_layers = -1
        with torch.amp.autocast('cuda', enabled = False):
            B = raw_features.shape[0]
            if raw_features.dim() == 4:
                codes_out = raw_features.unsqueeze(2)
            else:
                codes_out = raw_features
            cum_var_input = 0
            gt_all_bit_indices = []
            pred_all_bit_indices = []
            x_BLC_w_prefix = []
            for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
                residual = codes_out - cum_var_input
                if si != len(vae_scale_schedule)-1:
                    residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
                quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
                gt_all_bit_indices.append(bit_indices)
                if si < my_noise_apply_layers:
                    noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
                    mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
                    pred_bit_indices = bit_indices.clone()
                    pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
                    pred_all_bit_indices.append(pred_bit_indices)
                    if self.noise_apply_requant:
                        quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
                else:
                    pred_all_bit_indices.append(bit_indices)
                cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                
                # if si < len(vae_scale_schedule)-1:
                #     this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                #     if self.apply_spatial_patchify:
                #         # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                #         this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)

                if si <= len(vae_scale_schedule)-1:
                    this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_up).contiguous()
                    if self.apply_spatial_patchify:
                        # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                        this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
                    
                    x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)

            if self.apply_spatial_patchify:
                gt_ms_idx_Bl = []
                for item in gt_all_bit_indices:
                    # item shape: (B,1,H,W,d)
                    item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
                    # (B,d,H,W) -> (B,4d,H/2,W/2)
                    item = torch.nn.functional.pixel_unshuffle(item, 2)
                    # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
                    item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
                    gt_ms_idx_Bl.append(item)
            else:
                gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
            x_BLC_w_prefix = torch.cat(x_BLC_w_prefix, 1)

            if self.debug_bsc:
                self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
        return x_BLC_w_prefix, gt_ms_idx_Bl  ### note x_BLC_w_prefix instead of x_BLC_wo_prefix
    
    def long_flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
        my_noise_apply_layers = -1
        with torch.amp.autocast('cuda', enabled = False):
            B = raw_features.shape[0]
            if raw_features.dim() == 4:
                codes_out = raw_features.unsqueeze(2)
            else:
                codes_out = raw_features
            cum_var_input = 0
            gt_all_bit_indices = []
            pred_all_bit_indices = []
            x_BLC_w_prefix = []
            for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
                residual = codes_out - cum_var_input
                if si != len(vae_scale_schedule)-1:
                    residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
                quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
                gt_all_bit_indices.append(bit_indices)
                if si < my_noise_apply_layers:
                    noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
                    mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
                    pred_bit_indices = bit_indices.clone()
                    pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
                    pred_all_bit_indices.append(pred_bit_indices)
                    if self.noise_apply_requant:
                        quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
                else:
                    pred_all_bit_indices.append(bit_indices)
                cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                
                if si < len(vae_scale_schedule)-1:
                    this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                    if self.apply_spatial_patchify:
                        # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                        this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
                    x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)

                if si == len(vae_scale_schedule)-1:
                    this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_up).contiguous()
                    if self.apply_spatial_patchify:
                        # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                        this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
                    x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)

            if self.apply_spatial_patchify:
                gt_ms_idx_Bl = []
                for item in gt_all_bit_indices:
                    # item shape: (B,1,H,W,d)
                    item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
                    # (B,d,H,W) -> (B,4d,H/2,W/2)
                    item = torch.nn.functional.pixel_unshuffle(item, 2)
                    # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
                    item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
                    gt_ms_idx_Bl.append(item)
            else:
                gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
            x_BLC_w_prefix = torch.cat(x_BLC_w_prefix, 1)

            if self.debug_bsc:
                self.visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices)
        return x_BLC_w_prefix, gt_ms_idx_Bl  ### note x_BLC_w_prefix instead of x_BLC_wo_prefix

    def flow_flip_requant(self, vae_scale_schedule, inp_B3HW, raw_features, device):
        my_noise_apply_layers = -1
        with torch.amp.autocast('cuda', enabled = False):
            B = raw_features.shape[0]
            if raw_features.dim() == 4:
                codes_out = raw_features.unsqueeze(2)
            else:
                codes_out = raw_features
            cum_var_input = 0
            gt_all_bit_indices = []
            pred_all_bit_indices = []
            x_BLC_w_prefix = []
            for si, (pt, ph, pw) in enumerate(vae_scale_schedule):
                residual = codes_out - cum_var_input
                if si != len(vae_scale_schedule)-1:
                    residual = F.interpolate(residual, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_down).contiguous()
                quantized, _, bit_indices, loss = self.vae.quantizer.lfq(residual) # quantized shape: [B, d_vae, 1, h, w], bit_indices shape: [B,1,h,w,d_vae]
                gt_all_bit_indices.append(bit_indices)
                if si < my_noise_apply_layers:
                    noise_apply_strength = np.random.randint(0, 100 * self.noise_apply_strength+1) * 0.01
                    mask = torch.rand(*bit_indices.shape).to(device) < noise_apply_strength
                    pred_bit_indices = bit_indices.clone()
                    pred_bit_indices[mask] = 1 - pred_bit_indices[mask]
                    pred_all_bit_indices.append(pred_bit_indices)
                    if self.noise_apply_requant:
                        quantized = self.vae.quantizer.lfq.indices_to_codes(pred_bit_indices, label_type = 'bit_label')
                else:
                    pred_all_bit_indices.append(bit_indices)
                cum_var_input = cum_var_input + F.interpolate(quantized, size=vae_scale_schedule[-1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                
                # if si < len(vae_scale_schedule)-1:
                #     this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si+1], mode=self.vae.quantizer.z_interplote_up).contiguous()
                #     if self.apply_spatial_patchify:
                #         # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                #         this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)

                if si <= len(vae_scale_schedule)-1:
                    this_scale_input = F.interpolate(cum_var_input, size=vae_scale_schedule[si], mode=self.vae.quantizer.z_interplote_up).contiguous()
                    if self.apply_spatial_patchify:
                        # (B,d,1,H,W) -> (B,d,H,W) -> (B,4d,H/2,W/2)
                        this_scale_input = torch.nn.functional.pixel_unshuffle(this_scale_input.squeeze(-3), 2)
                    
                    x_BLC_w_prefix.append(this_scale_input.reshape(*this_scale_input.shape[:2], -1).permute(0,2,1)) # (B,H/2*W/2,4C) or (B,H*W,C)

            if self.apply_spatial_patchify:
                gt_ms_idx_Bl = []
                for item in gt_all_bit_indices:
                    # item shape: (B,1,H,W,d)
                    item = item.squeeze(1).permute(0,3,1,2) # (B,d,H,W)
                    # (B,d,H,W) -> (B,4d,H/2,W/2)
                    item = torch.nn.functional.pixel_unshuffle(item, 2)
                    # (B,4d,H/2,W/2) -> (B,H/2,W/2,4d) -> (B,H/2*w/2,4d)
                    item = item.permute(0,2,3,1).reshape(B, -1, 4*self.vae.codebook_dim)
                    gt_ms_idx_Bl.append(item)
            else:
                gt_ms_idx_Bl = [item.reshape(B, -1, self.vae.codebook_dim) for item in gt_all_bit_indices]
            x_BLC_w_prefix = torch.cat(x_BLC_w_prefix, 1)

            if self.debug_bsc:
                self.flow_visualize(vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices,raw_features)
        raw_features_quantized, _, _, _ = self.vae.quantizer.lfq(codes_out)
        raw_features_seq = raw_features.reshape(*raw_features.shape[:2], -1).permute(0,2,1)
        return x_BLC_w_prefix, gt_ms_idx_Bl,raw_features_seq  ### note x_BLC_w_prefix instead of x_BLC_wo_prefix
    
    def visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices):
        gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255
        gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1]
        recons_img_2 = self.labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
        recons_img_3 = self.labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
        cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3], axis=1)
        save_path = osp.abspath('gt-gt_indices-pred_indices_new.jpg')
        cv2.imwrite(save_path, cat_image)
        print(f'Save to {save_path}')
        print(cat_image.shape)
        import pdb; pdb.set_trace()

    def flow_visualize(self, vae_scale_schedule, inp_B3HW, gt_all_bit_indices, pred_all_bit_indices,raw_features):
        gt_img = (inp_B3HW.squeeze(-3) + 1) / 2 * 255
        gt_img = gt_img[0].permute(1,2,0).cpu().numpy().astype(np.uint8)[:,:,::-1]
        recons_img_2 = self.labels2image(gt_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
        recons_img_3 = self.labels2image(pred_all_bit_indices, label_type='bit_label', scale_schedule=vae_scale_schedule)
        recons_img_4 = self.features2image(raw_features)
        cat_image = np.concatenate([gt_img, recons_img_2, recons_img_3,recons_img_4], axis=1)
        save_path = osp.abspath('gt-gt_indices-pred_indices-raw_features_new.jpg')
        cv2.imwrite(save_path, cat_image)
        print(f'Save to {save_path}')
        print(cat_image.shape)
        import pdb; pdb.set_trace()
        
    def labels2image(self,all_indices, label_type='int_label', scale_schedule=None):
        summed_codes, recons_imgs = self.vae.decode_from_indices(all_indices, scale_schedule, label_type)
        recons_img = recons_imgs[0]
        recons_img = (recons_img + 1) / 2
        recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
        return recons_img

    def features2image(self,raw_features):
        recons_imgs = self.vae.decode(raw_features.squeeze(-3))
        recons_img = recons_imgs[0]
        recons_img = (recons_img + 1) / 2
        recons_img = recons_img.permute(1, 2, 0).mul_(255).cpu().numpy().astype(np.uint8)[:,:,::-1]
        return recons_img
        