import json
import logging
import os
import random
import time
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math

from tqdm import tqdm

# Project imports
import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
import sys
import params
from dsl.program import Program
import torch.nn.functional as F
import torch.distributions as D

from pytorch_metric_learning.distances import BaseDistance
from pytorch_metric_learning.losses import NTXentLoss

def init_distributed_mode():
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        rank = int(os.environ['SLURM_PROCID'])
        gpu = rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        distributed = False
        return
    distributed = True

    torch.cuda.set_device(gpu)
    dist_backend = 'nccl'
    dist_url = 'env://'
    print('| distributed init (rank {}): {}'.format(
        rank, dist_url), flush=True)
    torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url,
                                         world_size=world_size, rank=rank)
    setup_for_distributed(rank == 0)
    return gpu
    
def setup_for_distributed(is_master):
    """    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

    builtin_log = logging.info

    def new_log(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_log(*args, **kwargs)

    logging.info = new_log

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()

def get_world_size():
    """
    Gets total number of distributed workers or returns one if distributed is
    not initialized.
    """
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        world_size = torch.distributed.get_world_size()
    else:
        world_size = 1
    return world_size

def save_checkpoint(models, optimizers, model_name, optimizer_name, epoch, best_val_err, ckpt_path):
    
    def is_main_process():
        return get_rank() == 0

    def save_on_master(*args, **kwargs):
        if is_main_process():
            torch.save(*args, **kwargs)

    checkpoint = {}
    for idx in range(len(models)):
        model = models[idx]
        raw_model = model.module if hasattr(model, "module") else model
        checkpoint.update({
            model_name[idx]: raw_model.state_dict(),
        })
    for idx in range(len(optimizers)):
        checkpoint.update({
            optimizer_name[idx]: optimizers[idx].state_dict()
        })
    checkpoint.update({
        'epoch': epoch,
        'best_val_err': best_val_err
    })
    save_on_master(checkpoint, ckpt_path)


def load_checkpoint(models, optimizers, model_name, optimizer_name, ckpt_path):
    path_to_weight_dump = ckpt_path
    weight_ckpt = torch.load(path_to_weight_dump)

    loaded_models = []
    for idx in range(len(models)):
        model = models[idx]
        raw_model = model.module if hasattr(model, "module") else model
        raw_model.load_state_dict(weight_ckpt[model_name[idx]])
        loaded_models.append(model)
    loaded_optimizers = []
    if len(optimizers) > 0:
        for idx in range(len(optimizers)):
            optimizer = optimizers[idx]
            optimizer.load_state_dict(weight_ckpt[optimizer_name[idx]])
            loaded_optimizers.append(optimizer)
    start_epoch = weight_ckpt['epoch'] + 1
    best_val_err = weight_ckpt['best_val_err']
    return loaded_models, loaded_optimizers, start_epoch, best_val_err

def gaussian_KL_loss(mus, logvars, eps=1e-8):
    """Calculates KL distance of mus and logvars from unit normal.
    Args:
        mus: Tensor of means predicted by the encoder.
        logvars: Tensor of log vars predicted by the encoder.
    Returns:
        KL loss between mus and logvars and the normal unit gaussian.
    """
##    print('logvars', logvars)
##    print('mus', mus)
#    torch.save(logvars, 'logvars-'+str(dist.get_rank())+'.pt')
#    torch.save(mus, 'mus.'+str(dist.get_rank())+'.pt')
    KLD = -0.5 * torch.sum(1 + logvars - mus.pow(2) - logvars.exp())
    kl_loss = KLD/(mus.size(0) + eps)
#    """
    if kl_loss > 100:
        print(kl_loss)
        print(KLD)
        print(mus.min(), mus.max())
        print(logvars.min(), logvars.max())
#    """
    return kl_loss

def compute_two_gaussian_loss(mu1, logvar1, mu2, logvar2):
    """Computes the KL loss between the embedding attained from the answers
    and the categories.
    KL divergence between two gaussians:
        log(sigma_2/sigma_1) + (sigma_2^2 + (mu_1 - mu_2)^2)/(2sigma_1^2) - 0.5
    Args:
        mu1: Means from first space.
        logvar1: Log variances from first space.
        mu2: Means from second space.
        logvar2: Means from second space.
    """
    numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
    fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
    kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1)
    return kl / (mu1.size(0) + 1e-8)


def generate_ios(programs, typs, model, step, drop_target, random=False):
    batch_size = 32
    x_s = torch.zeros(batch_size, 1, 4, params.max_list_len, device='cuda').long() + params.integer_range
#    print(programs)
    ios = [{'program':None, 'examples':[]} for i in range(len(programs))]
    raw_model = model.module if hasattr(model, "module") else model
    for batch in tqdm(range(0, len(programs), batch_size)):
        #print(programs[batch:batch+batch_size])
        program_batch = programs[batch:batch+batch_size].view(-1, 1).cuda()
        typ = typs[batch:batch+batch_size].cuda()
        s = step[batch:batch+batch_size].cuda()
        z = drop_target[batch:batch+batch_size].cuda()
#        print('typ:', typ.shape)
#        print('program:', program_batch.shape)
        x = x_s[:len(program_batch)].clone()
        x = F.one_hot(x, params.integer_range + 1).float()
        for query_step in range(5):
            if not random:
                embedding = raw_model.query.encode_io(x, typ[:, :int(max(1, query_step)), :, :2])
                if params.distribution:
                    if params.distribution == 'GMM':
                        gmm_embedding = model.module.query.encode_into_t_GMM(embedding)
                        gmm_embedding = model.module.query.intersection(gmm_embedding)
                        embedding = gmm_embedding
                    else:
                        mus_t, logvars_t = model.module.query.encode_into_t(embedding)
                        mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                    #embedding = raw_model.query.reparameterize(mus_t, logvars_t)
                        embedding = torch.cat([mus_t, logvars_t], -1)
                ############## latent code ###########
                if params.latent_code:
                    latent_targ = torch.Tensor([query_step] * x.shape[0]).cuda().view(-1, 1)
                    embedding = torch.cat([embedding, latent_targ], -1)
                ######################################
                # query_index [batch size, 1, 3, input length]
                query_inp, query_index = raw_model.query.decode_process(embedding, typ, params.hard_softmax)

                query_io, var_encoded, var_typ = raw_model.env_step(query_index, query_inp, program_batch, s, z)
                if query_step > 0:
                    x = torch.cat([x, query_io], 1)
                    var = torch.cat([var, var_encoded], 1)
                    var_types = torch.cat([var_types, var_typ], 1)
                else:
                    x = query_io
                    var = var_encoded
                    var_types = var_typ
            else:
                query_index = []
                for batch_idx in range(program_batch.shape[0]):
                    inputs = []
                    for input_idx in range(3):
                        if typ[batch_idx][query_step][input_idx].tolist() == [0, 1, 0]:
                            value = [np.random.randint(0, params.integer_range - 1) for i in range(params.max_list_len)]
                        elif typ[batch_idx][query_step][input_idx].tolist() == [1, 0, 0]:
                            value = [np.random.randint(0, params.integer_range - 1)]
                            value.extend([params.integer_range] * (params.max_list_len - 1))
                        elif typ[batch_idx][query_step][input_idx].tolist() == [0, 0, 1]:
                            value = [params.integer_range] * params.max_list_len
                        inputs.append(value)
                    query_index.append(inputs)
                query_index = torch.Tensor(query_index).cuda().long().unsqueeze(1)
            output_batch = []
            for batch_idx in range(program_batch.shape[0]):
                input_index = query_index[batch_idx, 0]
                program = raw_model.le.inverse_transform(program_batch[batch_idx].cpu())[0]
                ios[batch+batch_idx]['program'] = program
                program = Program.parse(program.rstrip())
                input_vals = raw_model.get_input_val(input_index, program)
                output = raw_model.get_query_output(program, input_vals)
                output_batch.append(output)
#                torch.cat([, output])
                output_val = get_output_val(output, program)
                ios[batch+batch_idx]['examples'].append({'inputs': input_vals, 'output': output_val[0]})
            output = torch.stack(output_batch, axis=0)
            #query_io = torch.cat([query_index, output.unsqueeze(1).unsqueeze(1)], axis=2).long()
            #if query_step > 0:
            #    x = torch.cat([x, query_io], 1)
            #else:
            #    x = query_io
    return ios
    
def get_output_val(val, program):
    output_type = program.var_types[-1:]
    types = output_type
    vals = []
    for raw_val, typ in zip([val], types):
#        print(raw_val.shape)
        if str(typ) == 'NULL':
            continue
        elif str(typ) == 'INT':
            raw_val = raw_val[0] + params.integer_min
        elif str(typ) == 'LIST':
            raw_val = raw_val + params.integer_min
        else:
            raise ValueError('bad type {}'.format(typ))
#        val = Value.construct(raw_val.tolist(), input_type)
        vals.append(raw_val.tolist())
    return vals

class KLDistance(BaseDistance):
    def __init__(self, **kwargs):
        super().__init__(is_inverted=False, normalize_embeddings=False, **kwargs)
        assert not self.is_inverted
        assert not self.normalize_embeddings

    def compute_mat(self, query_emb, ref_emb):
        dim_size = query_emb.shape[-1]
        mu1 = query_emb[..., :int(dim_size / 2)].unsqueeze(-2)
        logvar1 = query_emb[..., int(dim_size / 2):].unsqueeze(-2)
        mu2 = ref_emb[..., :int(dim_size / 2)].unsqueeze(-3)
        logvar2 = ref_emb[..., int(dim_size / 2):].unsqueeze(-3)
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, -1)
        return kl

class NTXentDistLoss(NTXentLoss):
    def __init__(self, temperature=0.07, **kwargs):
        super().__init__(temperature=temperature, **kwargs)
    
    def get_default_distance(self):
        return KLDistance()

class NSLoss(nn.Module):
    '''
    Negative Sampling Loss.
    '''
    def __init__(self, gamma):
        super(NSLoss, self).__init__()
        self.gamma = gamma
    
    def forward(self, program_emb, io_emb):
        dim_size = program_emb.shape[-1]
        mu1 = program_emb[..., :int(dim_size / 2)].unsqueeze(-2)
        logvar1 = program_emb[..., int(dim_size / 2):].unsqueeze(-2)
        mu2 = io_emb[..., :int(dim_size / 2)].unsqueeze(-3)
        logvar2 = io_emb[..., int(dim_size / 2):].unsqueeze(-3)
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, -1)
        diagnal_mask = torch.eye(*kl.shape, device='cuda')
        loss = -torch.log(torch.sigmoid(self.gamma - kl * diagnal_mask)).sum() / diagnal_mask.sum() \
            - torch.log(torch.sigmoid(kl * (1 - diagnal_mask) - self.gamma)).sum() / (1 - diagnal_mask).sum()
        return loss
class NSLoss2(nn.Module):
    '''
    Negative Sampling Loss. Considering negative samples among programs and negative samples among IOs.
    '''
    def __init__(self, gamma):
        super(NSLoss2, self).__init__()
        self.gamma = gamma
    
    def forward(self, program_emb, io_emb):
        tmp = torch.cat([program_emb, io_emb])
        io_emb = torch.cat([io_emb, program_emb])
        program_emb = tmp
        dim_size = program_emb.shape[-1]
        mu1 = program_emb[..., :int(dim_size / 2)].unsqueeze(-2)
        logvar1 = program_emb[..., int(dim_size / 2):].unsqueeze(-2)
        mu2 = io_emb[..., :int(dim_size / 2)].unsqueeze(-3)
        logvar2 = io_emb[..., int(dim_size / 2):].unsqueeze(-3)
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, -1)
        pos_mask = torch.eye(*kl.shape, device='cuda')
        tmp_mask = torch.eye(*kl.shape, device='cuda')
        N = int(kl.shape[0] / 2)
        tmp_mask[:N, N:2*N] = torch.eye(N, N, device='cuda')
        tmp_mask[N:2*N, :N] = torch.eye(N, N, device='cuda')
        neg_mask = 1 - tmp_mask

        loss = -torch.log(torch.sigmoid(self.gamma - kl * pos_mask)).sum() / pos_mask.sum() \
            - torch.log(torch.sigmoid(kl * neg_mask - self.gamma)).sum() / neg_mask.sum()
        return loss


from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from pytorch_metric_learning.utils import common_functions as c_f
class ProbDistance(BaseDistance):
    #TODO: detach from BaseDistance
    def __init__(self, distribution=None, **kwargs):
        super().__init__(is_inverted=False, normalize_embeddings=False, **kwargs)
        assert not self.is_inverted
        assert not self.normalize_embeddings
        #self.distribution = distribution

    def compute_mat(self, program_emb, io_emb):
        dim_size = io_emb.shape[-1]
        #dist = self.distribution(alpha, beta)
        if params.distribution == 'Normal':
            alpha = io_emb[..., :int(dim_size / 2)].unsqueeze(0)
            beta = (0.5 * io_emb[..., int(dim_size / 2):]).exp().unsqueeze(0)
            dist = D.normal.Normal(alpha, beta)
        ################### Beta Distribution #########################
        elif params.distribution == 'Beta':
            #alpha = (io_emb[..., :int(dim_size / 2)]).exp().unsqueeze(0)
            #beta = (io_emb[..., int(dim_size / 2):]).exp().unsqueeze(0)
            alpha = torch.clamp(io_emb[..., :int(dim_size / 2)], 0.05, 1e9).unsqueeze(0)
            beta = torch.clamp(io_emb[..., int(dim_size / 2):], 0.05, 1e9).unsqueeze(0)
            dist = D.beta.Beta(alpha, beta)
            program_emb = torch.sigmoid(program_emb)
        ################### GMM ######################
        elif params.distribution == 'GMM':
            io_emb = io_emb.view(-1, params.clusters, 2 * params.dist_dim + 1)
            c = io_emb[..., 0].unsqueeze(0)
            alpha = io_emb[..., 1:params.dist_dim+1].unsqueeze(0)
            beta = (0.5 * io_emb[..., params.dist_dim+1:]).exp().unsqueeze(0)
            mix = D.Categorical(c)
            comp = D.Independent(D.Normal(alpha, beta), 1)
            dist = D.mixture_same_family.MixtureSameFamily(mix, comp)
        else:
            raise ValueError('bad type {}'.format(params.distribution))
        program_emb = program_emb.unsqueeze(1)
        #print('program_emb', program_emb)
        #print('alpha', alpha)
        #print('beta', beta)
        log_prob = dist.log_prob(program_emb)
        #print('log_prob', log_prob)
        #print('log_prob', log_prob)
        #print('log_prob', log_prob.shape)
        if params.distribution == 'GMM':
            mat = log_prob
        else:
            mat = log_prob.sum(-1)
        return mat 
class NTXentProbLoss(NTXentLoss):
    def __init__(self, temperature=0.07, **kwargs):
        super().__init__(temperature=temperature, **kwargs)
        self.distribution = D.normal.Normal 
    
    def get_default_distance(self):
        return ProbDistance(self.distribution)

    def compute_loss(self, embeddings, labels, indices_tuple):
        indices_tuple = lmu.convert_to_pairs(indices_tuple, labels)
        if all(len(x) <= 1 for x in indices_tuple):
            return self.zero_losses()
        mat = self.distance(program_emb, io_emb)
        return self.loss_method(mat, labels, indices_tuple)
    
    def forward(self, program_emb, io_emb, labels, indices_tuple=None):
        """
        Args:
            embeddings: tensor of size (batch_size, embedding_size)
            labels: tensor of size (batch_size)
            indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives)
                            or size 4 for pairs (anchor1, postives, anchor2, negatives)
                            Can also be left as None
        Returns: the loss
        """
        self.reset_stats()
        labels = c_f.to_device(program_emb, io_emb, labels)
        loss_dict = self.compute_loss(program_emb, io_emb, labels, indices_tuple)
        self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
        return self.reducer(loss_dict, embeddings, labels)

class NTXentProbLoss2(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        #self.distribution = D.normal.Normal
        #self.distribution = D.beta.Beta
        self.distribution = None 
        self.distance = ProbDistance(self.distribution)
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss() 
        
    def nt_xent_loss(self, program_emb, io_emb, eps=1e-6):
        """
        assume out_1 and out_2 are normalized
        out_1: [batch_size, dim]
        out_2: [batch_size, dim]
        """
        # gather representations in case of distributed training
        # out_1_dist: [batch_size * world_size, dim]
        # out_2_dist: [batch_size * world_size, dim]
        batch_size = program_emb.shape[0]
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            program_emb_dist = torch.cat(GatherLayer.apply(program_emb), 0)
            io_emb_dist = torch.cat(GatherLayer.apply(io_emb), 0)
        else:
            program_emb_dist = program_emb
            io_emb_dist = io_emb

        #print('program_emb', program_emb)
        #print('io_emb', io_emb)
        #print('program_emb', program_emb.shape)
        #print('program_emb_dist', program_emb_dist.shape)
        #print('io_emb', io_emb.shape)
        #print('io_emb_dist', io_emb_dist.shape)
        #program_emb_dist = torch.rand(*program_emb_dist.shape).cuda()
        #io_emb_dist = torch.rand(*io_emb_dist.shape).cuda()
        distance = self.distance.compute_mat(program_emb_dist, io_emb_dist)
        #print('distance', distance)
        mask = torch.eye(distance.shape[0], dtype=torch.bool).to(distance.device)
        pos = (distance[mask] / self.temperature).view(distance.shape[0], 1)
        #print('pos', pos)
        neg = (distance[~mask] / self.temperature).view(distance.shape[0], -1)
        #print('neg', neg)
        logits = torch.cat([pos, neg], -1)
        labels = torch.zeros(pos.shape[0]).to(pos.device).long()
        loss = self.criterion(logits, labels)

        return loss
    
    def forward(self, program_emb, io_emb):
        loss = self.nt_xent_loss(program_emb, io_emb)
        return loss

class SyncFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor):
        ctx.batch_size = tensor.shape[0]

        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]

        torch.distributed.all_gather(gathered_tensor, tensor)
        gathered_tensor = torch.cat(gathered_tensor, 0)

        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)

        idx_from = torch.distributed.get_rank() * ctx.batch_size
        idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
        return grad_input[idx_from:idx_to]

class GatherLayer(torch.autograd.Function):
    """Gather tensors from all process, supporting backward propagation."""

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        (input,) = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

class HellingerLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def hellinger_distance(self, alpha1, beta1, alpha2, beta2, eps=1e-6):
        batch_size = alpha1.shape[0]

        if params.distribution == 'Normal':
            mu1 = alpha1
            mu2 = alpha2
            logvar1 = beta1
            logvar2 = beta2
            sigma1 = (0.5 * beta1).exp()
            sigma2 = (0.5 * beta2).exp()
            var1 = beta1.exp()
            var2 = beta2.exp()

            numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
            fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
            kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1)
            return kl / (mu1.size(0) + 1e-8)

        elif params.distribution == 'Beta':
            alpha1 = alpha1.exp().unsqueeze(0)
            alpha2 = alpha2.exp().unsqueeze(0)
            beta1 = beta1.exp().unsqueeze(0)
            beta2 = beta2.exp().unsqueeze(0)
            dist1 = D.beta.Beta(alpha1, beta1)
            dist2 = D.beta.Beta(alpha2, beta2)
            return torch.distributions.kl.kl_divergence(dist1, dist2).sum() / batch_size
        #distance = 1 - torch.sqrt((2 * sigma1 * sigma2) / (var1 + var2)) * torch.exp(-0.25 * (mu1 - mu2)**2 / (var1 + var2))

        #return distance.mean()
    
    def forward(self, alpha1, beta1, alpha2, beta2):
        distance = self.hellinger_distance(alpha1, beta1, alpha2, beta2)
        #return torch.exp(-distance)
        return 1 / (distance + 1e-8)
        #return -distance