import torch
from attrdict import AttrDict
from torch.distributions import StudentT

def img_to_task(img, num_ctx=None,
        max_num_points=None, target_all=False, t_noise=None, conv_mode=False):

    B, C, H, W = img.shape
    num_pixels = H*W
    img = img.view(B, C, -1)

    if t_noise is not None:
        if t_noise == -1:
            t_noise = 0.09 * torch.rand(img.shape)
        img += t_noise * StudentT(2.1).rsample(img.shape)

    batch = AttrDict()
    max_num_points = max_num_points or num_pixels
    num_ctx = num_ctx or \
            torch.randint(low=3, high=max_num_points-3, size=[1]).item()
    num_tar = max_num_points - num_ctx if target_all else \
            torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item()
    num_points = num_ctx + num_tar

    idxs = torch.cuda.FloatTensor(B, num_pixels).uniform_().argsort(-1)[...,:num_points].to(img.device)
    x1, x2 = idxs//W, idxs%W
    batch.x = torch.stack([
        2*x1.float()/(H-1) - 1,
        2*x2.float()/(W-1) - 1], -1).to(img.device)
    batch.y = (torch.gather(img, -1, idxs.unsqueeze(-2).repeat(1, C, 1))\
            .transpose(-2, -1) - 0.5).to(img.device)
    
    if conv_mode:
        batch.yc = img.reshape(B, C, H, W)
        batch.xc = torch.zeros((B, C, H*W)).to(img.device)
        for i in range(B):
            batch.xc[i,:, idxs[i,:num_ctx]] = 1
        batch.xc = batch.xc.reshape(B, C, H, W)
        
        batch.yt = img.reshape(B, C, H, W)
        batch.xt = torch.zeros((B,C, H*W)).to(img.device)
        for i in range(B):
            batch.xt[i,:, idxs[i,num_ctx:]] = 1
        batch.xt = batch.xt.reshape(B,C, H, W)
        
        batch.y = img.reshape(B, C, H, W)
        batch.x = torch.zeros((B,C, H*W)).to(img.device)
        for i in range(B):
            batch.x[i, :,idxs[i]] = 1
        batch.x = batch.x.reshape(B,C, H, W)
    
    else:        
        batch.xc = batch.x[:,:num_ctx]
        batch.xt = batch.x[:,num_ctx:]
        batch.yc = batch.y[:,:num_ctx]
        batch.yt = batch.y[:,num_ctx:]

    return batch

def video_to_task(video, num_ctx=None,
        max_num_points=None, target_all=False, t_noise=None):

    B, C, T, H, W = video.shape
    num_pixels = T*H*W
    video = video.view(B, C, -1)

    if t_noise is not None:
        if t_noise == -1:
            t_noise = 0.09 * torch.rand(video.shape)
        video += t_noise * StudentT(2.1).rsample(video.shape)

    batch = AttrDict()
    max_num_points = max_num_points or num_pixels
    num_ctx = num_ctx or \
            torch.randint(low=3, high=max_num_points-3, size=[1]).item()
    num_tar = max_num_points - num_ctx if target_all else \
            torch.randint(low=3, high=max_num_points-num_ctx, size=[1]).item()
    num_points = num_ctx + num_tar
    idxs = torch.cuda.FloatTensor(B, num_pixels).uniform_().argsort(-1)[...,:num_points].to(video.device)
    x1, x2, x3 = (idxs//W)//H, (idxs//W)%H, idxs%W
    batch.x = torch.stack([
        2*x1.float()/(T-1) - 1,
        2*x2.float()/(H-1) - 1,
        2*x3.float()/(W-1) - 1,], -1).to(video.device)
    batch.y = (torch.gather(video, -1, idxs.unsqueeze(-2).repeat(1, C, 1))\
            .transpose(-2, -1) - 0.5).to(video.device)

    batch.xc = batch.x[:,:num_ctx]
    batch.xt = batch.x[:,num_ctx:]
    batch.yc = batch.y[:,:num_ctx]
    batch.yt = batch.y[:,num_ctx:]

    return batch

def img_to_task_landmark(img, 
                         labels, 
                         num_ctx=None,
                         max_num_points=None, 
                         target_all=False, 
                         device=None):
    
    B, C, H, W = img.shape
    num_pixels = H * W
    device = device or img.device

    img_flat = img.view(B, C, -1)  # (B, C, H*W) 
    labels_flat = labels.view(B, -1, labels.shape[-1])  
    
    batch = AttrDict()
    num_ctx = num_ctx or torch.randint(low=3, high=max_num_points // 2, size=[1]).item()
    if target_all:
        num_tar = max_num_points - num_ctx
    else:
        num_tar = torch.randint(low=3, high= max_num_points - num_ctx, size=[1]).item()
        
    # random pixel permutation 
    idxs = torch.randperm(num_pixels, device=device).unsqueeze(0).repeat(B, 1)

    ctx_idxs = idxs[:, :num_ctx] # (B, num_ctx)
    tar_idxs = idxs[:, num_ctx:num_ctx + num_tar] # (B, num_target)  

    ctx_x1, ctx_x2 = ctx_idxs // W, ctx_idxs % W
    tar_x1, tar_x2 = tar_idxs // W, tar_idxs % W

    ctx_coords = torch.stack([
        2 * ctx_x1.float() / (H - 1) - 1, 
        2 * ctx_x2.float() / (W - 1) - 1
    ], dim=-1) # (B, num_target, 2)

    tar_coords = torch.stack([
        2 * tar_x1.float() / (H - 1) - 1,
        2 * tar_x2.float() / (W - 1) - 1
    ], dim=-1) # (B, num_target, 2)
    
    # (B, num_context, 3)
    ctx_rgb_values = torch.gather(img_flat, -1, ctx_idxs.unsqueeze(1).repeat(1, C, 1)).transpose(1, 2)
    # (B, num_target, 3)
    tar_rgb_values = torch.gather(img_flat, -1, tar_idxs.unsqueeze(1).repeat(1, C, 1)).transpose(1, 2)

    # (B, num_context, 5)
    ctx_labels = torch.gather(labels_flat, 1, ctx_idxs.unsqueeze(-1).expand(-1, -1, labels.shape[-1]))
    # (B, num_target, 5)
    tar_labels = torch.gather(labels_flat, 1, tar_idxs.unsqueeze(-1).expand(-1, -1, labels.shape[-1]))
    
    ctx_labels = ctx_labels.argmax(dim=-1).float()
    tar_labels = tar_labels.argmax(dim=-1).float()

    all_coords = torch.cat([ctx_coords, tar_coords], dim=1)
    all_rgb_values = torch.cat([ctx_rgb_values, tar_rgb_values], dim=1)
    all_labels = torch.cat([ctx_labels, tar_labels], dim=1)

    batch.x = torch.cat([all_coords, all_rgb_values], dim=-1) # (B, num_context + num_target, 2+3)
    batch.y = all_labels.unsqueeze(-1) # (B, num_context + num_target, 1)

    batch.xc = torch.cat([ctx_coords, ctx_rgb_values], dim=-1) 
    batch.xt = torch.cat([tar_coords, tar_rgb_values], dim=-1) 
    
    batch.yc = ctx_labels.unsqueeze(-1)  
    batch.yt = tar_labels.unsqueeze(-1)  

    return batch


def img_to_task_landmark2(
    img, 
    labels, 
    num_ctx=None, 
    max_num_points=None, 
    target_all=False, 
    device=None
):
    B, C, H, W = img.shape
    num_pixels = H * W
    device = device or img.device

    img_flat = img.view(B, C, -1)
    labels_flat = labels.view(B, -1, labels.shape[-1])

    batch = AttrDict()

    num_ctx = num_ctx or torch.randint(low=3, high=max_num_points // 2, size=[1]).item()
    num_tar = max_num_points - num_ctx if target_all else torch.randint(low=3, high=max_num_points - num_ctx, size=[1]).item()
    
    ctx_half = num_ctx // 2
    tar_half = num_tar // 2

    ctx_idxs = []
    tar_idxs = []

    for b in range(B):
        label_idx = (labels_flat[b].sum(dim=-1) > 0).nonzero(as_tuple=True)[0]
        no_label_idx = (labels_flat[b].sum(dim=-1) == 0).nonzero(as_tuple=True)[0]

        available_ctx_label = min(ctx_half, len(label_idx))
        available_ctx_no_label = min(num_ctx - available_ctx_label, len(no_label_idx))

        available_tar_label = min(tar_half, len(label_idx))
        available_tar_no_label = min(num_tar - available_tar_label, len(no_label_idx))

        ctx_label = label_idx[torch.randperm(len(label_idx), device=device)[:available_ctx_label]]
        ctx_no_label = no_label_idx[torch.randperm(len(no_label_idx), device=device)[:available_ctx_no_label]]
        tar_label = label_idx[torch.randperm(len(label_idx), device=device)[:available_tar_label]]
        tar_no_label = no_label_idx[torch.randperm(len(no_label_idx), device=device)[:available_tar_no_label]]

        ctx_idxs.append(torch.cat([ctx_label, ctx_no_label], dim=0))
        tar_idxs.append(torch.cat([tar_label, tar_no_label], dim=0))

    ctx_idxs = torch.stack(ctx_idxs)
    tar_idxs = torch.stack(tar_idxs)
    
    ctx_x1, ctx_x2 = ctx_idxs // W, ctx_idxs % W
    tar_x1, tar_x2 = tar_idxs // W, tar_idxs % W

    ctx_coords = torch.stack([
        2 * ctx_x1.float() / (H - 1) - 1,
        2 * ctx_x2.float() / (W - 1) - 1
    ], dim=-1)

    tar_coords = torch.stack([
        2 * tar_x1.float() / (H - 1) - 1,
        2 * tar_x2.float() / (W - 1) - 1
    ], dim=-1)
    
    ctx_rgb_values = torch.gather(img_flat, -1, ctx_idxs.unsqueeze(1).repeat(1, C, 1)).transpose(1, 2)
    tar_rgb_values = torch.gather(img_flat, -1, tar_idxs.unsqueeze(1).repeat(1, C, 1)).transpose(1, 2)

    ctx_labels = torch.gather(labels_flat, 1, ctx_idxs.unsqueeze(-1).expand(-1, -1, labels.shape[-1]))
    tar_labels = torch.gather(labels_flat, 1, tar_idxs.unsqueeze(-1).expand(-1, -1, labels.shape[-1]))
    
    ctx_labels = ctx_labels.argmax(dim=-1).float()
    tar_labels = tar_labels.argmax(dim=-1).float()

    all_coords = torch.cat([ctx_coords, tar_coords], dim=1)
    all_rgb_values = torch.cat([ctx_rgb_values, tar_rgb_values], dim=1)
    all_labels = torch.cat([ctx_labels, tar_labels], dim=1)

    batch.x = torch.cat([all_coords, all_rgb_values], dim=-1)
    batch.y = all_labels.unsqueeze(-1)

    batch.xc = torch.cat([ctx_coords, ctx_rgb_values], dim=-1)
    batch.xt = torch.cat([tar_coords, tar_rgb_values], dim=-1)
    
    batch.yc = ctx_labels.unsqueeze(-1)
    batch.yt = tar_labels.unsqueeze(-1)

    return batch


def coord_to_img(x, y, shape):
    x = x.cpu()
    y = y.cpu()
    B = x.shape[0]
    C, H, W = shape

    I = torch.zeros(B, 3, H, W)
    I[:,0,:,:] = 0.61
    I[:,1,:,:] = 0.55
    I[:,2,:,:] = 0.71

    x1, x2 = x[...,0], x[...,1]
    x1 = ((x1+1)*(H-1)/2).round().long()
    x2 = ((x2+1)*(W-1)/2).round().long()
    for b in range(B):
        for c in range(3):
            I[b,c,x1[b],x2[b]] = y[b,:,min(c,C-1)]

    return I

def task_to_img(xc, yc, xt, yt, shape):
    xc = xc.cpu()
    yc = yc.cpu()
    xt = xt.cpu()
    yt = yt.cpu()

    B = xc.shape[0]
    C, H, W = shape

    xc1, xc2 = xc[...,0], xc[...,1]
    xc1 = ((xc1+1)*(H-1)/2).round().long()
    xc2 = ((xc2+1)*(W-1)/2).round().long()

    xt1, xt2 = xt[...,0], xt[...,1]
    xt1 = ((xt1+1)*(H-1)/2).round().long()
    xt2 = ((xt2+1)*(W-1)/2).round().long()

    task_img = torch.zeros(B, 3, H, W).to(xc.device)
    task_img[:,2,:,:] = 1.0
    task_img[:,1,:,:] = 0.4
    for b in range(B):
        for c in range(3):
            task_img[b,c,xc1[b],xc2[b]] = yc[b,:,min(c,C-1)] + 0.5
    task_img = task_img.clamp(0, 1)

    completed_img = task_img.clone()
    for b in range(B):
        for c in range(3):
            completed_img[b,c,xt1[b],xt2[b]] = yt[b,:,min(c,C-1)] + 0.5
    completed_img = completed_img.clamp(0, 1)

    return task_img, completed_img
