import cv2
import torch
pdist = torch.nn.PairwiseDistance(p=2)

from .depth_anything_v2.dpt import DepthAnythingV2

def get_rectangle(mask: torch.Tensor):
    N,X,H,W = mask.shape  
    index_1 = torch.nonzero(mask)   
    min_y,min_x = torch.min(index_1,dim=0)[0][-2:]
    max_y,max_x = torch.max(index_1,dim=0)[0][-2:]
    left_top = torch.Tensor((min_y, min_x)).to(device=mask.device)
    left_bottom = torch.Tensor((min_y, max_x)).to(device=mask.device)
    right_top = torch.Tensor((max_y, min_x)).to(device=mask.device)
    right_bottom = torch.Tensor((max_y, max_x)).to(device=mask.device)
    rect = torch.stack((left_top, left_bottom, right_top, right_bottom),dim=0).to(device=mask.device)
    return rect, left_top, left_bottom, right_top, right_bottom

def get_circle(mask: torch.Tensor):
    rect, left_top, left_bottom, right_top, right_bottom = get_rectangle(mask=mask)
    center = torch.Tensor(((left_top[0] + right_bottom[0]) / 2, (left_top[1] + right_bottom[1]) / 2)).to(device=mask.device)  # y,x
    radius = pdist(center, left_top) 
    return center,radius


def estimate_beta_from_depth(depth: torch.Tensor,
                              mask: torch.Tensor,
                              center_contrast=0.08,
                              slope=3.5,
                              min_beta=1.0,
                              max_beta=1.7) -> float:
    masked_depth = depth[mask.bool()]
    if masked_depth.numel() == 0:
        return 1.8  # fallback

    std = torch.std(masked_depth)
    mean = torch.mean(masked_depth)
    contrast = std / (mean + 1e-6)
    norm_c = torch.sigmoid(-slope * (contrast - center_contrast))
    beta = min_beta + (max_beta - min_beta) * norm_c
    return beta.item()

def compute_space_weight(handle, mask, H, W, O, R, beta, device, mode='auto', depth=None):
    O = O.flip(0)
    index_1 = torch.nonzero(mask, as_tuple=False)  # [N, 2]
    grid = index_1[:, [1, 0]].float()  

    delta = grid - handle[None, None, :] 
    delta_norm = delta.norm(dim=-1) + 1e-6
    OA = O - handle # [2]
    OA_norm = OA.norm(dim=-1) + 1e-6
    cos_theta = (delta @ OA) / (delta_norm * OA_norm)  
    sign_map = torch.where(cos_theta > 0, torch.tensor(1.0, device=device), torch.tensor(-1.0, device=device))  # [H, W]
    L = R + sign_map * OA_norm # [H, W]
    L = torch.clamp(L, min=R)  
    if mode == 'auto':
        beta = estimate_beta_from_depth(depth, mask, center_contrast=0.5, slope=beta, min_beta=1.15, max_beta=1.8)
        print(f"Auto estimated beta: {beta}")
    weight = 1 - (delta_norm / L) ** beta
    
    weight_full = torch.zeros((H, W), device=device)
    weight_full[index_1[:, 0], index_1[:, 1]] = weight
    return weight_full

def compute_space_weight_accurate(handle, mask, H, W, O, R, beta, device, mode='auto', depth=None):
    O = O.flip(0)
    index_1 = torch.nonzero(mask, as_tuple=False)  # [N, 2]
    grid = index_1[:, [1, 0]].float() 
    delta = grid - handle[None, :]  
    delta_norm = delta.norm(dim=-1) + 1e-6
    OA = handle - O # [2]
    OA_norm = OA.norm(dim=-1)
    direction = delta / delta_norm[:, None]  # [H, W, 2]
    b = (direction * OA).sum(dim=-1)  # [N]
    c = OA_norm ** 2 - R ** 2  # scalar
    discriminant = b**2 - c  
    discriminant = torch.clamp(discriminant, min=0.0)
    L = -b + torch.sqrt(discriminant)  # [N]
    if torch.any(L < 0):
        print('error L < 0')
    L = torch.clamp(L, min=R)  
    if mode == 'auto':
        beta = estimate_beta_from_depth(depth, mask, center_contrast=0.5, slope=beta, min_beta=1.15, max_beta=1.8)
        print(f"Auto estimated beta: {beta}")
    weight = 1 - (delta_norm / L) ** beta
    if torch.any(weight < 0) or torch.any(weight > 1):
        print('error weight')
    weight_full = torch.zeros((H, W), device=device)
    weight_full[index_1[:, 0], index_1[:, 1]] = weight
    return weight_full

def estimate_alpha_from_depth(depth, mask, beta=4, center_contrast=1.0, min_alpha=0.5, max_alpha=1.2):
    masked_depth = depth[mask.bool()]
    if masked_depth.numel() == 0:
        return 1.0  # fallback
    depth_std = torch.std(masked_depth)
    depth_mean = torch.mean(masked_depth)
    contrast = depth_std / (depth_mean + 1e-6)  # depth variation ratio
    # Normalize contrast to [0, 1]
    norm_c = torch.sigmoid(beta * (contrast - center_contrast))  # adjust 4.0 or center as needed
    alpha = min_alpha + (max_alpha - min_alpha) * norm_c

    return alpha.item()

def compute_depth_weight_mask_based(depth, mask, handle, upper_scale=1.5, lower_scale=0.5, mode='linear', alpha=2.0):
    x, y = handle.long()
    d_h = depth[y, x].clamp(min=1e-4)
    if mode == 'auto':
        alpha = estimate_alpha_from_depth(depth, mask, beta=alpha)
        print(f"Auto estimated alpha: {alpha}")
    f = (depth / d_h) ** alpha
    f = torch.clamp(f, min=lower_scale, max=upper_scale)
    f[mask == 0] = 0.0
    return f

# sota
def drag_image(source_image, 
               mask, 
               handle_points, 
               target_points,
               gamma_ratio=0.5,
               lambda_mix=None,
               upper_scale=2,
               lower_scale=0.5,
               alpha=2.0,
               beta=2.0,
               scale=1,
               **kwargs):   
    test_fusion = kwargs.get('test_fusion', 'RDA')
    device = mask.device
    model_configs = {
        'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
        'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
        'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
        'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
    }
    encoder = 'vitl' # or 'vits', 'vitb', 'vitg'
    model = DepthAnythingV2(**model_configs[encoder])
    model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location='cpu'))
    model = model.to(device).eval()
    raw_img = source_image# HxWxC
    raw_img = cv2.cvtColor(raw_img, cv2.COLOR_RGB2BGR)
    depth = model.infer_image(raw_img) # HxW raw depth map in numpy
    latent_H, latent_W = mask.shape[2], mask.shape[3]
    depth = cv2.resize(depth, (latent_W, latent_H), interpolation=cv2.INTER_LINEAR)
    depth = torch.from_numpy(depth).to(device) # HxW
    depth = (depth - depth.min()) / (depth.max() - depth.min())
    
    # 
    N = len(handle_points)
    flow = torch.zeros(latent_H, latent_W, 2, device=device)
    yy, xx = torch.meshgrid(torch.arange(latent_H, device=device), torch.arange(latent_W, device=device), indexing='ij')
    grid = torch.stack([xx, yy], dim=-1).float()  # [H, W, 2]
    mask_bool = mask[0, 0] > 0
    
    # add
    masked_coords = mask_bool.nonzero(as_tuple=False)  # [M, 2] -> (y, x)
    coords_xy = masked_coords[:, [1, 0]].float().to(device) 
    handle_points_tensor = torch.stack([pt.float().to(device) for pt in handle_points])  # [N, 2]
    distances = torch.norm(coords_xy[:, None, :] - handle_points_tensor[None, :, :], dim=-1)
    assignments = torch.argmin(distances, dim=-1) 
    
    sub_masks = []
    for i in range(len(handle_points)):
        sub_mask = torch.zeros(latent_H, latent_W, dtype=torch.bool, device=device)
        indices_i = masked_coords[assignments == i]
        sub_mask[indices_i[:, 0], indices_i[:, 1]] = True
        sub_masks.append(sub_mask)

    sum_scale = 0
    # move_vector = torch.zeros(latent_H, latent_W, 2, device=device)
    # O, R = get_circle(mask=mask)  # [2]
    # gamma = gamma_ratio * 2 * R
    if gamma_ratio < 1e-3:
        lambda_mix  = 1
    for i in range(N):
        handle = handle_points[i].float().to(device)    # [x, y]
        target = target_points[i].float().to(device)  # [x, y]
        
        sub_mask = sub_masks[i]
        mask_i = sub_mask.unsqueeze(0).unsqueeze(0)  # [1,1,H,W]
        mask_bool_i = mask_i[0, 0] > 0
        O, R = get_circle(mask=mask_i)  # [2]
        gamma = gamma_ratio * 2 * R
        
        direction = target - handle   # [dx, dy]
        if direction.norm() < 1e-4:
            continue
        delta = grid - handle[None, None, :]  # [H, W, 2]
        dist = delta.norm(dim=-1)
        space_weight = compute_space_weight(handle, mask_bool_i, latent_H, latent_W, O, R, beta, device, mode='linear', depth=depth)
        depth_weight = compute_depth_weight_mask_based(depth, mask_bool_i, handle, mode='linear', 
                                                       upper_scale=upper_scale, lower_scale=lower_scale,
                                                       alpha=alpha)
        
        if lambda_mix is None:
            lambda_mix_i = dist / (dist + gamma + 1e-6)  # [H, W]
            # lambda_mix_i = 1 - torch.exp(-((dist / (gamma_ratio * R)) ** 2))
        else:
            lambda_mix_i = lambda_mix  
        # import pdb; pdb.set_trace()
        influence = (1 - lambda_mix_i) * space_weight + lambda_mix_i * depth_weight
        direction = direction.view(1, 1, 2)  # [1, 1, 2]
        influence = influence.unsqueeze(-1)  # [64, 64, 1]
        move_vector = influence * direction * scale
        flow[mask_bool_i] = move_vector[mask_bool_i]
    # if test_fusion == 'pixel' or test_fusion == 'norm' or test_fusion == 'inv_norm' or test_fusion == 'add':
    #     move_vector = move_vector / sum_scale
    #     flow[mask_bool] = move_vector[mask_bool]
    return flow