import numpy as np
import torch
import random
import matplotlib.pyplot as plt


def divide_sequence(source, time_steps):
    seq_num = int(len(source) / time_steps)
    return [source[time_steps * i:time_steps * (i + 1), :, :] for i in range(seq_num)]


def create_patches(clip, patch_size):
    """
    Divides a video clip into non-overlapping square patches.
    
    Parameters:
    -----------
    clip : ndarray
        A 3D array of shape (T, H, W) representing a video clip with T frames,
        each of height H and width W. H and W must be divisible by patch_size.
    
    patch_size : int
        The size of each square patch (patch_size × patch_size).
    
    Returns:
    --------
    out : ndarray
        A 4D array of shape (num_patches, T, patch_size, patch_size) containing
        all patches across all frames where num_patches = (H/patch_size) * (W/patch_size).
        The patches are ordered spatially from left to right, top to bottom.
    
    Raises:
    -------
    AssertionError:
        If the clip dimensions are not divisible by patch_size.
    """
    duration = clip.shape[0]
    sx, sy = clip.shape[1], clip.shape[2]
    assert(sx % patch_size == 0)
    assert(sy % patch_size == 0)
    nx, ny = int(sx/patch_size), int(sy/patch_size)

    out = np.empty((nx*ny, duration, patch_size, patch_size), dtype='float32')
    for i in range(nx):
        for j in range(ny):
            out[i*ny + j, ...] = clip[:, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size]

    return out

def assemble_patches(patches, patch_size, sx, sy, overlap=0):
    """
    Reconstructs a sequence of images from their patch representation.
    
    Parameters:
    -----------
    source : ndarray
        A 4D array of shape (timesteps, num_patches, patch_size, patch_size) containing the
        patches for all frames in the sequence. The patches should be ordered spatially
        from left to right, top to bottom.
    
    patch_size : int
        The size of each square patch (patch_size × patch_size).
    
    sx : int
        The height of the original image. Must be divisible by patch_size when overlap=0.
    
    sy : int
        The width of the original image. Must be divisible by patch_size when overlap=0.
    
    overlap : int, default=0
        The number of pixels that patches overlap. Currently only supports overlap=0.
    
    Returns:
    --------
    output : ndarray
        A 3D array of shape (timesteps, sx, sy) containing the reconstructed image 
        sequence.
    
    Raises:
    -------
    AssertionError:
        If sx or sy is not divisible by patch_size when overlap=0.
    Exception:
        If overlap is not 0, as this functionality is not yet implemented.
    """
    # If dimensions of source are not already split into P,P rearrange
    if len(patches.shape) == 3 and patches.shape[2] == patch_size * patch_size:
        # Reshape the flattened patches back to square patches
        num_patches = patches.shape[1]
        timesteps = patches.shape[0]
        patches = patches.reshape(timesteps, num_patches, patch_size, patch_size)

    if overlap==0:
        assert (sx % patch_size == 0)
        assert (sy % patch_size == 0)
        nx, ny = int(sx / patch_size), int(sy / patch_size)
        timesteps = patches.shape[0]

        output = np.empty((timesteps, sx, sy), dtype='float32')
        for i in range(nx):
            for j in range(ny):
                for t in range(timesteps):
                    output[t, i*patch_size:(i+1)*patch_size, j*patch_size:(j+1)*patch_size] = patches[t, i*ny + j, ...]

        return output

    else:
        raise Exception('Overlap !=0 not yet implemented')


# Util for plotting the reconstruction of a clip for a given model.
def plot_clip_reconstruction(data_loader, model, device, patch_size=16):
    """
    Plots the reconstruction of a randomly selected clip after passing through a model.
    
    This function:
    1. Selects a random clip from the data loader's training set
    2. Converts the clip into patches
    3. Passes the patches through the provided model
    4. Reassembles the patches into full frames
    5. Displays the reconstructed frames
    
    Parameters:
    -----------
    data_loader : object
        A data loader object containing training clips.
    
    model : torch.nn.Module
        The model to evaluate for reconstruction quality.
        Expected to take input of shape (T, num_patches, patch_size*patch_size)
        and return output of the same shape.
    
    device : torch.device
        The device (CPU/GPU) where the model and data should be placed.
    
    patch_size : int, default=16
        The size of each square patch (patch_size × patch_size).
    
    Returns:
    --------
    None
        Displays a matplotlib figure with the reconstructed frames.
    """
    # Randomly select a clip
    clip_index = random.randint(0, len(data_loader.train_clips) - 1)
    data = data_loader.train_clips[clip_index]
    
    # Display the shape of the original data
    print(f"Original data shape: {data.shape}")  # (T, H, W)
    
    # Create patches
    # Expected shape: (T, P, patch_size, patch_size) where T=timesteps, P=number of patches
    patched = create_patches(data, patch_size)
    # Verify the shape matches our expectations
    print(f"Patched data shape: {patched.shape}")  # Should be (T, P, patch_size, patch_size)
    patched_tensor = torch.tensor(patched).to(device)
    
    # Reshape for the model
    reshaped = patched_tensor.permute(1, 0, 2, 3).reshape(data.shape[0], -1, patch_size * patch_size)
    
    # Pass through the model
    output = model(reshaped) 
    
    # Recover the original image
    output = output.reshape(data.shape[0], -1, patch_size, patch_size)
    assembled_output = assemble_patches(output.cpu().detach().numpy(), patch_size, sx=data.shape[1], sy=data.shape[2])
    
    # Display the shape of the assembled output
    print(f"Assembled output shape: {assembled_output.shape}")
    
    # Plot the original and reconstructed frames
    fig, axes = plt.subplots(4, 5, figsize=(20, 16))
    axes = axes.flatten()
    
    for i in range(data.shape[0]):
        ax = axes[i]
        ax.imshow(assembled_output[i], cmap='gray')
        ax.set_title(f'Frame {i+1}')
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()