import torch
import math
import matplotlib.pyplot as plt
from torchvision import transforms
import PL_distribution as PL

def perm_list_to_mat(pi):
    I = torch.eye(pi.size(-1)).to(pi.device)
    return I[pi].float()

def permute_image_perm_list(perm_list, x):
    """
    Apply the permutation to a batch of image chunks
    Args:
        perm_list: [batch_shape, n]
        x: [batch_shape, n, c, h, w]

    Returns:
        Permuted set of image chunks.
    """
    x, perm_list = torch.broadcast_tensors(x, perm_list[(...,) + (None,) * 3])
    return torch.gather(x, -4, perm_list)

def permute_int_list(perm_list, x):
    """
    Apply the permutation to x

    Args:
        perm_list: [batch_shape, n]
        x: [batch_shape, n]

    Returns:
        shape [batch_shape, n]
    """
    x, perm_list = torch.broadcast_tensors(x, perm_list)
    return torch.gather(x, -1, perm_list).long()

def permute_list(perm_list, x):
    """
    Apply the permutation to x

    Args:
        perm_list: [batch_shape, n]
        x: [batch_shape, n]

    Returns:
        shape [batch_shape, n]
    """
    x, perm_list = torch.broadcast_tensors(x, perm_list)
    return torch.gather(x, -1, perm_list)

def permute_embd(perm_list, x):
    """
    Args:
        perm_list: [batch_shape, n]
        x: [batch_shape, n, d]

    Returns:
        shape [batch_shape, n, d]
    """
    x, perm_list = torch.broadcast_tensors(x, perm_list.unsqueeze(-1))
    return torch.gather(x, -2, perm_list)

@torch.no_grad()
def insert_back_to_idx(x, idx):
    """
    Args:
        x: shape [batch_shape, n]
        idx: shape [batch_shape]
    """
    range_tensor = torch.arange(x.size(-1)).to(x.device)
    mask = range_tensor >= idx.unsqueeze(-1)
    rolled_x = torch.roll(x, shifts=1, dims=-1)
    rearranged = torch.where(mask, rolled_x, x)
    rearranged[range_tensor == idx.unsqueeze(-1)] = x[..., -1].flatten()

    return rearranged

@torch.no_grad()
def insert_idx_to_back(x, idx):
    """
    Args:
        x: shape [batch_shape, n]
        idx: shape [batch_shape]
    """
    range_tensor = torch.arange(x.size(-1)).to(x.device)
    mask = range_tensor >= idx.unsqueeze(-1)
    rolled_x = torch.roll(x, shifts=-1, dims=-1)
    rearranged = torch.where(mask, rolled_x, x)
    index_elements = torch.gather(x, -1, idx.unsqueeze(-1)).squeeze(-1)
    rearranged[..., -1] = index_elements

    return rearranged

@torch.no_grad()
def insert_back_to_idx_images(x, idx):
    """
    Args:
        x: shape [b, n, c, h, w]
        idx: shape [b]
    """
    range_tensor = torch.arange(x.size(-4)).to(x.device) # [n]
    roll_mask = (range_tensor >= idx.unsqueeze(-1))[(...,) + (None,) * 3] # [b, n, 1, 1, 1]
    rolled_x = torch.roll(x, shifts=1, dims=-4)
    rearranged = torch.where(roll_mask, rolled_x, x)

    replace_mask = (range_tensor == idx.unsqueeze(-1))[(...,) + (None,) * 3]
    result = torch.where(replace_mask, x[..., [-1], :, :, :], rearranged)

    return result

@torch.no_grad()
def insert_idx_to_back_images(x, idx):
    """
    Args:
        x: shape [batch_shape, n, c, h, w]
        idx: shape [batch_shape]
    """
    range_tensor = torch.arange(x.size(-4)).to(x.device)
    roll_mask = (range_tensor >= idx.unsqueeze(-1))[(...,) + (None,) * 3] # [b, n, 1, 1, 1]
    rolled_x = torch.roll(x, shifts=-1, dims=-4)
    rearranged = torch.where(roll_mask, rolled_x, x)

    idx = idx[(...,) + (None,) * 4].expand(*((-1,) * (x.dim() - 3)), *x.shape[-3:])
    index_elements = torch.gather(x, -4, idx) # [b, 1, c, h, w]
    rearranged[..., [-1], :, :, :] = index_elements

    return rearranged

@torch.no_grad()
def swap_by_idx(x, idx):
    """
    Args:
        x: shape [batch_shape, n]
        idx: shape [batch_shape, 2] 
    """
    x_swapped = x.clone()
    first = x.gather(-1, idx[..., 0:1])
    second = x.gather(-1, idx[..., 1:2])
    x_swapped.scatter_(-1, idx[..., 0:1], second)
    x_swapped.scatter_(-1, idx[..., 1:2], first)

    return x_swapped

@torch.no_grad()
def swap_by_idx_images(x, idx):
    """
    Args:
        x: shape [batch_shape, n, c, h, w]
        idx: shape [batch_shape, 2] 
    """
    idx = idx[(...,) + (None,) * 3].expand(*((-1,) * (x.dim() - 3)), *x.shape[-3:])
    first_idx = idx[..., [0], :, :, :]
    second_idx = idx[..., [1], :, :, :]
    first = torch.gather(x, -4, first_idx)
    second = torch.gather(x, -4, second_idx)
    
    x_swapped = x.clone()
    x_swapped.scatter_(-4, first_idx, second)
    x_swapped.scatter_(-4, second_idx, first)

    return x_swapped

@torch.no_grad()
def complete_range(x, n):
    """
    Args:
        x: shape [batch_shape, k]
        n: int
        1 <= k <= n
    Returns:
        shape [batch_shape, n]
    """
    device = x.device
    batch_shape = x.shape[:-1]
    k = x.size(-1)

    all_numbers = torch.arange(n, device=device)
    comparison = (x.unsqueeze(-1) == all_numbers) # [batch_shape, k, n]

    mask = comparison.any(-2) # [batch_shape, n]
    missing_numbers = torch.masked_select(all_numbers, ~mask).reshape(*batch_shape, n - k)

    return torch.cat([x, missing_numbers], dim=-1)

@torch.no_grad()
def plot_image(image, num_pieces, save="show"):
    """
    Plot chunked images.

    Args:
        num_pieces: int
        image: tensor shape [num_pieces**2, 1, h, w]
    """
    pixel_min = torch.min(image)
    pixel_max = torch.max(image)
    figs, axs = plt.subplots(num_pieces, num_pieces, figsize=(4,4))
    for idx, piece in enumerate(image):
        axs[idx // num_pieces, idx % num_pieces].imshow(piece.cpu().squeeze(), vmin=pixel_min, vmax=pixel_max)
    plt.subplots_adjust(wspace=0, hspace=0)

    for ax_row in axs:
        for ax in ax_row:
            ax.set_xticks([])
            ax.set_yticks([])
    
    if save == "show":
        plt.show()
    else:
        plt.savefig(f"./demo_images/{save}.png")

@torch.no_grad()
def plot_CIFAR10_image(image, num_pieces, save="show"):
    """
    Plot chunked CIFAR10 images.

    Args:
        image: tensor shape [num_pieces**2, 3, h, w]
        num_pieces: int
    """
    pixel_min = torch.min(image)
    pixel_max = torch.max(image)
    figs, axs = plt.subplots(num_pieces, num_pieces, figsize=(4, 4))
    for idx, piece in enumerate(image):
        inv_trans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                              std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                         transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                              std = [ 1., 1., 1. ]),
                                      ])
        img = inv_trans(img)
        # Transpose the image to have the channels as the last dimension
        img = piece.cpu().permute(1, 2, 0).squeeze()
        axs[idx // num_pieces, idx % num_pieces].imshow(img, vmin=pixel_min, vmax=pixel_max)
    plt.subplots_adjust(wspace=0, hspace=0)

    for ax_row in axs:
        for ax in ax_row:
            ax.set_xticks([])
            ax.set_yticks([])

    if save == "show":
        plt.show()
    else:
        plt.savefig(f"./demo_images/{save}.png")

@torch.no_grad()
def plot_image_serial(image, num_pieces, save="show"):
    """
    Plot chunked images.

    Args:
        num_pieces: int
        image: tensor shape [num_pieces, 1, h, w]
    """
    pixel_min = torch.min(image)
    pixel_max = torch.max(image)
    figs, axs = plt.subplots(nrows=num_pieces, ncols=1, figsize=(4 * num_pieces, 4))
    for idx, piece in enumerate(image):
        axs[idx % num_pieces].imshow(piece.cpu().squeeze(), vmin=pixel_min, vmax=pixel_max)
    plt.subplots_adjust(wspace=0, hspace=0)

    for ax in axs:
        ax.set_xticks([])
        ax.set_yticks([])
    
    if save == "show":
        plt.show()
    else:
        plt.savefig(f"./demo_images/{save}.png")

@torch.no_grad()
def plot_non_chunked_image(image):
    plt.imshow(image.squeeze())
    plt.show()

@torch.no_grad()
def find_perm_images(x1, x2):
    """
    Find the perm such that applying the perm to x1 gives x2
    x1 --perm--> x2

    Args:
        x_1: shape [batch_shape, n, c, h, w]
        x_2: shape [batch_shape, n, c, h, w]

    Returns:
        shape [batch_shape, n]
    """
    equality_matrix = torch.cdist(x1.flatten(start_dim=-3), x2.flatten(start_dim=-3), p=1)
    perm_list = torch.argmax((equality_matrix < 1e-8).int(), dim=-2) # shape [batch, n]
    return perm_list

@torch.no_grad()
def find_perm(x1, x2):
    """
    Find the perm such that applying the perm to x1 gives x2
    x1 --perm--> x2

    Args:
        x_1: shape [batch_shape, n]
        x_2: shape [batch_shape, n]

    Returns:
        shape [batch_shape, n]
    """
    x1 = x1.float().unsqueeze(-1)
    x2 = x2.float().unsqueeze(-1)
    equality_matrix = torch.cdist(x1, x2, p=1)
    perm_list = torch.argmax((equality_matrix < 1e-8).int(), dim=-2) # shape [batch, n]
    return perm_list

@torch.no_grad()
def log_prob_normal_dist_images(x, mean, var=1., no_const=False):
    """
    Computes log p(x) under N(x | mean, var I)

    Args:
        x: shape [batch_shape, n, c, h, w]
        mean: shape [batch_shape, n, c, h, w]

    Returns:
        shape [batch_shape]
    """
    x = x.flatten(start_dim=-4)
    mean = mean.flatten(start_dim=-4)
    D = x.size(-1)

    mse = - ((x - mean) * (x - mean)).sum(-1) / (2 * var)

    if no_const:
        return mse
    else:
        return - D * math.log(2 * math.pi) / 2 - math.log(var) / 2 + mse

@torch.no_grad()
def count_rising_sequence(perm):
    """
    Args:
        perms: [batch, n]

    Returns:
        [batch]
    """
    return (torch.diff(perm) < 0).sum(-1) + 1

@torch.no_grad()
def batch_randperm(batch, n):
    device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
    uniform_scores = torch.zeros(batch, n, device=device)
    randperms = PL.sample(uniform_scores, 1).squeeze(0) # [batch, n]
    return randperms
