from pathlib import Path
from pipeline.metrics.bd_jup.utils import get_heatmap, heatmap2bboxes, read_image, binarize_heatmap
from pipeline.metrics.bd_jup.head import CustomNetwork,INPUT_SIZE
import lpips
import warnings
import torch
import cv2 
import numpy as np
from torchvision.transforms.functional import resize

# from base.metric_abs import MetricABS
# from base.image_abs import SRImage
from functools import partial

class CombDet:
    def __init__(
        self,
        threshold1=0.9, 
        threshold2=0.99,
        erqa_block_size=16,
        lpips_block_size=96,
        erqa_stride=None,
        lpips_stride=None,
        erqa_weight=0.5,
        final_threshold=0.8,
        version='0.0',
        n_device=0,
        model_path="model_1.pth",
        bin_mode='gauss'
    ):
        # (h_low, h_high), (w_low, w_high)
        self.size_limits=((150, 500), (150, 500))

        self.treshold1 = threshold1
        self.treshold2 = threshold2
        self.erqa_block_size = erqa_block_size
        self.lpips_block_size = lpips_block_size
        self.erqa_stride = erqa_stride if erqa_stride else self.erqa_block_size // 4
        self.lpips_stride = lpips_stride if lpips_stride else self.lpips_block_size // 8
        self.erqa_weight = erqa_weight
        self.lpips_weight = 1 - erqa_weight
        self.final_threshold = final_threshold
        self.version = version
        self.n_device = n_device
        self.device = torch.device(
            f"cuda:{n_device}" if (torch.cuda.is_available() and isinstance(n_device, int)) else "cpu"
        )
        self.bin_mode = bin_mode
        
        print(f"LPIPS running on {self.device}")
        
        self.lpips_metric = lpips.LPIPS().to(self.device)
        self.lpips_metric.eval()

        if self.version == '1.0':
            self.model = CustomNetwork().to(self.device)
            self.model.load_state_dict(torch.load(model_path))


    def _call(self, sr_mat, gt_mat = None, tc_path : Path = None):
        '''
        target_path -- путь до изображения, на котором ищем артефакт
        gt_path -- (опционально) путь до оригинального изображения
        tc_path -- (опционально) путь до изображения, сжатого традиционным кодеком [не используется, нужен для совместимости с другими методами]
        '''
        sample = sr_mat
        gt = gt_mat
        
        if gt_mat is None:
            # TODO
            raise NameError('GT path was not found')
        if tc_path:
            warnings.warn("This variable is not used")
        if gt.size >= 7_000_000 * 3:
            print("cpu")
        heatmap = get_heatmap(
            gt=gt,
            sample=sample,
            lpips_metric=self.lpips_metric, 
            threshold1=self.treshold1, 
            threshold2=self.treshold2,
            erqa_block_size=self.erqa_block_size, 
            lpips_block_size=self.lpips_block_size, 
            erqa_stride=self.erqa_stride, 
            lpips_stride=self.lpips_stride,
            erqa_weight=self.erqa_weight,
            lpips_weight=self.lpips_weight,
            device=(self.device if gt.size < 7_000_000 * 3 else "cpu" )
        )
                
        decision = False
        returned_value = 0
        # Возвращаем True, если на изображении есть артефакт, иначе False
        if self.version == '1.0':
            heatmap = torch.stack((heatmap, heatmap, heatmap))
            heatmap = resize(heatmap, INPUT_SIZE)
            returned_value = self.model(heatmap[None, ...]).item()
            decision = returned_value > self.final_threshold
        elif self.version == '0.0':
            returned_value = heatmap.max().item() 
            decision = returned_value > self.final_threshold
        else:
            pass
        
        return decision, returned_value, heatmap

    
    def bbox_filter(self, bbox):
        x1, x2, y1, y2 = bbox['x1'], bbox['x2'], bbox['y1'], bbox['y2']
        h, w = np.abs(y1 - y2), np.abs(x1 - x2)
        if not (self.size_limits[0][0] < h): # and h < self.size_limits[0][1]):
            return False
        if not (self.size_limits[1][0] < w): # and w < self.size_limits[1][1]):
            return False
        return True

    def __call__(self, hr_path, sr_path, tc_path : Path = None):
        sr_mat = cv2.imread(sr_path)
        sr_mat = cv2.cvtColor(sr_mat, cv2.COLOR_BGR2RGB)
        gt_mat = cv2.imread(hr_path)
        gt_mat = cv2.cvtColor(gt_mat, cv2.COLOR_BGR2RGB)
        decision, returned_value, heatmap = self._call(sr_mat, gt_mat, tc_path)
        heatmap_bin = binarize_heatmap(heatmap, self.final_threshold, self.bin_mode)
        bboxes_raw = heatmap2bboxes(heatmap_bin)

        res = []

        for bbox in bboxes_raw:
            if self.bbox_filter(bbox):
                x1, x2, y1, y2 = bbox['x1'], bbox['x2'], bbox['y1'], bbox['y2']
                res.append({
                    'bbox' : [x1, y1, x2, y2],
                    "art_t" : 'texture',
                    "metric_value" : int(heatmap_bin[y1:y2, x1:x2].max())
                })
        return heatmap.detach().cpu().numpy()

        
# from base.metric_abs import MetricABS
# from base.image_abs import SRImage

# class BlockingDet(MetricABS):

#     def __init__(self, *args, **kwargs):
#         self.metric = CombDet(**kwargs)
        
#     def __str__(self):
#         return "BlockingDet"

#     def __call__(self, img_info : SRImage, *args, **kwargs):

#         mat_SR = img_info.mat_SR
#         mat_HR = img_info.mat_HR
#         mat_RF = img_info.mat_RF

#         if img_info.format != 'rgb':
#             raise TypeError('Wrong image format provided, for SS metric it must be RGB')

#         det, res = self.metric(sr_mat=mat_SR, gt_mat=mat_HR)

#         return det, res