import cv2
import logging
logger = logging.getLogger(__name__)
import os.path as osp
import numpy as np
from PIL import Image
import PIL.Image as PImage
from torchvision.transforms.functional import to_tensor
import torch
import torch.nn.functional as F
from infinity.utils.dynamic_resolution import dynamic_resolution_h_w
from torchvision.transforms.functional import rotate
from torchvision.transforms import InterpolationMode
import random
def save_images(img_list, save_path):
    for idx, img in enumerate(img_list):
        tmp_path = save_path[:-4] + str(idx) + save_path[-4:]
        cv2.imwrite(tmp_path, img.cpu().numpy())
    logger.info(f".. saved last image to {osp.abspath(tmp_path)}")


def save_single_image(img_maybe_batch, save_path):
    if type(save_path) != list:
        save_path = [save_path]
    if len(img_maybe_batch.shape) == 3:
        num = 1
    else:
        num = img_maybe_batch.shape[0]
    for i in range(num):
        if num != img_maybe_batch.shape[0]:
            img = img_maybe_batch
        else:
            img = img_maybe_batch[i,::]
        written = cv2.imwrite(save_path[i], img.cpu().numpy())
        if written:
            logger.info(f"Save to {osp.abspath(save_path[i])}")
        else:
            logger.warning(f"Failed saving image at {osp.abspath(save_path[i])}")

def get_stripped_delta(delta):
    stripped_delta = ''.join(str(delta).split('.')) # only needed for save path
    if stripped_delta[-1] == "0": # for example join(split(delta)), delta = 20.0 -> 200 -> 20 But for join(split(0.4))-> 04
        stripped_delta = stripped_delta[:-1] 
    return stripped_delta


def transform(pil_img, tgt_h, tgt_w):

    width, height = pil_img.size
    if width / height <= tgt_w / tgt_h:
        resized_width = tgt_w
        resized_height = int(tgt_w / (width / height))
    else:
        resized_height = tgt_h
        resized_width = int((width / height) * tgt_h)
    pil_img = pil_img.resize((resized_width, resized_height), resample=PImage.LANCZOS)
    # crop the center out
    arr = np.array(pil_img)
    crop_y = (arr.shape[0] - tgt_h) // 2
    crop_x = (arr.shape[1] - tgt_w) // 2
    im = to_tensor(arr[crop_y: crop_y + tgt_h, crop_x: crop_x + tgt_w])
    im = im.add(im).add_(-1)
    return im

def joint_vi_vae_encode_decode(vae, image_or_path, scale_schedule, device, tgt_h, tgt_w, watermark_transform = None, apply_spatial_patchify=False):
    if apply_spatial_patchify:
        scale_schedule = [
            (pt, 2 * ph, 2 * pw) for pt, ph, pw in scale_schedule
        ]
    
    if type(image_or_path) == str:
        pil_image = Image.open(image_or_path).convert('RGB')
    elif type(image_or_path) == list:
        pil_image = []
        for img_or_path2_lol in image_or_path:
            if type(img_or_path2_lol) == str: 
                pil_image.append(Image.open(img_or_path2_lol).convert('RGB'))
            else:
                pil_image.append(img_or_path2_lol)
    else:
        pil_image = image_or_path
    inp = transform(pil_image, tgt_h, tgt_w)

    scale_schedule = [(item[0], item[1], item[2]) for item in scale_schedule]

    img_embedding, z, _, all_bit_indices, _, infinity_input = vae.encode(inp.unsqueeze(0).to(device), scale_schedule=scale_schedule)
    
    if apply_spatial_patchify: # patchify operation
        for i, idx_Bld in enumerate(all_bit_indices): 
            idx_Bld = idx_Bld.squeeze(1)
            idx_Bld = idx_Bld.permute(0, 3, 1, 2)                       # [B, d, h, w] (from [B, h, w, d])
            idx_Bld = torch.nn.functional.pixel_unshuffle(idx_Bld, 2)    # [B, 4d, h//2, w//2]
            idx_Bld = idx_Bld.permute(0, 2, 3, 1)                       # [B, h//2, w//2, 4d]
            all_bit_indices[i] = idx_Bld.unsqueeze(1) # [B, 4d, h, w]
    return _, _, all_bit_indices, img_embedding

def decode_codes(summed_codes, vae):
    img = vae.decode(summed_codes.squeeze(-3))
    img = (img + 1) / 2
    img = img.permute(0, 2, 3, 1).mul_(255).to(torch.uint8).flip(dims=(3,)).squeeze(0)
    return img        
        

def compose_scales_to_image(args, bit_encoding, vae, scale_schedule, save_scale_img_path=""):
    summed_codes=0
    for idx, scale_indices in enumerate(bit_encoding):
        pn = scale_schedule[idx]
        idx_Bld = scale_indices.reshape(1, pn[1], pn[2], -1) 
        if args.apply_spatial_patchify: # unpatchify operation
            idx_Bld = idx_Bld.permute(0,3,1,2) # [B, 4d, h, w]
            idx_Bld = torch.nn.functional.pixel_shuffle(idx_Bld, 2) # [B, d, 2h, 2w]
            idx_Bld = idx_Bld.permute(0,2,3,1) # [B, 2h, 2w, d]
        idx_Bld = idx_Bld.unsqueeze(1)
        codes = vae.quantizer.lfq.indices_to_codes(idx_Bld, label_type='bit_label') # [B, d, 1, h, w] or [B, d, 1, 2h, 2w]
        codes = F.interpolate(codes, size=scale_schedule[-1], mode=vae.quantizer.z_interplote_up)
        summed_codes += codes
        if save_scale_img_path:
            img = decode_codes(summed_codes.squeeze(-3), vae)
            save_single_image(img, f"remove_noise{idx}.png")
        
            img = decode_codes(codes.squeeze(-3), vae)
            save_single_image(img, f"remove_noise_code{idx}.png")
    return summed_codes.squeeze(-3)

def decompose_image_into_scales(args,img_path:str, vae, img_transform):
    scale_schedule = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["scales"]
    scale_schedule = [(1, h, w) for (_, h, w) in scale_schedule]
    tgt_h, tgt_w = dynamic_resolution_h_w[args.h_div_w_template][args.pn]["pixel"]            
    gt_img, _, encoding_bit_indices, _= joint_vi_vae_encode_decode(
        vae, img_path, scale_schedule, "cuda", tgt_h, tgt_w, [], apply_spatial_patchify=args.apply_spatial_patchify
    )
    if img_transform != []:
        no_trans_gt_img, _, no_trans_encoding_bit_indices, _= joint_vi_vae_encode_decode(
            vae, img_path, scale_schedule, "cuda", tgt_h, tgt_w, [], apply_spatial_patchify=args.apply_spatial_patchify
        )
    summed_codes = compose_scales_to_image(args, encoding_bit_indices, vae, scale_schedule)
    save_single_image(decode_codes(summed_codes, vae), 'tmp.png')

    if img_transform != []:
        count_match_after_reencoding(no_trans_encoding_bit_indices, encoding_bit_indices, watermark_scales=range(13))

    cv2.imwrite("remove_gt.png", gt_img)
    
def count_match_after_reencoding(
    encoding_bit_indices, gen_bit_indices: list[torch.Tensor], watermark_scales: list[int], count_bit_match:bool = True, compare_only_on_watermarked_scales:bool = True
) -> tuple[list[int], list[int]]:
    """Count how much bit or token information is lost after reencoding

    Args:
        encoding_bit_indices (list[torch.Tensor]): Quantized residuals (bits) after encoding
        gen_bit_indices (list[torch.Tensor]): Quantized residuals (bits) after generation before decoding
        watermark_scales (list[int]): Which scales have been watermarked
        count_bit_match (bool): Specify if matches should be checked per token or per bit 
        compare_only_on_watermarked_scales (bool): If the matches should be counted only on scales that have been watermarked
    Returns:
        Tuple[list[int], list[int]]: Number of matches for each scale and of total bits or tokens.
    """
    if compare_only_on_watermarked_scales:
        encoding_bit_indices = [encoding_bit_indices[i] for i in watermark_scales] 
        gen_bit_indices = [gen_bit_indices[i] for i in watermark_scales]
    num_matches_list = []
    num_total_list = []
    ret = {}
    token_size = gen_bit_indices[0].shape[-1]
    for i, (ebi, gbi) in enumerate(zip(encoding_bit_indices, gen_bit_indices)):
        function = lambda d, idx :torch.sum(torch.eq(d[0][idx],d[1][idx]).to(torch.int32)).to(torch.int32).cpu().item()
        indices_shape = ebi.shape
        scales, token, bit_of_tokens = analyse_bits_of_different_levels(data=(ebi.reshape(-1), gbi.reshape(-1)), function = function, token_size=token_size, seq_len=ebi.reshape(-1).shape[-1])
        ret[f"scale_{i}"] = {"scale": {"match_reencoding": scales}, "token": {"match_reencoding": token}, "bit_n_of_tokens": {"match_reencoding":bit_of_tokens} }
        if count_bit_match: # counts matches per bit
            num_total = indices_shape[2] * indices_shape[3] * indices_shape[4] # num of bits in this scale
            num_matches = scales
        else: # counts matches per token
            num_total = indices_shape[2] * indices_shape[3] # num of tokens in this scale
            matches = torch.where(matches == 32, 1, 0)
            num_matches = matches.reshape(-1).sum(0).item()
        num_total_list.append(num_total)
        num_matches_list.append(num_matches)
    num_total_list = np.array(num_total_list)
    num_matches_list = np.array(num_matches_list)
    for scale in range(len(num_total_list)):
        logger.info(f"{num_matches_list[scale]}/{num_total_list[scale]}, {num_matches_list[scale]/num_total_list[scale]*100:.2f}%")
    return ret, num_matches_list, num_total_list

def count_frequency_ones_zeros(encoding_bit_indices, gen_bit_indices: list[torch.Tensor], watermark_scales: list[int], count_bit_match:bool = True, compare_only_on_watermarked_scales:bool = True
):
    
    ret = {}
    
    average_enc=[]
    average_gen=[]

    full_num_bits = 0

    for i, (ebi, gbi) in enumerate(zip(encoding_bit_indices, gen_bit_indices)):

        
        indices_shape = ebi.shape
        num_bits_per_scale =(indices_shape[2] * indices_shape[3] * indices_shape[4])
        full_num_bits += num_bits_per_scale
        
        frequency_ones_gen = torch.sum(gbi) / num_bits_per_scale
        frequency_zeros_gen = 1 - frequency_ones_gen
        
        frequency_ones_enc = torch.sum(ebi) / num_bits_per_scale
        frequency_zeros_enc = 1 - frequency_ones_enc
        
        ret[f"scale_{i}"] = {"frequency_ones_gen" : frequency_ones_gen.cpu().item() , "frequency_zeroes_gen" : frequency_zeros_gen.cpu().item(), "frequency_ones_enc" : frequency_ones_enc.cpu().item(), "frequency_zeros_enc" : frequency_zeros_enc.cpu().item() }
        average_gen.append((frequency_ones_gen.cpu().item()) * num_bits_per_scale)
        average_enc.append(frequency_ones_enc.cpu().item() * num_bits_per_scale)
        
    ret["encoding_frequency_ones"] = (np.sum(average_enc)/full_num_bits).item()
    ret["generated_frequency_ones"] = (np.sum(average_gen)/full_num_bits).item()
    
    
    return ret


def analyse_bits_of_different_levels(data, function, token_size, seq_len):
    # Does not work on batches
    num_tokens = int(seq_len/token_size)
    bit_of_tokens = []
    token = []

    scale = function(data, slice(None))

    return scale, token, bit_of_tokens

def calc_entropy(data):
    data = torch.nn.functional.softmax(data, dim=-1, dtype=None)
    return torch.sum(torch.div(torch.special.entr(data[:,:,0]),data.shape[-2]), dim=1).cpu().numpy().tolist()
    
def count_consecutive_sequences(array):
    if len(array) == 0:
        return {}

    # Find the indices where the value changes
    change_indices = np.where(np.diff(array) != 0)[0] + 1
    # Include the start and end indices to capture all sequences
    indices = np.concatenate(([0], change_indices, [len(array)]))

    # Store sequences and their counts
    sequences = []
    
    for start, end in zip(indices[:-1], indices[1:]):
        sequence = end-start
        sequences.append(int(sequence))

    # Count frequency of each sequence
    return sequences

def rotate_features(encoded_bits, angle = 90):
    for i, scale in enumerate(encoded_bits):
        scale = scale.squeeze(0).squeeze(0)
        feature_map = scale.permute(2,0,1).unsqueeze(0).float()   # (1, D, H, W) B feature tensor
        rotated_feature_map = torch.stack([
            rotate(feature_map[0, d].unsqueeze(0), angle, interpolation=InterpolationMode.BILINEAR).squeeze(0)
        for d in range(feature_map.shape[1])]).unsqueeze(0) 
        encoded_bits[i] = torch.round(rotated_feature_map.permute(0, 2, 3, 1))  # (1, N, N, D)
    return encoded_bits

def set_seeds(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

def count_matching_n_bit_sequences(a, b, max_n=32):
    if type(a) == list:
        a = torch.cat([t.reshape(-1) for t in a], dim=0)
        b = torch.cat([t.reshape(-1) for t in b], dim=0)
    res = []
    for n in range(1, max_n + 1):
        # For n, create rolling windows: shape will be (num_windows, n)
        a_wins = a.unfold(0, n, 1)  # shape: (length-n+1, n)
        b_wins = b.unfold(0, n, 1)
        # Compare elementwise, then .all(dim=1) for each window
        matches = (a_wins == b_wins).all(dim=1)
        count = matches.sum().item()
        res.append(count)
    return res
