
import torch
import random 
import numpy as np
import torch
import argparse
from dataclasses import dataclass, fields, asdict

import uuid

from pathlib import Path


def get_closest(preds, query, masks):
    closest_points = torch.zeros_like(preds)
    closest_masks = torch.zeros_like(masks)
    closest_points[:, 0] = preds[:, 0]
    closest_masks[:, 0] = masks[:, 0]
    for i in range(1, preds.size(1)):
        update_closest_point = (closest_points[:, i - 1] - query).square().mean(-1) > (preds[:, i] - query).square().mean(-1)
        closest_points[:, i] = closest_points[:, i - 1]
        closest_masks[:, i] = closest_masks[:, i - 1]
        closest_points[update_closest_point, i] = preds[update_closest_point, i]
        closest_masks[update_closest_point, i] = masks[update_closest_point, i]
    return closest_points, closest_masks

def deterministic_NeuralSort(s, tau, hard=False):
    """
    s: input elements to be sorted. Shape: batch_size x n x 1
    tau: temperature for relaxation. Scalar.
    """
    n = s.size()[1]
    device = s.device
    one = torch.ones((n, 1), dtype = torch.float32).to(device)
    A_s = torch.abs(s - s.permute(0, 2, 1))
    B = torch.matmul(A_s, torch.matmul(one, torch.transpose(one, 0, 1)))
    scaling = (n + 1 - 2 * (torch.arange(n) + 1)).type(torch.float32).to(device)
    C = torch.matmul(s, scaling.unsqueeze(0))
    P_max = (C-B).permute(0, 2, 1)
    sm = torch.nn.Softmax(-1)
    P_hat = sm(P_max / tau)
    if hard:
        index = P_hat.max(-1, keepdim=True)[1]
        P_hat = torch.zeros_like(P_hat, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
       
    return P_hat

def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature, noise_scale):
    y = logits + noise_scale*sample_gumbel(logits.size()).to(logits.device)
    return torch.softmax(y / temperature, dim=-1)

def gumbel_softmax(logits, temperature, hard=False, noise_scale=1.):
    """
    ST-gumple-softmax
    input: [*, n_class]
    return: flatten --> [*, n_class] an one-hot vector
    """
    y = gumbel_softmax_sample(logits, temperature, noise_scale)
    
    if not hard:
        return y

    shape = y.size()
    _, ind = y.max(dim=-1)
    y_hard = torch.zeros_like(y).view(-1, shape[-1])
    y_hard.scatter_(1, ind.view(-1, 1), 1)
    y_hard = y_hard.view(*shape)
    # Set gradients w.r.t. y_hard gradients w.r.t. y
    y_hard = (y_hard - y).detach() + y
    return y_hard

def soft_nn(query, candidates, temp=1.0):
    neg_dists = -(candidates - query[:, None]).square().mean(-1)
    soft_scores = torch.softmax(neg_dists / temp, -1)
    return (candidates.transpose(2, 1) @ soft_scores[..., None]).squeeze(-1), soft_scores
    
def get_permuted_idx(perms, ids):
    return (perms.argmax(-1) == ids[:, None]).float().argmax(-1)

def noisy_smax(logits, temp=1., noise_rate=2.):
    return torch.softmax((logits + torch.randn_like(logits)*noise_rate)/temp, -1)

def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def compute_acc(logits, labels):
    assert logits.shape[:-1] == labels.shape
    return (logits.argmax(dim=-1) == labels).float().mean()

@dataclass
class DefaultArgs:
    seed: int = 0
    device: str = 'cuda:0'

    run_id: str = str(uuid.uuid4())
    out_dir: str = 'runs'

    def __post_init__(self):
        path_prefix = Path('./run_data')
        if self.out_dir == '':
            self.out_dir = path_prefix/self.run_id
        else:
            self.out_dir = path_prefix/self.out_dir/self.run_id


    def __str__(self):
        d = asdict(self)
        return ''.join([f'{k}: {v}\n' for k, v in d.items()])
    
    def generate_name(self, fields=None):
        d = asdict(self)
        if fields != None:
            name_dict = {f: d[f] for f in fields}
        else:
            name_dict = d
        name = ''
        for k, v in name_dict.items():
            name += f'{k}={v}_'
        return name[:-1] ## omit final '_'

def parse_args(args_class = DefaultArgs):
    def str2bool(v):
        if isinstance(v, bool):
            return v
        if v.lower() in ('yes', 'true', 't', 'y', '1'):
            return True
        elif v.lower() in ('no', 'false', 'f', 'n', '0'):
            return False
        else:
            raise argparse.ArgumentTypeError('Boolean value expected.')
    parser = argparse.ArgumentParser()
    for field in fields(args_class):
        name = field.name
        default = field.default
        name = f'--{name}'
        if field.type == bool:
            parser.add_argument(name, type=str2bool, nargs='?', const=True, default=default)
        elif field.type == tuple:
            t = str if len(default) == 0 else type(default[0])
            parser.add_argument(name, nargs='+', type=t, default=default)
        else:
            parser.add_argument(name, default=default, type=type(default))
    args = parser.parse_args()
    for k,v in vars(args).items():
        if isinstance(v, list):
            setattr(args, k, tuple(v))
    return args_class(**vars(args))
