import torch
import torch.nn.functional as F
import torchvision
import random
from einops import rearrange

from .utils import randomly_sample_triplets, compute_distances


### video similarity loss
def update_video_corr(corr):
    global video_corr
    video_corr = corr

def update_masking():
    global corr_mask
    corr_mask = None

def update_video_corr_loss(input_corr, normalize=False):
    global video_corr
    global video_corr_loss
    if 'corr_mask' in globals() and corr_mask is not None:
        weight = corr_mask / corr_mask.mean()
        video_corr_loss += (weight * F.mse_loss(input_corr, video_corr, reduction='none')).mean()
    else:
        if normalize:
            loss_scale = F.mse_loss(input_corr, video_corr).detach()
            video_corr_loss += F.mse_loss(input_corr, video_corr) / (loss_scale + 1e-7)
        else:
            video_corr_loss += F.mse_loss(input_corr, video_corr)

def reset_video_corr_loss():
    global video_corr_loss
    video_corr_loss = 0


### optical flow consistency and smoothness loss
def update_transformation_consistency_loss(motion, reverse_motion):
    global transformation_consistency_loss
    transformation_consistency_loss += F.mse_loss(motion + reverse_motion, torch.zeros_like(motion))

def reset_transformation_consistency_loss():
    global transformation_consistency_loss
    transformation_consistency_loss = 0

def update_transformation_smoothness_loss(motion):
    global transformation_smoothness_loss
    transformation_smoothness_loss += F.mse_loss(motion, torch.zeros_like(motion))

def reset_transformation_smoothness_loss():
    global transformation_smoothness_loss
    transformation_smoothness_loss = 0


### feature similarity ranking loss
def reset_feature_ranking_loss():
    global feature_ranking_loss
    feature_ranking_loss = 0

def update_ranking_reference(rgb, depth):
    global rgb_reference
    global depth_reference
    rgb_reference = rgb
    depth_reference = depth
    
def update_feature_ranking_loss(feature, sampling=False):
    global feature_ranking_loss
    global rgb_reference
    global depth_reference
    b, c, f, h, w = feature.shape
    # interpolate and blur reference features
    if rgb_reference is not None:
        assert rgb_reference.shape[1] == f
        tmp_rgb_reference = rearrange(rgb_reference, "b f c h w -> (b f) c h w")
        tmp_rgb_reference = torch.nn.functional.interpolate(tmp_rgb_reference, (h, w), mode='bilinear')
        tmp_rgb_reference = torchvision.transforms.functional.gaussian_blur(tmp_rgb_reference, (5,5))
        tmp_rgb_reference = rearrange(tmp_rgb_reference, "(b f) c h w -> b c f h w", b=b)
    if depth_reference is not None:
        assert depth_reference.shape[1] == f
        tmp_depth_reference = rearrange(depth_reference, "b f c h w -> (b f) c h w")
        tmp_depth_reference = torch.nn.functional.interpolate(tmp_depth_reference, (h, w), mode='bilinear')
        tmp_depth_reference = torchvision.transforms.functional.gaussian_blur(tmp_depth_reference, (5,5))
        tmp_depth_reference = rearrange(tmp_depth_reference, "(b f) c h w -> b c f h w", b=b)
    # frame difference as sampling weights
    if sampling:
        total_reference = torch.cat([tmp_rgb_reference, tmp_depth_reference], dim=1) if depth_reference is not None else tmp_rgb_reference
        total_reference_diff = torch.norm(total_reference[:,:,1:,:,:] - total_reference[:,:,:-1,:,:], dim=1)
        min_v = torch.min(total_reference_diff[total_reference_diff>0])
        sampling_weights = torch.cat([torch.zeros_like(total_reference_diff[:,0,:,:].unsqueeze(1)), total_reference_diff], dim=1)
        sampling_weights[sampling_weights==0] = min_v
    else:
        sampling_weights = None
    # sample point triplets from referece
    batch_idx, i, x_A, y_A, x_B, y_B, x_C, y_C = randomly_sample_triplets(b, f, h, w, int(0.3*b*f*h*w), int(0.2*h), sampling_weights=sampling_weights)
    # distances
    d1_AB, d1_AC = compute_distances(tmp_rgb_reference, batch_idx, i, x_A, y_A, x_B, y_B, x_C, y_C)
    if depth_reference is not None:
        d2_AB, d2_AC = compute_distances(tmp_depth_reference, batch_idx, i, x_A, y_A, x_B, y_B, x_C, y_C)
    d3_AB, d3_AC = compute_distances(feature, batch_idx, i, x_A, y_A, x_B, y_B, x_C, y_C)
    # loss
    if depth_reference is not None:
        loss = ((d1_AB <= d1_AC).float() * (d2_AB <= d2_AC).float() * F.relu(d3_AB - d3_AC)).sum() / (((d1_AB <= d1_AC).float() * (d2_AB <= d2_AC).float()).sum() + 1e-6) + (((d1_AB >= d1_AC).float() * (d2_AB >= d2_AC).float() * F.relu(d3_AC - d3_AB)).sum() / (((d1_AB >= d1_AC).float() * (d2_AB >= d2_AC).float()).sum() + 1e-6))
    else:
        loss = ((d1_AB <= d1_AC).float() * F.relu(d3_AB - d3_AC)).sum() / ((d1_AB <= d1_AC).float().sum() + 1e-6) + (((d1_AB >= d1_AC).float() * F.relu(d3_AC - d3_AB)).sum() / ((d1_AB >= d1_AC).float().sum() + 1e-6))
    # update loss
    feature_ranking_loss += loss


def update_global_random():
    global randnum
    randnum = random.uniform(0, 1)