import torch
import torch.nn.functional as F
from torchvision import transforms
import cv2

def get_local_weights(residual, ksize):

    pad = (ksize - 1) // 2
    residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')

    unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1)
    pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)

    return pixel_level_weight

def get_refined_artifact_map(hr_path, sr_path, rf_path, ksize=7):

    img_output = cv2.imread(sr_path)
    img_output = cv2.cvtColor(img_output, cv2.COLOR_BGR2RGB)
    img_output = transforms.ToTensor()(img_output).unsqueeze(0)
    img_gt = cv2.imread(hr_path)
    img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
    img_gt = transforms.ToTensor()(img_gt).unsqueeze(0)
    img_ema = cv2.imread(rf_path)
    img_ema = cv2.cvtColor(img_ema, cv2.COLOR_BGR2RGB)
    img_ema = transforms.ToTensor()(img_ema).unsqueeze(0)
    residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
    residual_SR = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)

    patch_level_weight = torch.var(residual_SR.clone(), dim=(-1, -2, -3), keepdim=True) ** (1/5)
    pixel_level_weight = get_local_weights(residual_SR.clone(), ksize)
    overall_weight = patch_level_weight * pixel_level_weight

    overall_weight[residual_SR < residual_ema] = 0

    return overall_weight.detach().cpu().numpy()[0, 0]