import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import numpy as np


MAX_FLOW = 400
def flow_sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW):

    n_predictions = len(flow_preds)    
    flow_loss = 0.0

    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
    valid = (valid >= 0.5) & (mag < max_flow)

    for i in range(n_predictions):
        i_weight = gamma**(n_predictions - i - 1)
        i_loss = (flow_preds[i] - flow_gt).abs()
        flow_loss += i_weight * (valid[:, None] * i_loss).mean()

    return flow_loss


def trimmed_flow_sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, max_flow=MAX_FLOW, trim=0.5):

    n_predictions = len(flow_preds)
    flow_loss = 0.0
    B, _, H, W = flow_gt.shape

    # exclude invalid pixels and large flows
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
    valid = (valid >= 0.5) & (mag < max_flow) # [B, H, W]

    for i in range(n_predictions):
        i_weight = gamma**(n_predictions - i - 1)

        # L1 error: [B, H, W]
        error = (flow_preds[i] - flow_gt).abs().mean(dim=1)

        trimmed_losses = []
        for b in range(B):
            e = error[b]   # [H, W]
            v = valid[b]   # [H, W]
            e_valid = e[v] # [N_valid_b]

            k = int(e_valid.numel() * (1.0 - trim))
            trimmed, _ = torch.sort(e_valid, descending=False)
            trimmed = trimmed[:k]
            trimmed_losses.append(trimmed.mean())

        loss_i = torch.stack(trimmed_losses).mean()
        flow_loss += i_weight * loss_i
        
    return flow_loss


def calculate_epe_batch(flow_pred, flow_gt, valid, max_flow=MAX_FLOW):
    # exlude invalid pixels and extremely large diplacements
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
    valid = (valid >= 0.5) & (mag < max_flow)
    
    epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt()
    epe = epe.view(-1)[valid.view(-1)]
    return epe.mean().item()


def calulate_metric_val(flow_pred, flow_gt, valid, max_flow=MAX_FLOW):
    epe_list = np.array([], dtype=np.float32)
    num_valid_pixels = 0
    out_valid_pixels = 0
    
    mag = torch.sum(flow_gt**2, dim=1).sqrt()
    val = (valid >= 0.5) & (mag < max_flow)
    
    epe = torch.sum((flow_pred - flow_gt)**2, dim=1).sqrt()
    out = ((epe > 3.0) & ((epe/mag) > 0.05)).float()
    for b in range(out.shape[0]):
        epe_list = np.append(epe_list, epe[b][val[b]].mean().cpu().numpy())
        out_valid_pixels += out[b][val[b]].sum().cpu().numpy()
        num_valid_pixels += val[b].sum().cpu().numpy()
    return epe_list, out_valid_pixels, num_valid_pixels


def syn_flow_data(img_source, dpt_source, K_source, K_target, T_s2t, thr_dpt=1e-5, thr_valid=10):
    H, W, C = img_source.shape
    img_source = torch.from_numpy(img_source).float() # [H, W, 3]
    dpt_source = torch.from_numpy(dpt_source).float()  # [1, H, W]

    K_source = torch.from_numpy(K_source).float()
    K_target = torch.from_numpy(K_target).float()
    T_s2t = torch.cat([torch.from_numpy(T_s2t).float(), torch.tensor([[0, 0, 0, 1]])], dim=0)

    # mesh grid
    y, x = torch.meshgrid(torch.arange(H), torch.arange(W), indexing='ij')
    x, y = x.reshape(-1).float(), y.reshape(-1).float()
    z = dpt_source.reshape(-1)
    z[z <= 0] = thr_dpt

    fx_s, fy_s = K_source[0, 0], K_source[1, 1]
    cx_s, cy_s = K_source[0, 2], K_source[1, 2]
    X_s = (x - cx_s) / fx_s * z
    Y_s = (y - cy_s) / fy_s * z
    Z_s = z
    pts_source = torch.stack([X_s, Y_s, Z_s, torch.ones_like(Z_s)], dim=0)  # [4, N]

    # source -> target
    pts_target = T_s2t @ pts_source
    X_t, Y_t, Z_t = pts_target[0], pts_target[1], pts_target[2].clamp(min=thr_dpt)

    fx_t, fy_t = K_target[0, 0], K_target[1, 1]
    cx_t, cy_t = K_target[0, 2], K_target[1, 2]

    u_t = (fx_t * X_t / Z_t) + cx_t 
    v_t = (fy_t * Y_t / Z_t) + cy_t

    flow_s2t_u = (u_t - x).reshape(1, H, W)
    flow_s2t_v = (v_t - y).reshape(1, H, W)
    flow_s2t = torch.cat([flow_s2t_u, flow_s2t_v], dim=0) # [2, H, W]
    
    img_target, valid_mask_target = forward_warp_bilinear(img_source, u_t, v_t)

    u_t_norm = (2 * u_t / (W - 1) - 1).reshape(H, W, 1)
    v_t_norm = (2 * v_t / (H - 1) - 1).reshape(H, W, 1)
    coords_t_norm = torch.cat([u_t_norm, v_t_norm], dim=-1)  # [H, W, 2]
    img_source_warp = F.grid_sample(img_target.permute(2,0,1).unsqueeze(0), coords_t_norm.unsqueeze(0), mode='bilinear', align_corners=True, padding_mode='zeros').squeeze(0).permute(1, 2, 0)
    
    valid_mask = ((img_source - img_source_warp).abs().mean(dim=2) <= thr_valid).float()
    return flow_s2t, valid_mask, img_target, img_source_warp


def forward_warp_bilinear(img_source, u_t, v_t):

    H, W, C = img_source.shape
    device = img_source.device

    x1 = u_t
    y1 = v_t
    src = img_source.reshape(-1, C)

    x0 = x1.floor().long()
    y0 = y1.floor().long()
    x1_ = x0 + 1
    y1_ = y0 + 1

    wx = (x1 - x0.float()).clamp(0, 1)
    wy = (y1 - y0.float()).clamp(0, 1)

    weights = [
        (1 - wx) * (1 - wy),  # top-left
        (1 - wx) * wy,        # bottom-left
        wx * (1 - wy),        # top-right
        wx * wy               # bottom-right
    ]

    coords = [
        (x0, y0),
        (x0, y1_),
        (x1_, y0),
        (x1_, y1_)
    ]

    img_target = torch.zeros(H, W, C, device=device)
    count = torch.zeros(H, W, 1, device=device)

    for w, (xg, yg) in zip(weights, coords):
        mask = (xg >= 0) & (xg <= W - 1) & (yg >= 0) & (yg <= H - 1)
        xg = xg[mask].clamp(0, W - 1)
        yg = yg[mask].clamp(0, H - 1)
        ww = w[mask].unsqueeze(1)  # [N, 1]
        color = src[mask] * ww  # [N, C]

        for c in range(C):
            img_target[:, :, c].index_put_((yg, xg), color[:, c], accumulate=True)
        count.index_put_((yg, xg), ww, accumulate=True)

    mask = count > 0
    img_target[mask.expand_as(img_target)] /= count.expand_as(img_target)[mask.expand_as(img_target)]
    valid_mask = mask.squeeze(-1).float()

    return img_target, valid_mask
