import torch
import torch.nn.functional as F
from torch import nn
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import os


@torch.no_grad()
class PostProcess(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.debug = True
        self.scale_fct = 128
        self.nms_th = 5
        self.grid = self.generate_grid()
        self.phaseII = args.phaseII

    def nms(self, scores, lines):  # gt_lines do not use
        # scores, lines = scores.cpu().numpy(), lines.cpu().numpy()
        junctions = lines.reshape(-1, 2, 2)
        MAX_NUM = 500
        candidate = [0]

        diff = ((junctions[:, None, :, None] - junctions[:, None]) ** 2).sum(-1)
        diff = torch.minimum(diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0])
        mask = diff < 5
        mask = torch.tril(mask)
        e = torch.eye(lines.shape[0], dtype=bool, device=mask.device)
        mask = mask & ~e
        o_mask = mask

        s_pre = mask.sum(-1) == 0
        sb_pre = torch.zeros_like(s_pre, dtype=bool)
        t = mask.sum(-1)
        for i in range(3):
            mask = o_mask
            mask[:, s_pre] = False
            tn = mask.sum(-1)
            sb_pre = (tn != t) & ~s_pre
            mask = o_mask
            mask[:, sb_pre] = False
            s_pre = (mask.sum(-1) == 0) & ~sb_pre

        return scores[s_pre][:MAX_NUM].cpu().numpy(), lines[s_pre][:MAX_NUM].cpu().numpy()

        # for i in range(1, junctions.shape[0]):
        #     if len(candidate) == MAX_NUM:
        #         break
        #     diff = ((junctions[candidate][:, None, :, None] - junctions[i][None]) ** 2).sum(-1)
        #     # diff = (np.abs(junctions[candidate][:, None, :, None] - junctions[i][None])).sum(-1)
        #     diff = np.minimum(
        #         diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
        #     )
        #     if diff.min() < 5:
        #         continue
        #     candidate.append(i)
        # return scores[candidate], lines[candidate]


    def nms_debug(self, scores, lines, gt_lines, gt_indices):
        pre_lines = lines.reshape(-1, 2, 2)
        gt_lines = gt_lines.reshape(-1, 2, 2)
        diff = ((pre_lines[:, None, :, None] - gt_lines[:, None]) ** 2).sum(-1)
        diff = np.minimum(
            diff[:, :, 0, 0] + diff[:, :, 1, 1], diff[:, :, 0, 1] + diff[:, :, 1, 0]
        )
        choice = np.unique(np.argmin(diff, 0))
        unchoice = ~np.in1d(np.arange(0, len(scores)), choice)
        unchoice = np.arange(0, len(scores))[unchoice]
        # return np.concatenate((np.ones_like(choice), np.zeros_like(unchoice))), \
        #        np.concatenate((lines[choice], lines[unchoice]))
        return np.ones_like(choice), lines[choice]

    def get_len(self, lines):
            return ((lines[:, :, 1] - lines[:, :, 3]) ** 2 + (lines[:, :, 0] - lines[:, :, 2]) ** 2) ** 0.5 * self.scale_fct

    def generate_grid(self):
        x = np.linspace(0, self.scale_fct, self.scale_fct, endpoint=False)
        y = np.linspace(0, self.scale_fct, self.scale_fct, endpoint=False)
        X, Y = np.meshgrid(x, y)
        grid = np.dstack((X, Y))
        return grid.reshape((-1, 2))

    def finetune_scores(self, outputs, targets=None, N=2000):
        scores, indices = torch.topk(outputs['pred_cls'], N)
        lines = torch.stack([l[i] for i, l in zip(indices, outputs['pred_lines'])], dim=0)
        # statistic_n_pp_distribution(lines, targets['lines'])
        bs, n = lines.shape[:2]
        l_map = sum([F.interpolate(t.unsqueeze(1), size=(256, 256), mode='bilinear', align_corners=False) for t in outputs['pred_line_map'][0:1]])
        j_map = sum([F.interpolate(t.unsqueeze(1), size=(256, 256), mode='bilinear', align_corners=False) for t in outputs['pred_junction_map'][1:3]])

        j_grid = lines.view(bs, n, 2, 2) * 2 - 1
        l_grid = torch.nn.functional.interpolate(j_grid, size=(30, 2), mode="bilinear", align_corners=True)

        j_samples = F.grid_sample(j_map, j_grid, align_corners=False).squeeze()
        l_samples = F.grid_sample(l_map, l_grid, align_corners=False).squeeze()
        scores_d = torch.log(self.get_len(lines) + 1)  # get len with scale_fact
        scores_j = j_samples.mean(dim=-1)
        scores_l = l_samples.mean(dim=-1)

        scores = scores+ scores_j * 0.4 + scores_l * 0 + scores_d * 0.2  # 0.4; 0; 0.2
        outputs['pred_cls'], outputs['pred_lines'] = scores, lines

        # if "pred_cls_p2" in outputs:
        #     scores, indices = torch.topk(outputs['pred_cls_p2'], N)
        #     lines = torch.stack([l[i] for i, l in zip(indices, outputs['pred_lines_p2'])], dim=0)
        #     # statistic_n_pp_distribution(lines, targets['lines'])
        #     bs, n = lines.shape[:2]
        #     l_map = sum([F.interpolate(t.unsqueeze(1), size=(256, 256), mode='bilinear', align_corners=False) for t in
        #                  outputs['pred_line_map'][0:1]])
        #     j_map = sum([F.interpolate(t.unsqueeze(1), size=(256, 256), mode='bilinear', align_corners=False) for t in
        #                  outputs['pred_junction_map'][1:3]])
        #
        #     j_grid = lines.view(bs, n, 2, 2) * 2 - 1
        #     l_grid = torch.nn.functional.interpolate(j_grid, size=(30, 2), mode="bilinear", align_corners=True)
        #
        #     j_samples = F.grid_sample(j_map, j_grid, align_corners=False).squeeze()
        #     l_samples = F.grid_sample(l_map, l_grid, align_corners=False).squeeze()
        #     scores_d = torch.log(self.get_len(lines) + 1)  # get len with scale_fact
        #     scores_j = j_samples.mean(dim=-1)
        #     scores_l = l_samples.mean(dim=-1)
        #
        #     scores = scores + scores_j * 0.4 + scores_l * 0 + scores_d * 0.2  # 0.4; 0; 0.2
        #     outputs['pred_cls_p2'], outputs['pred_lines_p2'] = scores, lines
        return outputs


    def forward(self, outputs, targets):
        if not self.phaseII:
            outputs = self.finetune_scores(outputs, targets, 2000)
            pred_cls = outputs['pred_cls']
            pred_lines = outputs['pred_lines']
        else:
            # outputs = self.finetune_scores(outputs, targets, 2000)  # only for ablation
            pred_cls = outputs['pred_cls_p2']
            pred_lines = outputs['pred_lines_p2']
        scores, indices = torch.sort(pred_cls, descending=True)
        lines = torch.stack([l[i] for i, l in zip(indices, pred_lines)], dim=0)
        gt_cls = torch.stack([c[i] for i, c in zip(indices.cpu(), targets['label_cls'])], dim=0)  # only for debug
        lines = lines * self.scale_fct

        pred = [self.nms(s, l) for s, l in zip(scores, lines)]
        # pred = [(s[:500].cpu().numpy(), l[:500].cpu().numpy()) for s, l in zip(scores, lines)]
        gt = [t.numpy() / 4 for t in targets['lines']]
        if not self.debug:
            return pred, gt
        else:
            dpred = [self.nms_debug(s.cpu().numpy(), l.cpu().numpy(), t.numpy() / 4, i.numpy()) for s, l, t, i in zip(scores, lines, targets['lines'], gt_cls)]
            # dpred = [self.nms_debug(s.cpu().numpy(), l, t.numpy() / 4, i.numpy()) for s, l, t, i in
            #          zip(scores, [pred[0][1], pred[1][1]], targets['lines'], gt_cls)]
            return (dpred, pred), gt