# Copyright 2019-present NAVER Corp.
# CC BY-NC-SA 3.0
# Available only for non-commercial use


import os, pdb
from PIL import Image
import numpy as np
import torch

from tools import common
from tools.dataloader import norm_RGB
from nets.patchnet import *
from nets.patchnet_equivariant import *


def load_network(model_fn): 
    checkpoint = torch.load(model_fn, map_location="cpu")
    print("\n>> Creating net = " + checkpoint['net']) 
    net = eval(checkpoint['net'])
    nb_of_weights = common.model_size(net)
    print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )")

    # initialization
    weights = checkpoint['state_dict']
    net.load_state_dict({k.replace('module.',''):v for k,v in weights.items()})
    return net.eval()


class NonMaxSuppression (torch.nn.Module):
    def __init__(self, rel_thr=0.7, rep_thr=0.7):
        torch.nn.Module.__init__(self)
        self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.rel_thr = rel_thr
        self.rep_thr = rep_thr
    
    def forward(self, reliability, repeatability, **kw):
        assert len(reliability) == len(repeatability) == 1
        reliability, repeatability = reliability[0], repeatability[0]

        # local maxima
        maxima = (repeatability == self.max_filter(repeatability))

        # remove low peaks
        maxima *= (repeatability >= self.rep_thr)
        maxima *= (reliability   >= self.rel_thr)

        return maxima.nonzero().t()[2:4]


def extract_multiscale( net, img, detector, scale_f=2**0.25, 
                        min_scale=0.0, max_scale=1, 
                        min_size=256, max_size=1024, 
                        verbose=False):
    old_bm = torch.backends.cudnn.benchmark 
    torch.backends.cudnn.benchmark = False # speedup
    
    # extract keypoints at multiple scales
    B, three, H, W = img.shape
    assert B == 1 and three == 3, "should be a batch with a single RGB image"
    
    assert max_scale <= 1
    s = 1.0 # current scale factor
    
    X,Y,S,C,Q,D = [],[],[],[],[],[]
    while  s+0.001 >= max(min_scale, min_size / max(H,W)):
        if s-0.001 <= min(max_scale, max_size / max(H,W)):
            nh, nw = img.shape[2:]
            if verbose: print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}")
            # extract descriptors
            with torch.no_grad():
                res = net(imgs=[img])
                
            # get output and reliability map
            descriptors = res['descriptors'][0]
            reliability = res['reliability'][0]
            repeatability = res['repeatability'][0]

            # normalize the reliability for nms
            # extract maxima and descs
            y,x = detector(**res) # nms
            c = reliability[0,0,y,x]
            q = repeatability[0,0,y,x]
            d = descriptors[0,:,y,x].t()
            n = d.shape[0]

            # accumulate multiple scales
            X.append(x.float() * W/nw)
            Y.append(y.float() * H/nh)
            S.append((32/s) * torch.ones(n, dtype=torch.float32, device=d.device))
            C.append(c)
            Q.append(q)
            D.append(d)
        s /= scale_f

        # down-scale the image for next iteration
        nh, nw = round(H*s), round(W*s)
        img = F.interpolate(img, (nh,nw), mode='bilinear', align_corners=False)

    # restore value
    torch.backends.cudnn.benchmark = old_bm

    Y = torch.cat(Y)
    X = torch.cat(X)
    S = torch.cat(S) # scale
    scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability
    XYS = torch.stack([X,Y,S], dim=-1)
    D = torch.cat(D)
    return XYS, D, scores


def extract_keypoints(args, return_output=False):
    iscuda = common.torch_set_gpu(args.gpu)

    # load the network...
    net = load_network(args.model)
    if iscuda: net = net.cuda()

    # create the non-maxima detector
    detector = NonMaxSuppression(
        rel_thr = args.reliability_thr, 
        rep_thr = args.repeatability_thr)

    while args.images:
        img_path = args.images.pop(0)
        
        if img_path.endswith('.txt'):
            args.images = open(img_path).read().splitlines() + args.images
            continue
        
        print(f"\nExtracting features for {img_path}")
        img = Image.open(img_path).convert('RGB')
        W, H = img.size
        img = norm_RGB(img)[None] 
        if iscuda: img = img.cuda()
        
        # extract keypoints/descriptors for a single image
        xys, desc, scores = extract_multiscale(net, img, detector,
            scale_f   = args.scale_f, 
            min_scale = args.min_scale, 
            max_scale = args.max_scale,
            min_size  = args.min_size, 
            max_size  = args.max_size, 
            verbose = True)

        xys = xys.cpu().numpy()
        desc = desc.cpu().numpy()
        scores = scores.cpu().numpy()
        idxs = scores.argsort()[-args.top_k or None:]
        
        if return_output:
            outdict = {
                "keypoints": xys[idxs],
                "descriptors": desc[idxs],
                "scores": scores[idxs],
                "imsize": (W,H),
            }
            return outdict
        else:
            outpath = img_path + '.' + args.tag
            print(f"Saving {len(idxs)} keypoints to {outpath}")
            np.savez(open(outpath,'wb'), 
                imsize = (W,H),
                keypoints = xys[idxs], 
                descriptors = desc[idxs], 
                scores = scores[idxs])



def extract_keypoints_modified(
        images: list,
        model: str,
        top_k=5000,
        scale_f = 2**0.25,
        min_scale=0.0,
        max_scale=1,
        min_size=256,
        max_size=1024,
        reliability_thr=0.7,
        repeatability_thr=0.7,
        gpu=[0],
        verbose=False,
    ):
    """Extracts keypoints for a given list of images.

    Args:
        args (dict): arguments
        images (list): list of PIL.Image.Image objects.
        model (str): path to the model or the actual model
        reliability_thr (float, optional): _description_. Defaults to 0.7.
        repeatability_thr (float, optional): _description_. Defaults to 0.7.
        gpu (list, optional): _description_. Defaults to [0].
        return_output (bool, optional): _description_. Defaults to False.
        verbose (bool, optional): whether to print outputs. Defaults to False.

    Returns:
        list: list of dicts with keypoints, descriptors, scores and imsize.
    """
    
    # if no GPU is available, use CPU
    if not torch.cuda.is_available():
        gpu = -1
    
    # iscuda = common.torch_set_gpu(args.gpu)
    iscuda = common.torch_set_gpu(gpu, verbose=verbose)

    # load the network...
    if isinstance(model, str):
        # check if model exists
        assert os.path.exists(model), f"model does not exist at {model}"\
            "Note that you have passed model as a path to the model file."
        net = load_network(model)
    else:
        # assume it is a network
        net = model

    if iscuda: net = net.cuda()

    # create the non-maxima detector
    detector = NonMaxSuppression(
        rel_thr = reliability_thr, 
        rep_thr = repeatability_thr,
    )

    results = []
    for img in images:
        # img_path = args.images.pop(0)
        
        # if img_path.endswith('.txt'):
        #     args.images = open(img_path).read().splitlines() + args.images
        #     continue
        
        # print(f"\nExtracting features for {img_path}")
        # img = Image.open(img_path).convert('RGB')

        W, H = img.size
        img = norm_RGB(img)[None]
        if iscuda: img = img.cuda()
        
        # extract keypoints/descriptors for a single image
        xys, desc, scores = extract_multiscale(net, img, detector,
            scale_f   = scale_f, 
            min_scale = min_scale, 
            max_scale = max_scale,
            min_size  = min_size, 
            max_size  = max_size, 
            verbose = verbose)

        xys = xys.cpu().numpy()
        desc = desc.cpu().numpy()
        scores = scores.cpu().numpy()
        idxs = scores.argsort()[-top_k or None:]

        outdict = {
            "keypoints": xys[idxs],
            "descriptors": desc[idxs],
            "scores": scores[idxs],
            "imsize": (W,H),
        }
        results.append(outdict)

    return results



if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser("Extract keypoints for a given image")
    parser.add_argument("--model", type=str, required=True, help='model path')
    
    parser.add_argument("--images", type=str, required=True, nargs='+', help='images / list')
    parser.add_argument("--tag", type=str, default='r2d2', help='output file tag')
    
    parser.add_argument("--top-k", type=int, default=5000, help='number of keypoints')

    parser.add_argument("--scale-f", type=float, default=2**0.25)
    parser.add_argument("--min-size", type=int, default=256)
    parser.add_argument("--max-size", type=int, default=1024)
    parser.add_argument("--min-scale", type=float, default=0)
    parser.add_argument("--max-scale", type=float, default=1)
    
    parser.add_argument("--reliability-thr", type=float, default=0.7)
    parser.add_argument("--repeatability-thr", type=float, default=0.7)

    parser.add_argument("--gpu", type=int, nargs='+', default=[0], help='use -1 for CPU')
    args = parser.parse_args()

    extract_keypoints(args)

