import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

def postprocess_hyper(hyper):
    if hyper["valid"] == "":
        print("guessing the validation dataset location...")
        hyper["valid"] = hyper["train"]
        if "train" in hyper["train"]:
            candidate = hyper["train"].replace("train","valid")
            import os.path
            if os.path.exists(candidate):
                hyper["valid"] = candidate
        print("->",hyper["valid"])

    if hyper["test"] == "":
        print("guessing the test dataset location...")
        hyper["test"] = hyper["train"]
        if "train" in hyper["train"]:
            candidate = hyper["train"].replace("train","test")
            import os.path
            if os.path.exists(candidate):
                hyper["test"] = candidate
        print("->",hyper["test"])


def anneal_function(name, initial_value, final_value, start_epoch, current_epoch, end_epoch):
    if current_epoch < start_epoch:
        progress = 0.0
    else:
        progress = (current_epoch - start_epoch) / (end_epoch - start_epoch)
    if name == 'exponential':
        # let t0 := start_epoch , t1 := end_epoch, t := current_epoch
        # let f(t) := Ae^(R(t-t0)/(t1-t0))
        # initial_value == f(t0) == Ae^(R(t0-t0)) == A
        # final_value   == f(t1) == Ae^(R(t1-t0)/(t1-t0)) = Ae^R
        # therefore
        # A == initial_value
        # R == log(f(t1) / A) == log(f(t1)) - log(f(t0))
        R = math.log(final_value) - math.log(initial_value)
        return initial_value * math.exp(R*progress)
    if name == 'exponential_steps':
        progress = (progress // 0.2)*0.2
        R = math.log(final_value) - math.log(initial_value)
        return initial_value * math.exp(R*progress)
    if name == 'cosine':
        middle_value = (initial_value + final_value) / 2
        radius       = (final_value - initial_value) / 2
        return middle_value + radius * math.cos(math.pi * progress)
    if name == 'linear':
        return initial_value + (final_value - initial_value) * progress
    if name == 'step':
        return final_value if progress > 0.0 else initial_value
    if name == 'static': 
        return final_value
    assert False



class Variational(nn.Module):
    def __init__(self,beta=1.0):
        super().__init__()
        self.beta = beta
    def forward(self,*args,**kwargs):
        if hasattr(self, "losses"):
            self.losses += self.loss(*args,**kwargs)
        else:
            self.losses = self.loss(*args,**kwargs)

    def clear(self):
        if hasattr(self, "losses"):
            del self.losses

class BinConcrete(Variational):
    def forward(self, logits, temperature, noise=True):
        super().forward(logits, temperature, noise)
        if noise:
            eps = 1e-20
            U = torch.rand_like(logits)
            binconcrete_sample = torch.log(U + eps) - torch.log(1 -  U + eps)

            y = logits + binconcrete_sample
        else:
            y = logits
        return torch.sigmoid(y / temperature)
    def loss(self, logits, temperature, noise=True):
        # when we see BinConcrete as 2-class GumbelSoftmax,
        # then the sigmoid(logit) is equivalent to the softmax result of GumbelSoftmax logit for the first class.
        q0 = torch.sigmoid(logits)
        # because this is 2-class GumbelSoftmax, the softmax result of the second class is the rest of the probability
        # this is also mathematically equivalent to torch.sigmoid(-logits)
        q1 = 1-q0
        # we can write the loss like torch.mean(q0 * torch.log(q0) + (1-q0) * torch.log(1-q0)) ,
        # but we don't do that for numerical stability
        # take a log
        log_q0 = F.logsigmoid(logits)
        log_q1 = F.logsigmoid(-logits)
        return torch.sum(q0 * log_q0 + q1 * log_q1) * self.beta


class STBinConcrete(BinConcrete):
    def forward(self, logits, temperature, noise=True):
        y = super().forward(logits, temperature, noise)
        y_hard = y.round()
        return (y_hard - y).detach() + y


class GumbelSoftmax(Variational):
    def forward(self, logits, temperature, noise=True):
        super().forward(logits, temperature, noise)
        # reference: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
        if noise:
            eps = 1e-20
            U = torch.rand_like(logits)
            gumbel_sample = -torch.log(-torch.log(U + eps) + eps)
            
            y = logits + gumbel_sample
        else:
            y = logits
        return torch.softmax(y / temperature, dim=-1)
    def loss(self, logits, temperature, noise=True):
        q = torch.softmax(logits)
        log_q = nn.LogSoftmax(logits)
        return torch.sum(q * log_q, dim=0) * self.beta


class STGumbelSoftmax(GumbelSoftmax):
    def forward(self, logits, temperature, noise=True):
        x = super().forward(logits, temperature, noise)
        shape = x.size()
        _, ind = x.max(dim=-1)
        x_hard = torch.zeros_like(x).view(-1,shape[-1]) # all zero
        x_hard.scatter_(1, ind.view(-1, 1), 1)          # assign 1 for the argmax
        x_hard = x_hard.view(*shape)
        return (x_hard - x).detach() + x

### Loss functions

# function used by the original negative sampling
def logsigmoiddot(x, y, reduction="mean"):
    result = -F.logsigmoid((x * y).sum(1))
    if reduction == "sum":
        return result.sum(0)
    elif reduction == "mean":
        return result.mean(0)

def cosine_distance(x, y, reduction="mean"):
    nx = F.normalize(x)
    ny = F.normalize(y)
    return F.mse_loss(nx, ny, reduction=reduction) / 2 # = 1 - cos(x,y)

def js_divergence(x_prob, y_prob, reduction="mean"):
    # too slow.. commenting out
    # assert max(x_prob.view(-1)) <= 1 and min(x_prob.view(-1)) >= 0
    # assert max(y_prob.view(-1)) <= 1 and min(y_prob.view(-1)) >= 0
    m_prob = 0.5 * (x_prob + y_prob)
    result = 0.5 * (kl_divergence(x_prob, m_prob, reduction) + kl_divergence(y_prob, m_prob, reduction))
    return result

def kl_divergence(x_prob, y_prob, reduction="mean"):
    # too slow... commenting out 
    # assert max(x_prob.view(-1)) <= 1 and min(x_prob.view(-1)) >= 0
    # assert max(y_prob.view(-1)) <= 1 and min(y_prob.view(-1)) >= 0
    eps = 1e-20
    kld = (y_prob * ((y_prob+eps).log() - (x_prob+eps).log())).sum(1)
    if reduction == "mean":
        return kld.mean(0)
    if reduction == "sum":
        return kld.sum(0)


losses = { f.__name__ : f for f in [logsigmoiddot, cosine_distance] }
losses["l1"]         = F.l1_loss
losses["l2"]         = F.mse_loss
losses["bce"]        = F.binary_cross_entropy
losses["bce_logits"] = F.binary_cross_entropy_with_logits
losses["smooth_l1"]  = F.smooth_l1_loss
# losses["js"]     = distance.jensenshannon
# losses["kl"]         = F.kl_div
losses["kl"]         = kl_divergence
losses["js"]         = js_divergence
# note: there is F.cosine_embedding_loss, which is not used because it requires an additional input and is incompatible

def compute_model_memory(model):
    """Calculate the memory of weights of a model and return in 
    unit of MB"""
    total_memory = 0
    for w in model.parameters():
        total_memory += w.element_size() * w.nelement()
    
    return total_memory/1024.0/1024.0

def training_timer(file_name, func, epochs, **kwargs):
    """
    :param file_name: filename where timing information is saved
    :param func: function execution to be timed
    :param kwargs: dict of key value pairs to copy over and save
    :return:
    """
    stat = { ** kwargs}
    import time

    start_time = time.perf_counter()
    func()
    end_time = time.perf_counter()
    stat["total_runtime"] = end_time - start_time
    if epochs:
        stat["epoch_runtime"] = (end_time - start_time)/epochs

    import json
    with open(file_name, 'w') as f:
        json.dump(stat, f)


def register(dic, *args):
    assert len(args)>=2
    value = args[-1]
    lastarg = args[-2]
    args = args[:-2]
    _dic = dic
    for arg in args:
        if arg not in _dic:
            _dic[arg] = {}
        _dic = _dic[arg]
    _dic[lastarg] = value
    return dic
    

def print_all_gpu_variables():
    # prints currently alive Tensors and Variables
    import gc
    for obj in gc.get_objects():
        try:
            if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
                print(type(obj), obj.size())
        except:
            pass

def opt_cuda(x):
    # if cuda available, put variale into cuda
    cuda_available = torch.cuda.is_available()
    if cuda_available:
        return x.cuda()
    else:
        return x


def _append_save(value, path):
    import os
    import json
    try:
        with open(path, 'r') as f:
            data = json.load(f)
    except:
        data = {}
    for k, v in value.items():
        data[k] = v
    with open(path, "w") as f:
        json.dump(data, f)

def call_with_lock(path,fn):
    import time
    import subprocess
    lock = path+".lock"
    while True:
        try:
            with open(lock,"x") as f:
                try:
                    result = fn()
                finally:
                    subprocess.run(["rm",lock])
            break
        except FileExistsError:
            print("waiting for lock...")
            time.sleep(1)
    return result

def append_save(value,path):
    return call_with_lock(path, lambda : _append_save(value,path))
