import PIL.Image
import numpy as np
import torch
import os
import pickle
import sys
from PIL import Image

sys.path.append('../')
from torch_utils import misc
from torch_utils import training_stats
from torch_utils.ops import conv2d_gradfix
from torch_utils.ops import grid_sample_gradfix

def save_image_grid(img, fname, drange, grid_size):
    '''
    save images from a gridded image
    '''
    lo, hi = drange
    img = np.asarray(img, dtype=np.float32)
    img = (img - lo) * (255 / (hi - lo))
    img = np.rint(img).clip(0, 255).astype(np.uint8)

    gw, gh = grid_size
    _N, C, H, W = img.shape
    img = img.reshape(gh, gw, C, H, W)
    img = img.transpose(0, 3, 1, 4, 2)
    img = img.reshape(gh * H, gw * W, C)

    assert C in [1, 3]
    if C == 1:
        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
    if C == 3:
        PIL.Image.fromarray(img, 'RGB').save(fname)


def get_image_grid(img, drange, grid_size):
    lo, hi = drange
    img = np.asarray(img, dtype=np.float32)
    img = (img - lo) * (255 / (hi - lo))
    img = np.rint(img).clip(0, 255).astype(np.uint8)

    gw, gh = grid_size
    _N, C, H, W = img.shape
    img = img.reshape(gh, gw, C, H, W)
    img = img.transpose(0, 3, 1, 4, 2)
    img = img.reshape(gh * H, gw * W, C)

    assert C in [1, 3]
    if C == 1:
        return PIL.Image.fromarray(img[:, :, 0], 'L')
    if C == 3:
        return PIL.Image.fromarray(img, 'RGB')


def create_gif_from_directory(frame_folder, gif_filename, seq_len, fps, interval_after=0):
    '''
    create from files
    useage: build dataset, or generate gifs from a directory
    '''
    frame_files = sorted([f for f in os.listdir(frame_folder) if f.endswith(('.png', '.jpg', '.jpeg'))])
    
    if not frame_files:
        print("No frame files found, please check the folder path.")
        return
    frames = []
    for frame_file in frame_files[:seq_len]:
        frame_path = os.path.join(frame_folder, frame_file)
        try:
            frame = Image.open(frame_path)
            frames.append(frame)
        except Exception as e:
            print(f"Unable to open frame file {frame_path}: {e}")
    if not frames:
        print("No frame files could be loaded, please check file format and path.")
        return
    
    frame_duration = 1000 / fps 
    
    imageio.mimsave(gif_filename, frames, duration=frame_duration, loop=0)
    
    if len(frames) >= 8 and interval_after > 0:
        blank_frame = Image.new('RGB', frames[0].size, (255, 255, 255))  
        blank_frames = [blank_frame] * (interval_after * fps // 1000)  
        imageio.mimsave(gif_filename, blank_frames, duration=frame_duration, append=True)

def split_images_from_grid(image, width=256, height=256, vid_num=8, vid_len=8):
    """
    Crop an image into smaller patches based on given dimensions and layout.

    Args:
        image (PIL.Image): Input image to be cropped.
        width (int): Width of each small patch.
        height (int): Height of each small patch.
        vid_num (int): Number of rows (videos).
        vid_len (int): Number of columns (frames per video).

    Returns:
        list: A list of cropped patches as PIL images.
    """
    small_images = []

    for i in range(vid_num):
        for j in range(vid_len):
            left = j * width
            upper = i * height
            right = left + width
            lower = upper + height
            small_image = image.crop((left, upper, right, lower))
            small_images.append(small_image)

    return small_images


def save_videos_as_gifs(small_images, vid_num, vid_len, output_prefix="vid", frame_duration=100, pause_duration=500):
    """
    Groups a list of frame images and saves them as individual GIF files.

    Args:
        small_images (list): A list of all frame images.
        vid_num (int): Number of videos to create.
        vid_len (int): Number of frames per video.
        output_prefix (str): Prefix for the output GIF file names.
        frame_duration (int): Duration of each frame in milliseconds.
    
    Returns:
        list: Paths to all generated GIF files.
    """
    gif_paths = []
    
    for i in range(vid_num):
        # Group frames for the current video
        vid = small_images[i * vid_len:i * vid_len + vid_len]
        
        # Define the output file path
        gif_path = f"{output_prefix}.gif"
        
        # Add pause frames (repeating the last frame)
        pause_frames = int(pause_duration / frame_duration)
        for _ in range(pause_frames):
            vid.append(vid[-1])
        
        # Save the frames as a GIF
        vid[0].save(
            gif_path, 
            save_all=True, 
            append_images=vid[1:], 
            duration=frame_duration, 
            loop=0
        )
        gif_paths.append(gif_path)
    
    return gif_paths


def create_combined_gif_from_images(
    small_images, vid_num, vid_len, layout=(2, 4), frame_size=(256, 256), 
    output_path="final_output.gif", frame_duration=100, pause_duration=500
):
    """
    Combine multiple PIL images into a single GIF with a specified layout, with pauses at the end of each loop.

    Args:
        small_images (list): List of PIL images representing frames of multiple videos.
        vid_num (int): Number of videos to combine.
        vid_len (int): Number of frames in each video.
        layout (tuple): Layout of the combined GIF as (rows, cols). For example, (2, 4) means 2 rows and 4 columns.
        frame_size (tuple): Size of each frame (width, height), e.g., (256, 256).
        output_path (str): Path to save the combined GIF. Default is "final_output.gif".
        frame_duration (int): Duration of each frame in milliseconds. Default is 100 ms.
        pause_duration (int): Duration of the pause at the end of the loop in milliseconds. Default is 500 ms.
    
    Returns:
        None
    """
    rows, cols = layout
    width, height = frame_size

    # Calculate the size of the canvas for the combined GIF
    canvas_width = cols * width
    canvas_height = rows * height

    # List to store all frames of the combined GIF
    frames = []

    # Iterate through each frame
    for frame_idx in range(vid_len):
        # Create a blank canvas
        canvas = Image.new("RGB", (canvas_width, canvas_height))
        
        # Iterate through each video and place its frame on the canvas
        for vid_idx in range(vid_num):
            # Calculate the index of the current frame in the small_images list
            frame_index = vid_idx * vid_len + frame_idx
            if frame_index < len(small_images):
                # Calculate the position to paste the frame on the canvas
                x = (vid_idx % cols) * width
                y = (vid_idx // cols) * height
                # Paste the frame onto the canvas
                canvas.paste(small_images[frame_index], (x, y))
        
        frames.append(canvas)

    # Add pause frames (repeating the last frame)
    pause_frames = int(pause_duration / frame_duration)
    for _ in range(pause_frames):
        frames.append(frames[-1])

    # Save the combined GIF
    frames[0].save(
        output_path, save_all=True, append_images=frames[1:], 
        duration=frame_duration, loop=0
    )
    print(f"Combined GIF saved to {output_path}")

def tensor_to_gif(tensor, save_path, loop_pause=2, frame_duration=100):
    """
    Save a Tensor as a looping GIF with a pause at the end.

    Args:
        tensor (torch.Tensor): Input Tensor, shape [1, channel, seq_len, h, w].
        save_path (str): Path to save the GIF.
        loop_pause (int): Pause duration between loops (seconds).
        frame_duration (int): Duration of each frame (milliseconds).
    """
    assert tensor.ndim == 5, "Tensor must be 5-dimensional [1, channel, seq_len, h, w]"
    tensor = tensor.squeeze(0)   # Remove batch dimension [channel, seq_len, h, w]
    channel, seq_len, h, w = tensor.shape
    
    # Normalize to [0, 255] and convert to uint8
    tensor = (tensor - tensor.min()) / (tensor.max() - tensor.min())
    tensor = (tensor * 255).byte()
    
    # If grayscale
    if channel == 1:
        frames = [PIL.Image.fromarray(tensor[0, i].cpu().numpy(), mode="L") for i in range(seq_len)]
    elif channel == 3:
        frames = [PIL.Image.fromarray(tensor[:, i].permute(1, 2, 0).cpu().numpy(), mode="RGB") for i in range(seq_len)]
    else:
        raise ValueError("Only inputs with channel 1 (grayscale) or 3 (RGB) are supported")

    pause_frames = frames[-1:] * (loop_pause * 1000 // frame_duration)
    
    frames[0].save(
        save_path,
        save_all=True,
        append_images=frames[1:] + pause_frames,
        duration=frame_duration,
        loop=0
    )
    print(f"GIF saved at {save_path}") 
