import torch
import numpy as np
import torch.nn.functional as F
import cv2
from pipeline.metrics.bd_jup.erqa_torch import ERQA
import kornia as K

def read_image(path):
    image = cv2.imread(str(path))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image


def process_image(image: np.array, device):
    image = image.astype(np.float64) / 255 * 2 - 1
    image = image.transpose(2, 0, 1)
    return torch.from_numpy(image).float().to(device)


def get_padding(side, patch_size, stride):
    subd = side - patch_size
    return int(np.ceil(subd / stride)) * stride - subd


def smart_unfold(image, patch_size, stride=None):
    """splits image into blocks of patch_size with stride"""
    if len(image.shape) == 3:
        image = image.unsqueeze(0)
    if stride is None:
        stride = patch_size        
    *_, h, w = image.shape
    padding = (
        get_padding(h, patch_size, stride) // 2,
        get_padding(w, patch_size, stride) // 2,
    )
    image = F.unfold(image, patch_size, stride=stride, padding=padding)
    image = image.view(1, 3, patch_size, patch_size, -1)

    return image


def get_fold_divisor(image: torch.tensor, patch_size, stride, padding):
    if len(image.shape) == 3:
        image = image.unsqueeze(0)

    # Remove redundant samples and channels
    image = image[:1, :1]

    input_ones = torch.ones(image.shape, dtype=image.dtype)
    divisor = F.fold(F.unfold(input_ones, patch_size, stride=stride, padding=padding), image.shape[-2:], patch_size, stride=stride, padding=padding)

    return divisor


def probs_fold(probs, output_size, patch_size, stride, device):
    """concatenates blocks into image"""
    probs_repeated = probs[:, None, None].repeat(1, patch_size, patch_size, 1).view(probs.shape[0], -1, probs.shape[1])

    h, w = output_size
    padding = (
        get_padding(h, patch_size, stride) // 2,
        get_padding(w, patch_size, stride) // 2,
    )
    probs_folded = F.fold(probs_repeated, output_size, patch_size, stride=stride, padding=padding).to(device)
    divisor = get_fold_divisor(probs_folded, patch_size, stride, padding).to(device)
    return probs_folded / divisor

@torch.no_grad
def patch_lpips(gt, sample, patch_size, metric, device, stride=None):
    if stride is None:
        stride = patch_size
    
    gt_blocks = smart_unfold(gt, patch_size, stride)
    sample_blocks = smart_unfold(sample, patch_size, stride)

    gt_blocks = torch.movedim(gt_blocks, -1, 1)
    sample_blocks = torch.movedim(sample_blocks, -1, 1)

    batch_size, n_blocks, *shape = gt_blocks.shape

    gt_blocks = gt_blocks.view(batch_size * n_blocks, *shape)
    sample_blocks = sample_blocks.view(batch_size * n_blocks, *shape)
    metric.to(device)
    metric.eval()
        
    probs = metric(gt_blocks, sample_blocks)
    probs = probs.view(batch_size, n_blocks)

    probs = probs_fold(probs, gt.shape[-2:], patch_size, stride=stride, device=device)
    probs = probs[:, 0].squeeze(0)

    return probs


def patch_erqa(gt, sample, patch_size, metric, device, stride=None, batch_size=1024):
    if stride is None:
        stride = patch_size
    
    gt_blocks = smart_unfold(gt, patch_size, stride)
    sample_blocks = smart_unfold(sample, patch_size, stride)
    gt_blocks = torch.movedim(gt_blocks, -1, 1)[0]
    sample_blocks = torch.movedim(sample_blocks, -1, 1)[0]
    n_blocks, *shape = gt_blocks.shape

    gt_blocks = gt_blocks.view(n_blocks, *shape)
    sample_blocks = sample_blocks.view(n_blocks, *shape)

    sample_batches = torch.split(sample_blocks, split_size_or_sections=batch_size, dim=0)
    gt_batches = torch.split(gt_blocks, split_size_or_sections=batch_size, dim=0)

    probs = [
       metric(
            sample_batch,
            gt_batch,
        ) for sample_batch, gt_batch in zip(sample_batches, gt_batches)
    ]
    probs = torch.cat(probs).float().to(device)
    probs = probs.view(1, -1)
    probs = probs_fold(probs, gt.shape[-2:], patch_size, stride=stride, device=device)
    probs = probs[:, 0].squeeze(0)

    return probs

def binarize_heatmap(float_heatmap, bin_thres, bin_mode):
    if bin_mode == 'gauss':
        k = 31
        median_heatmap = K.filters.median_blur(float_heatmap[None, None, ...], (k, k)).squeeze()
        heatmap_bin = (np.array(median_heatmap.detach().cpu()) > bin_thres).astype(np.uint8)
    elif bin_mode == 'dilation':
        kernel_erode = np.ones((25, 25), np.uint8) 
        kernel_dilate = np.ones((50, 50), np.uint8) 

        heatmap_bin = np.array(float_heatmap.detach().cpu(), dtype=np.float32) > bin_thres
        heatmap_bin = heatmap_bin.astype(np.uint8)

        heatmap_bin = cv2.erode(heatmap_bin, kernel_erode, iterations=1) 
        heatmap_bin = cv2.dilate(heatmap_bin, kernel_dilate, iterations=4) 

        heatmap_bin = np.array(heatmap_bin)
    else:
        raise ValueError(f'No such bin_mode {bin_mode}')

    return heatmap_bin

@torch.no_grad
def get_heatmap(
        sample, gt, 
        lpips_metric=None, 
        erqa_block_size=16, lpips_block_size=64, 
        threshold1=0.1, threshold2=0.2, 
        erqa_stride=16, lpips_stride=16,
        erqa_weight=0.5, lpips_weight=0.5,
        device=torch.device("cpu")
    ):
    gt_torch = process_image(gt, device)
    sample_torch = process_image(sample, device)

    if lpips_weight != 0:
        lpips_blocks = patch_lpips(gt_torch, sample_torch, lpips_block_size, lpips_metric, stride=lpips_stride, device=device)
        
        heatmap_lpips = lpips_blocks
    else:
        heatmap_lpips = torch.zeros(gt_torch.shape[1:]).to(device)
    if erqa_weight != 0:
        erqa = ERQA(
            threshold1=threshold1, threshold2=threshold2, 
            shift_compensation=True, 
            penalize_wider_edges=None, 
            global_compensation=False,
        )
        score, *maps = erqa(sample_torch[None, ...], gt_torch[None, ...], return_maps=True)
        tp, fp, fn = maps
        erqa_blocks = patch_erqa(gt_torch, sample_torch, erqa_block_size, metric=erqa, stride=erqa_stride, device=device)
        
        heatmap_erqa = 1 - erqa_blocks
    else:
        heatmap_erqa = torch.zeros(gt_torch.shape[1:]).to(device)
    
    #edges = {'tp': tp, 'fp': fp, 'fn': fn}
    # heatmap_erqa, edges = 1 - erqa_blocks, {'tp': tp, 'fp': fp, 'fn': fn}

    final_map = (erqa_weight * heatmap_erqa + lpips_weight * heatmap_lpips)
    # final_map = heatmap_erqa
    #print(heatmap_erqa.max(), heatmap_lpips.max())
    final_map[torch.isnan(final_map)] = torch.mean(final_map[~torch.isnan(final_map)])

    return final_map

def heatmap2bboxes(heatmap):
    '''
    return list of sorted by size bboxes from binarised heatmap
    '''
    cnts = cv2.findContours(heatmap, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
    cnts = cnts[0] if len(cnts) == 2 else cnts[1]

    bboxes = []
    for c in cnts:
        x, y, w, h = cv2.boundingRect(c)
        bbox = {"x1": x, "x2": x+w, "y1": y, "y2": y+h}
        bboxes.append(bbox)
    
    return sorted(bboxes, key=(lambda b: max(b["x2"]-b["x1"], b["y2"]-b["y1"])), reverse=True)
