import os
import torch
import datetime
import numpy as np
from PIL import Image
from pipeline.pipeline_svd_DragAnything_long import StableVideoDiffusionPipeline
from models.Drag3D import DragAnythingSDVModel
from models.unet_spatio_temporal_condition_controlnet import UNetSpatioTemporalConditionControlNetModel
import cv2
import re 
import glob
from os import path
from scipy.ndimage import distance_transform_edt
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.transforms import PILToTensor
import json
from tqdm import tqdm


def save_gifs_side_by_side(batch_output, validation_control_images,output_folder,name = 'none', target_size=(512 , 512),duration=200):

    flattened_batch_output = batch_output
    def create_gif(image_list, gif_path, duration=100):
        pil_images = [validate_and_convert_image(img,target_size=target_size) for img in image_list]
        pil_images = [img for img in pil_images if img is not None]
        if pil_images:
            pil_images[0].save(gif_path, save_all=True, append_images=pil_images[1:], loop=0, duration=duration)

    # Creating GIFs for each image list
    timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
    gif_paths = []
    
    for idx, image_list in enumerate([validation_control_images, flattened_batch_output]):
        
#         if idx==0:
#             continue

        gif_path = os.path.join(output_folder, f"temp_{idx}_{timestamp}.gif")
        create_gif(image_list, gif_path)
        gif_paths.append(gif_path)

    # Function to combine GIFs side by side
    def combine_gifs_side_by_side(gif_paths, output_path):
        gifs = [Image.open(gif) for gif in gif_paths]

        # Assuming all gifs have the same frame count and duration
        frames = []
        for frame_idx in range(gifs[0].n_frames):
            combined_frame = None
            
                
            for gif in gifs:
                
                gif.seek(frame_idx)
                if combined_frame is None:
                    combined_frame = gif.copy()
                else:
                    combined_frame = get_concat_h(combined_frame, gif.copy())
            frames.append(combined_frame)
        frames[0].save(output_path, save_all=True, append_images=frames[1:], loop=0, duration=duration)

    # Helper function to concatenate images horizontally
    def get_concat_h(im1, im2):
        dst = Image.new('RGB', (im1.width + im2.width, max(im1.height, im2.height)))
        dst.paste(im1, (0, 0))
        dst.paste(im2, (im1.width, 0))
        return dst

    # Combine the GIFs into a single file
    combined_gif_path = os.path.join(output_folder, f"combined_frames_{name}_{timestamp}.gif")
    combine_gifs_side_by_side(gif_paths, combined_gif_path)

    # Clean up temporary GIFs
    for gif_path in gif_paths:
        os.remove(gif_path)

    return combined_gif_path

# Define functions
def validate_and_convert_image(image, target_size=(512 , 512)):
    if image is None:
        print("Encountered a None image")
        return None

    if isinstance(image, torch.Tensor):
        # Convert PyTorch tensor to PIL Image
        if image.ndim == 3 and image.shape[0] in [1, 3]:  # Check for CxHxW format
            if image.shape[0] == 1:  # Convert single-channel grayscale to RGB
                image = image.repeat(3, 1, 1)
            image = image.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
            image = Image.fromarray(image)
        else:
            print(f"Invalid image tensor shape: {image.shape}")
            return None
    elif isinstance(image, Image.Image):
        # Resize PIL Image
        image = image.resize(target_size)
    else:
        print("Image is not a PIL Image or a PyTorch tensor")
        return None
    
    return image

def create_image_grid(images, rows, cols, target_size=(512 , 512)):
    valid_images = [validate_and_convert_image(img, target_size) for img in images]
    valid_images = [img for img in valid_images if img is not None]

    if not valid_images:
        print("No valid images to create a grid")
        return None

    w, h = target_size
    grid = Image.new('RGB', size=(cols * w, rows * h))

    for i, image in enumerate(valid_images):
        grid.paste(image, box=((i % cols) * w, (i // cols) * h))

    return grid

def tensor_to_pil(tensor):
    """ Convert a PyTorch tensor to a PIL Image. """
    # Convert tensor to numpy array
    if len(tensor.shape) == 4:  # batch of images
        images = [Image.fromarray(img.numpy().transpose(1, 2, 0)) for img in tensor]
    else:  # single image
        images = Image.fromarray(tensor.numpy().transpose(1, 2, 0))
    return images

def save_combined_frames(batch_output, validation_images, validation_control_images, output_folder):
    # Flatten batch_output to a list of PIL Images
    flattened_batch_output = [img for sublist in batch_output for img in sublist]

    # Convert tensors in lists to PIL Images
    validation_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_images]
    validation_control_images = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in validation_control_images]
    flattened_batch_output = [tensor_to_pil(img) if torch.is_tensor(img) else img for img in batch_output]

    # Flatten lists if they contain sublists (for tensors converted to multiple images)
    validation_images = [img for sublist in validation_images for img in (sublist if isinstance(sublist, list) else [sublist])]
    validation_control_images = [img for sublist in validation_control_images for img in (sublist if isinstance(sublist, list) else [sublist])]
    flattened_batch_output = [img for sublist in flattened_batch_output for img in (sublist if isinstance(sublist, list) else [sublist])]

    # Combine frames into a list
    combined_frames = validation_images + validation_control_images + flattened_batch_output

    # Calculate rows and columns for the grid
    num_images = len(combined_frames)
    cols = 3
    rows = (num_images + cols - 1) // cols

    # Create and save the grid image
    grid = create_image_grid(combined_frames, rows, cols, target_size=(512, 512))
    if grid is not None:
        timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        filename = f"combined_frames_{timestamp}.png"
        output_path = os.path.join(output_folder, filename)
        grid.save(output_path)
    else:
        print("Failed to create image grid")

def load_images_from_folder(folder):
    images = []
    valid_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}  # Add or remove extensions as needed

    # Function to extract frame number from the filename
    def frame_number(filename):
        matches = re.findall(r'\d+', filename)  # Find all sequences of digits in the filename
        if matches:
            if matches[-1] == '0000' and len(matches) > 1:
                return int(matches[-2])  # Return the second-to-last sequence if the last is '0000'
            return int(matches[-1])  # Otherwise, return the last sequence
        return float('inf')  # Return 'inf'


    # Sorting files based on frame number
    sorted_files = sorted(os.listdir(folder), key=frame_number)

    # Load images in sorted order
    for filename in sorted_files:
        ext = os.path.splitext(filename)[1].lower()
        if ext in valid_extensions:
            img = Image.open(os.path.join(folder, filename)).convert('RGB')
            images.append(img)

    return images


def infer_model(model, image):
    transform = T.Compose([
        T.Resize((196, 196)),
        T.ToTensor(),
        T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    image = transform(image).unsqueeze(0).cuda()
    cls_token = model(image, is_training=False)
    return cls_token

def find_largest_inner_rectangle_coordinates(mask_gray):

    refine_dist = cv2.distanceTransform(mask_gray.astype(np.uint8), cv2.DIST_L2, 5, cv2.DIST_LABEL_PIXEL)
    _, maxVal, _, maxLoc = cv2.minMaxLoc(refine_dist)
    radius = int(maxVal)

    return maxLoc, radius


def fill_and_concatenate( a, b):
    """
    Fill the non-zero regions of a 3D matrix with values from array a and concatenate the results.
    
    :param a: A 1D numpy array with 3 values.
    :param b: A 3D numpy matrix with shape (100, 100, 100).
    :return: A 4D numpy array with shape (3, 100, 100, 100).
    """
    # if len(a) != 3:
    #     raise ValueError("Array a must contain exactly 3 elements.")
    

    output = np.zeros((len(a),) + b.shape, dtype=b.dtype)
    
    for i, value in enumerate(a):
        filled_b = b.copy()  
        filled_b[b != 0] = value  
        output[i] = filled_b  
    
    return output

# cloud 
def get_condition(target_size, original_size, args, start_index, frame_num, first_frame=None, is_mask = False):
    condition_2D_list = []
    vis_2D_list = []

    original_size = (original_size[1],original_size[0])
    size = (target_size[1],target_size[0])
   
    name_list  = []
    pose_path = path.join(args.dataset_root, args.data_name, 'sparse_pose_frame')
    png_files = sorted(glob.glob(os.path.join(pose_path, "*.png")))[start_index:start_index+frame_num]
    for idxx,name in enumerate(png_files):
        path_file = os.path.join(pose_path,name)
        if idxx >= frame_num:
            break
        
        img = cv2.imread(path_file)
        img = cv2.resize(img, size)
        # Ensure all images are in RGB format
        if len(img.shape) == 2:  # Grayscale image
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
            # vis_img = cv2.cvtColor(vis_img, cv2.COLOR_GRAY2RGB)
        elif len(img.shape) == 3 and img.shape[2] == 3:  # Color image in BGR format
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # vis_img = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
        # Convert the numpy array to a PIL image
        pil_img = Image.fromarray(img)
        condition_2D_list.append(img/255)
        vis_2D_list.append(pil_img)
        name_list.append(name)
    return condition_2D_list, vis_2D_list, name_list



# Usage example
def convert_list_bgra_to_rgba(image_list):
    """
    Convert a list of PIL Image objects from BGRA to RGBA format.

    Parameters:
    image_list (list of PIL.Image.Image): A list of images in BGRA format.

    Returns:
    list of PIL.Image.Image: The list of images converted to RGBA format.
    """
    rgba_images = []
    for image in image_list:
        if image.mode == 'RGBA' or image.mode == 'BGRA':
            # Split the image into its components
            b, g, r, a = image.split()
            # Re-merge in RGBA order
            converted_image = Image.merge("RGBA", (r, g, b, a))
        else:
            # For non-alpha images, assume they are BGR and convert to RGB
            b, g, r = image.split()
            converted_image = Image.merge("RGB", (r, g, b))

        rgba_images.append(converted_image)

    return rgba_images

def show_mask(image, masks, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3)], axis=0)

        h, w = mask.shape[:2]

        color_a = np.concatenate([np.random.random(3)*255], axis=0)
        mask_image = mask.reshape(h, w, 1) * color_a.reshape(1, 1, -1)
        
    else:
        h, w = masks[0].shape[:2]
#         mask_image = mask1.reshape(h, w, 1) * np.array([30, 144, 255])
        mask_image = 0
        for idx,mask in enumerate(masks):
            if idx!=1 and idx!=0:
                continue
            color = np.concatenate([np.random.random(3)*255], axis=0)
            mask_image =  mask.reshape(h, w, 1) * color.reshape(1, 1, -1) + mask_image

    return (np.array(image).copy()*0.4+mask_image*0.6).astype(np.uint8)


# Main script
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="")
    parser.add_argument("--pretrained_model_name_or_path", type=str, default="stabilityai/stable-video-diffusion-img2vid", help="Path to the pretrained model or model name.")
    parser.add_argument("--dataset_root", type=str, default="/users/zeyuzhu/dataset_project/Datasets/fallowshow/datasets")
    parser.add_argument("--height", type=int, default=320, help="Height of the output video frames.")
    parser.add_argument("--width", type=int, default=576, help="Width of the output video frames.")
    parser.add_argument("--frame_number", type=int, default=16, help="Number of frames to generate.")
    parser.add_argument("--output_root", type=str, default="./Validation", help="Output directory.")

    parser.add_argument("--Motion3D", type=str, default="./model_out/ControlSVD-2024.12.9_reference_motion/checkpoint-12000/controlnet", help="Path to Motion3D checkpoint.")
    parser.add_argument("--data_name", type=str, default="fallowshow_1_000290to000632", help="Name of the dataset.")
    parser.add_argument("--using_gt_ff", type=bool, default=False)
    args = parser.parse_args()

    ## Load and set up the pipeline
    print('loading')
    controlnet = DragAnythingSDVModel.from_pretrained(args.Motion3D)
    unet = UNetSpatioTemporalConditionControlNetModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
    pipeline = StableVideoDiffusionPipeline.from_pretrained(args.pretrained_model_name_or_path, controlnet=controlnet, unet=unet)
    pipeline.enable_model_cpu_offload()
    ## create output dir
    output_dir = path.join(args.output_root, args.Motion3D.split('/')[2]+'_'+args.Motion3D.split('/')[3]+'_reference')
    data_name = args.data_name
    if args.using_gt_ff: 
        save_root = path.join(output_dir, data_name+'_using_gt_ff')
    else: save_root = path.join(output_dir, data_name)
        
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(save_root, exist_ok=True)
    os.makedirs(path.join(save_root, 'frames'), exist_ok=True)
    os.makedirs(path.join(save_root, 'videos'), exist_ok=True)
    video_length = len(os.listdir(path.join(args.dataset_root, data_name, 'frames')))
    validation_image = Image.open(os.path.join(args.dataset_root, data_name, 'frames', 'output_0001.png'))
    width, height = validation_image.size
    validation_image = validation_image.resize((args.width, args.height))
    reference_image = validation_image
    
    print("video_length", video_length)
    for start_index in tqdm(range(0, video_length, args.frame_number)):
        frame_num = min(args.frame_number, video_length-start_index)
        condition_2D_list, vis_2D_list, name_list = get_condition(
            target_size=(args.height, args.width),
            original_size=(height, width),
            start_index=start_index,
            frame_num=frame_num,
            args=args,
            first_frame=validation_image
        )
        condition_2D_list = torch.tensor(np.array(condition_2D_list))
        condition_2D_list = condition_2D_list.permute(0, 3, 1, 2)
        video_frames = pipeline(
            validation_image,
            condition_2D_list[:frame_num],
            decode_chunk_size=8,
            num_frames=frame_num,
            motion_bucket_id=180,
            controlnet_cond_scale=1.0,
            height=args.height,
            width=args.width,
            motion_frame=validation_image,
            reference_image=reference_image).frames
        video_frames = [img for sublist in video_frames for img in sublist]
        for frame, name in zip(video_frames, name_list):
            save_path = path.join(save_root, 'frames', name.split("/")[-1])
            frame.save(save_path)
        save_gifs_side_by_side(video_frames, vis_2D_list[:args.frame_number], path.join(save_root, 'videos'), target_size=(width, height), duration=110)
        ## for next forward
        valid_index = sorted(os.listdir(path.join(save_root, 'frames')))[-1]
        if args.using_gt_ff: 
            validation_image = Image.open(path.join(args.dataset_root, data_name, 'frames', valid_index))
        else: validation_image = Image.open(path.join(save_root, 'frames', valid_index))
        validation_image = validation_image.resize((args.width, args.height))