import os
import random
import numpy as np
import torch
from tqdm import tqdm
from difflib import SequenceMatcher
from torch import Tensor

__all__ = ["Registry"]


class Registry:
    """A registry providing name -> object mapping, to support
    custom modules.

    To create a registry (e.g. a backbone registry):

    .. code-block:: python

        BACKBONE_REGISTRY = Registry('BACKBONE')

    To register an object:

    .. code-block:: python

        @BACKBONE_REGISTRY.register()
        class MyBackbone(nn.Module):
            ...

    Or:

    .. code-block:: python

        BACKBONE_REGISTRY.register(MyBackbone)
    """

    def __init__(self, name):
        self._name = name
        self._obj_map = dict()

    def _do_register(self, name, obj, force=False):
        if name in self._obj_map and not force:
            raise KeyError(
                'An object named "{}" was already '
                'registered in "{}" registry'.format(name, self._name)
            )

        self._obj_map[name] = obj

    def register(self, obj=None, force=False):
        if obj is None:
            # Used as a decorator
            def wrapper(fn_or_class):
                name = fn_or_class.__name__
                self._do_register(name, fn_or_class, force=force)
                return fn_or_class

            return wrapper

        # Used as a function call
        name = obj.__name__
        self._do_register(name, obj, force=force)

    def get(self, name):
        if name not in self._obj_map:
            raise KeyError(
                'Object name "{}" does not exist '
                'in "{}" registry'.format(name, self._name)
            )

        return self._obj_map[name]

    def registered_names(self):
        return list(self._obj_map.keys())


def set_seed(seed):
    if seed != 0:
        random.seed(seed)
        os.environ["PYTHONHASHSEED"] = str(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    return seed


def split_logits_labels(model, dataloader):
    logits_list = []
    labels_list = []
    with torch.no_grad():
        for images, labels in tqdm(dataloader):
            logits = model(images)
            logits_list.append(logits)
            labels_list.append(labels)

        logits_list = torch.cat(logits_list).cuda()
        labels_list = torch.cat(labels_list).cuda()
    return logits_list, labels_list


def get_device(model=None):
    if model is None:
        if not torch.cuda.is_available():
            device = torch.device("cpu")
        else:
            cuda_idx = torch.cuda.current_device()
            device = torch.device(f"cuda:{cuda_idx}")
    else:
        device = next(model.parameters()).device
    return device


def build_score(conformal):
    if conformal == "aps":
        from .aps import aps
        return aps()
    else:
        return NotImplementedError
    

def neural_sort(
    scores: Tensor,
    tau: float = 0.01,
) -> Tensor:
    """
    Soft sorts scores (descending) along last dimension
    Follows implementation form
    https://github.com/ermongroup/neuralsort/blob/master/pytorch/neuralsort.py
    
    Grover, Wang et al., Stochastic Optimization of Sorting Networks via Continuous Relaxations

    Args:
        scores (Tensor): scores to sort
        tau (float, optional): smoothness factor. Defaults to 0.01.
        hard (bool, optional): whether to hard sort. Defaults to False.

    Returns:
        Tensor: permutation matrix such that sorted_scores = P @ scores 
    """
    A = (scores[...,:,None]-scores[...,None,:]).abs()
    n = scores.shape[-1]
    
    B = A @ torch.ones(n,1, device=A.device)
    C = scores[...,:,None] * (n - 1 - 2*torch.arange(n, device=A.device, dtype=torch.float))
    P_scores = (C-B).transpose(-2,-1)
    P_hat = torch.softmax(P_scores / tau, dim=-1)
    
    return P_hat


def soft_quantile(
    scores: Tensor,
    q: float,
    dim=-1,
    **kwargs
) -> Tensor:
    # swap requested dim with final dim
    dims = list(range(len(scores.shape)))
    dims[-1], dims[dim] = dims[dim], dims[-1]
    scores = scores.permute(*dims)
    
    # normalize scores on last dimension
    scores_norm = (scores - scores.mean()) / 3.*scores.std()
    
    # obtain permutation matrix for scores
    P_hat = neural_sort(scores_norm, **kwargs)
    
    # use permutation matrix to sort scores
    sorted_scores = (P_hat @ scores[...,None])[...,0]
    
    # turn quantiles into indices to select
    n = scores.shape[-1]
    
    squeeze = False
    if isinstance(q, float):
        squeeze = True
        q = [q]
    q = torch.tensor(q, dtype=torch.float, device=scores.device)
    indices = (1-q)*(n+1) - 1
    indices_low = torch.floor(indices).long()
    indices_frac = indices - indices_low
    indices_high = indices_low + 1
    
    # select quantiles from computed scores:
    quantiles = sorted_scores[...,torch.cat([indices_low,indices_high])]
    quantiles = quantiles[...,:q.shape[0]] + indices_frac*(quantiles[...,q.shape[0]:]-quantiles[...,:q.shape[0]])

    # restore dimension order
    if len(dims) > 1:
        quantiles = quantiles.permute(*dims)
        
    if squeeze:
        quantiles = quantiles.squeeze(dim)
    
    return quantiles
