import pyiqa
import torch
import cv2
import numpy as np
import torchvision.transforms as transforms
from kornia.contrib import extract_tensor_patches, compute_padding, combine_tensor_patches
import kornia


device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# create metric with default setting
iqa_metric = pyiqa.create_metric('dists', device=device)


def method(hr_path, sr_path, rf_path=None):
    
    ai_image = cv2.imread(sr_path)
    gt_image = cv2.imread(hr_path)
    
    
    #res = np.zeros(ai_image.shape[:2])
    
    
    gt_image = transforms.ToTensor()(cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)).unsqueeze(0)
    ai_image = transforms.ToTensor()(cv2.cvtColor(ai_image, cv2.COLOR_BGR2RGB)).unsqueeze(0)
    
    BLOCK_SIZE = 16
    
    padding = compute_padding(gt_image.shape[-2:], (BLOCK_SIZE, BLOCK_SIZE))
    gt_patches = extract_tensor_patches(gt_image, window_size=BLOCK_SIZE, stride=BLOCK_SIZE, padding=padding)[0]
    ai_patches = extract_tensor_patches(ai_image, window_size=BLOCK_SIZE, stride=BLOCK_SIZE, padding=padding)[0]
    
    BATCH_SIZE = 1000
    res = []
    for i in range(0, len(gt_patches), BATCH_SIZE):      
        res.append(iqa_metric(gt_patches[i:i + BATCH_SIZE], ai_patches[i:i + BATCH_SIZE]).unsqueeze(-1).unsqueeze(-1).expand(1, -1, 1, BLOCK_SIZE, BLOCK_SIZE))
    res = torch.cat(res, axis=1)
    res = combine_tensor_patches(res, gt_image.shape[-2:], window_size=BLOCK_SIZE, stride=BLOCK_SIZE, unpadding=padding)
    
    # res = res.squeeze(1)
    # threshold = 0.28
    # morph_block_size=(5, 5)
    # bin_res = torch.where(res > threshold, 1, 0).to(device).to(torch.int16)
    # bin_res[res > threshold] = 1
    
    # kernel = torch.ones(morph_block_size, dtype=torch.int16).to(device)
    # bin_res = kornia.morphology.erosion(bin_res[:, None], kernel)

    # for _ in range(3):
    #     bin_res = kornia.morphology.dilation(bin_res, kernel)

    # bin_res = bin_res.squeeze(1)



    res = res[0, 0].cpu().detach().numpy()
    return res




if __name__ == '__main__':
    
    hr_path = r'C:\Users\kir\Documents\sr_artifacts\desra_images\00026063\00026063.png'
    sr_path = r'C:\Users\kir\Documents\sr_artifacts\desra_images\00026063\00026063@RF@EDSR_x4.png'
    print(a := method(hr_path, sr_path))
    print(a.shape)
    #conf_dict = get_conf_dict('gt_conf.csv')
    #classes_info = get_classes_info('classes.json')
    #img_pathes = get_img_pathes(r'C:\Users\kir\Documents\sr_artifacts\my_dataset_mask\images', conf_dict=conf_dict)
    #best_threshold_conf, fscore = get_best_threshold_fscore(method, img_pathes, classes_info, conf_dict, min_area=10000)
    #print(best_threshold_conf, fscore)