import torch
import torch.nn.functional as F
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import cv2
import kornia.filters as K
import pickle

from pathlib import Path

def get_quantile(img, x=0.1):
    '''
    returns the highest uint8 pixel value such as img contains no less than img.npixels * x with greater value
    '''
    histr = cv2.calcHist([img],[0],None,[256],[0,256])
    s = 0
    i = 0
    thr = np.sum(histr) * x
    for i, n_pixels in enumerate(histr.ravel()[::-1]):
        s += n_pixels
        if s >= thr:
            break
    return 255 - i

def heatmap2bboxes(heatmap):
    '''
    return list of sorted by size bboxes from binarised heatmap
    '''
    cnts = cv2.findContours(heatmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]

    bboxes = []
    for c in cnts:
        x, y, w, h = cv2.boundingRect(c)
        #bbox = {"x1": x, "x2": x+w, "y1": y, "y2": y+h}
        bboxes.append([x, y, x + w, y + h])
    
    return sorted(bboxes, key=(lambda b: max(b[2]-b[0], b[3]-b[1])), reverse=True)



class Structure_Similarity:
    def __init__(self, LocalWeights_ksize=33, GaussianBlur_ksize=103, GaussianBlur_sigma=33, device="cpu"): 
        self.GaussianBlur_ksize = (GaussianBlur_ksize, GaussianBlur_ksize)
        self.GaussianBlur_sigma = (GaussianBlur_sigma, GaussianBlur_sigma)
        self.LocalWeights_ksize = LocalWeights_ksize

        self.val_transform = [
            ToTensorV2(), # for one- or two- channel images
            A.Compose([A.Normalize(), ToTensorV2()]), # for 3 channel images
        ]
        if device != "cpu":
            self.device = torch.device(f"cuda:{device}" if torch.cuda.is_available() else "cpu")
        else:
            self.device = torch.device("cpu")

 
    def __call__(self, hr_path, sr_path, rf_path):
        '''
        gt_image -- original image, RGB/Grayscale/UV/AB format
        ai_image -- image artifact is searched on (AI codec), RGB/Grayscale/UV/AB format
        tc_image -- target image (traditional codec), RGB/Grayscale/UV/AB format

        returns heatmap of structure similarity w/o thresholding
        '''
        ai_image = cv2.imread(sr_path)
        ai_image = cv2.cvtColor(ai_image, cv2.COLOR_BGR2RGB)
        gt_image = cv2.imread(hr_path)
        gt_image = cv2.cvtColor(gt_image, cv2.COLOR_BGR2RGB)
        tc_image = cv2.imread(rf_path)
        tc_image = cv2.cvtColor(tc_image, cv2.COLOR_BGR2RGB)
        assert gt_image.shape == ai_image.shape == tc_image.shape, "Input images should be the sampe shape"                 
        heatmap = self.get_heatmap(gt_image, ai_image, tc_image)[0, 0].cpu().numpy()
        return heatmap 


    def get_local_weights(self, residual):
        pad = (self.LocalWeights_ksize - 1) // 2
        residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect')
        unfolded_residual = residual_pad.unfold(2, self.LocalWeights_ksize, 1).unfold(3, self.LocalWeights_ksize, 1)
        pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1)
        return pixel_level_weight


    def get_refined_artifact_map(self, img_gt, img_output, dim=(-1, -2, -3)):
        residual_SR = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
        patch_level_weight = torch.var(residual_SR.clone(), dim=dim, keepdim=True) ** (1/5)
        pixel_level_weight = self.get_local_weights(residual_SR.clone())
        overall_weight = patch_level_weight * pixel_level_weight
        return overall_weight


    def get_heatmap(self, gt_image, ai_image, tc_image):
        nchannels = gt_image.shape[-1] if len(gt_image.shape) > 2 else 1
        gt_transformed = self.val_transform[nchannels==3](image=gt_image)['image'].to(device=self.device).unsqueeze(0)
        ai_transformed = self.val_transform[nchannels==3](image=ai_image)['image'].to(device=self.device).unsqueeze(0)
        tc_transformed = self.val_transform[nchannels==3](image=tc_image)['image'].to(device=self.device).unsqueeze(0)
        
        dim = (-1, -2) if nchannels==1 else (-1, -2, -3)
        ai_heatmap = self.get_refined_artifact_map(ai_transformed, gt_transformed, dim)
        tc_heatmap = self.get_refined_artifact_map(tc_transformed, gt_transformed, dim)

        ai_heatmap_blurred = K.gaussian_blur2d(ai_heatmap, kernel_size=self.GaussianBlur_ksize, sigma=self.GaussianBlur_sigma)
        tc_heatmap_blurred = K.gaussian_blur2d(tc_heatmap, kernel_size=self.GaussianBlur_ksize, sigma=self.GaussianBlur_sigma)
        
        return ai_heatmap_blurred - tc_heatmap_blurred
    

class SS_sep_chanels:
    def __init__(self, 
                 ssad: Structure_Similarity, 
                 texture_quantile: float = 0.005, 
                 color_quantile: float = 0.005,
                 global_threshold: float = 0.00015 * 0.01): 
        self.ssad = ssad
        self.texture_quantile = texture_quantile
        self.color_quantile = color_quantile
        self.global_threshold = global_threshold
    
    @staticmethod
    def extend_heatmap(heatmap_bin):
        kernel_erode = np.ones((25, 25), np.uint8) 
        kernel_dilate = np.ones((50, 50), np.uint8) 

        heatmap_bin = heatmap_bin.astype(np.uint8)

        # heatmap_bin = cv2.erode(heatmap_bin, kernel_erode, iterations=1) 
        heatmap_bin = cv2.dilate(heatmap_bin, kernel_dilate, iterations=4) 

        heatmap_bin = np.array(heatmap_bin)

        return heatmap_bin

    def __call__(self, gt_image : np.ndarray, ai_image : np.ndarray, tc_image : np.ndarray, cached_heatmap_path : Path = None, save_heatmap_to : Path = None, return_heatmap=False):
        '''
        gt_image -- original image, BGR format
        ai_image -- image artifact is searched on (AI codec), BGR format
        tc_image -- target image (traditional codec), BGR format

        returns dictionary of bboxes of texture and color artifacts in format 
        {
            "texture": list_of_texture_bboxes, # sorted by confidence
            "color": list_of_color_bboxes # sorted by confidence
        "}
        '''

        if cached_heatmap_path is not None:
            heatmap_y, my = self.get_thresholded_heatmap(None, None, None, self.texture_quantile, cached_heatmap_path=cached_heatmap_path + "/y.pickle")
            heatmap_l, ml = self.get_thresholded_heatmap(None, None, None, self.texture_quantile, cached_heatmap_path=cached_heatmap_path + "/l.pickle")
            texture_heatmap = heatmap_y * heatmap_l

            heatmap_uv, muv = self.get_thresholded_heatmap(None, None, None, self.texture_quantile, cached_heatmap_path=cached_heatmap_path + "/uv.pickle")
            heatmap_ab, mab = self.get_thresholded_heatmap(None, None, None, self.texture_quantile, cached_heatmap_path=cached_heatmap_path + "/ab.pickle")
            color_heatmap = heatmap_uv * heatmap_ab

            return {
                "texture": (heatmap2bboxes(texture_heatmap), {"y": my, "l": ml}),
                "color": (heatmap2bboxes(color_heatmap), {"uv": muv, "ab": mab})
            }


        assert gt_image.shape == ai_image.shape == tc_image.shape, "Input images should be the sampe shape"                 
            
        gt_yuv = cv2.cvtColor(gt_image, cv2.COLOR_BGR2YUV) / 255
        ai_yuv = cv2.cvtColor(ai_image, cv2.COLOR_BGR2YUV) / 255
        tc_yuv = cv2.cvtColor(tc_image, cv2.COLOR_BGR2YUV) / 255

        gt_lab = cv2.cvtColor(gt_image, cv2.COLOR_BGR2LAB) / 255
        ai_lab = cv2.cvtColor(ai_image, cv2.COLOR_BGR2LAB) / 255
        tc_lab = cv2.cvtColor(tc_image, cv2.COLOR_BGR2LAB) / 255

        if save_heatmap_to is None:
            save_hm_path_y, save_hm_path_l, save_hm_path_uv, save_hm_path_ab = None, None, None, None
        else:
            save_hm_path_y = save_heatmap_to + "/y.pickle"
            save_hm_path_l = save_heatmap_to + "/l.pickle"
            save_hm_path_uv = save_heatmap_to + "/uv.pickle"
            save_hm_path_ab = save_heatmap_to + "/ab.pickle"
        
        ###
        # Texture bboxes
        ###
        heatmap_y, my = self.get_thresholded_heatmap(gt_yuv[..., 0], ai_yuv[..., 0], tc_yuv[..., 0], self.texture_quantile, save_heatmap_to=save_hm_path_y)
        heatmap_l, ml = self.get_thresholded_heatmap(gt_lab[..., 0], ai_lab[..., 0], tc_lab[..., 0], self.texture_quantile, save_heatmap_to=save_hm_path_l)
        texture_heatmap = heatmap_y * heatmap_l

        ###
        # Color bboxes
        ###
        heatmap_uv, muv = self.get_thresholded_heatmap(gt_yuv[..., 1:], ai_yuv[..., 1:], tc_yuv[..., 1:], self.color_quantile, save_heatmap_to=save_hm_path_uv)
        heatmap_ab, mab = self.get_thresholded_heatmap(gt_lab[..., 1:], ai_lab[..., 1:], tc_lab[..., 1:], self.color_quantile, save_heatmap_to=save_hm_path_ab)
        color_heatmap = heatmap_uv * heatmap_ab

        extended_texture_heatmap = self.extend_heatmap(texture_heatmap)
        result = []
        for bbox in heatmap2bboxes(extended_texture_heatmap):
            result.append({
                'bbox' : bbox,
                "art_t" : 'texture',
                "metric_value" : my['max'] + ml['max']
                })
        '''
        for bbox in heatmap2bboxes(color_heatmap):
            result.append({
                'bbox' : bbox,
                "art_t" : 'color',
                "metric_value" : muv['max'] + mab['max']
                })
        '''
        return result


    def get_thresholded_heatmap(self, hr_path, sr_path, rf_path, quantile, cached_heatmap_path=None, save_heatmap_to=None):
        ai = cv2.imread(sr_path)
        ai = cv2.cvtColor(ai, cv2.COLOR_BGR2RGB)
        gt = cv2.imread(hr_path)
        gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB)
        tc = cv2.imread(rf_path)
        tc = cv2.cvtColor(tc, cv2.COLOR_BGR2RGB)
        if cached_heatmap_path is None:
            heatmap = self.ssad(gt, ai, tc)
            if save_heatmap_to is not None:
                with open(save_heatmap_to, 'wb') as f:
                    pickle.dump(heatmap, f)
        else:
            with open(cached_heatmap_path, 'rb') as f:
                heatmap = pickle.load(f)

        heatmap_scaled = heatmap.astype(np.uint8) if not heatmap.max() else (np.clip(heatmap / heatmap.max(), 0, 1) * 255).astype(np.uint8)
        q = get_quantile(heatmap_scaled, x=quantile) * heatmap.max() / 255

        # второй параметр ретерна для упрощенного логгирования
        # мб еще можно проанализировать среднее внутри найденного ббокса или среднее в heatmap[heatmap!=0]
        #return heatmap >= max(q, self.global_threshold), {"max": heatmap.max(), "mean": heatmap.mean()}
        return heatmap

# from base.metric_abs import MetricABS
# from base.image_abs import SRImage
# from functools import partial

# class SSM(MetricABS):

#     def __init__(self, *args, **kwargs):
#         device = kwargs['device']
#         kargs = {
#             'global_threshold' : kwargs['global_threshold'],
#             'texture_quantile' : kwargs['texture_quantile'],
#             'color_quantile' : kwargs['color_quantile']
#         }
#         self.metric = SS_sep_chanels(Structure_Similarity(device=device), **kargs)
        
#     def __str__(self):
#         return "SSM"

#     def __call__(self, img_info : SRImage, *args, **kwargs):

#         mat_SR = img_info.mat_SR
#         mat_HR = img_info.mat_HR
#         mat_RF = img_info.mat_RF

#         if img_info.format != 'rgb':
#             raise TypeError('Wrong image format provided, for SS metric it must be RGB')

#         convert = partial(SRImage.convert_image, format_from='rgb', format_to='bgr')

#         res = self.metric(gt_image=convert(mat_HR), ai_image=convert(mat_SR), tc_image=convert(mat_RF))
#         texture_bboxes, conf_texture = res["texture"]
#         color_bboxes, conf_color = res["color"]
        
#         det = (
#             [(bbox, 'texture', conf_texture) for bbox in texture_bboxes] + 
#             [(bbox, 'color', conf_color) for bbox in color_bboxes]
#         )

#         return det, res