import json
import logging
import os
import random
import time
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math

from tqdm import tqdm

# Project imports
import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
import sys
import params
import torch.nn.functional as F
import torch.distributions as D

from pytorch_metric_learning.distances import BaseDistance
from pytorch_metric_learning.losses import NTXentLoss

from utils import *

from data import *

from scripts.get_map import get_map as get_map_nosize 

def init_distributed_mode():
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        rank = int(os.environ['SLURM_PROCID'])
        gpu = rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        distributed = False
        return
    distributed = True

    torch.cuda.set_device(gpu)
    dist_backend = 'nccl'
    dist_url = 'env://'
    print('| distributed init (rank {}): {}'.format(
        rank, dist_url), flush=True)
    torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url,
                                         world_size=world_size, rank=rank)
    setup_for_distributed(rank == 0)
    return gpu
    
def setup_for_distributed(is_master):
    """    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print

    builtin_log = logging.info

    def new_log(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_log(*args, **kwargs)

    logging.info = new_log

def is_dist_avail_and_initialized():
    if not dist.is_available():
        return False
    if not dist.is_initialized():
        return False
    return True

def get_rank():
    if not is_dist_avail_and_initialized():
        return 0
    return dist.get_rank()

def get_world_size():
    """
    Gets total number of distributed workers or returns one if distributed is
    not initialized.
    """
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        world_size = torch.distributed.get_world_size()
    else:
        world_size = 1
    return world_size

def save_checkpoint(models, optimizers, model_name, optimizer_name, epoch, best_val_err, ckpt_path):
    
    def is_main_process():
        return get_rank() == 0

    def save_on_master(*args, **kwargs):
        if is_main_process():
            torch.save(*args, **kwargs)

    checkpoint = {}
    for idx in range(len(models)):
        model = models[idx]
        raw_model = model.module if hasattr(model, "module") else model
        checkpoint.update({
            model_name[idx]: raw_model.state_dict(),
        })
    for idx in range(len(optimizers)):
        checkpoint.update({
            optimizer_name[idx]: optimizers[idx].state_dict()
        })
    checkpoint.update({
        'epoch': epoch,
        'best_val_err': best_val_err
    })
    save_on_master(checkpoint, ckpt_path)


def load_checkpoint(models, optimizers, model_name, optimizer_name, ckpt_path):
    path_to_weight_dump = ckpt_path
    weight_ckpt = torch.load(path_to_weight_dump)

    loaded_models = []
    for idx in range(len(models)):
        model = models[idx]
        raw_model = model.module if hasattr(model, "module") else model
        raw_model.load_state_dict(weight_ckpt[model_name[idx]])
        loaded_models.append(model)
    loaded_optimizers = []
    for idx in range(len(optimizers)):
        optimizer = optimizers[idx]
        optimizer.load_state_dict(weight_ckpt[optimizer_name[idx]])
        optimizer.load_state_dict(weight_ckpt[optimizer_name[idx]])
        loaded_optimizers.append(optimizer)
    start_epoch = weight_ckpt['epoch'] + 1
    best_val_err = weight_ckpt['best_val_err']
    return loaded_models, loaded_optimizers, start_epoch, best_val_err


def gaussian_KL_loss(mus, logvars, eps=1e-8):
    """Calculates KL distance of mus and logvars from unit normal.
    Args:
        mus: Tensor of means predicted by the encoder.
        logvars: Tensor of log vars predicted by the encoder.
    Returns:
        KL loss between mus and logvars and the normal unit gaussian.
    """
##    print('logvars', logvars)
##    print('mus', mus)
#    torch.save(logvars, 'logvars-'+str(dist.get_rank())+'.pt')
#    torch.save(mus, 'mus.'+str(dist.get_rank())+'.pt')
    KLD = -0.5 * torch.sum(1 + logvars - mus.pow(2) - logvars.exp())
    kl_loss = KLD/(mus.size(0) + eps)
#    """
    if kl_loss > 100:
        print(kl_loss)
        print(KLD)
        print(mus.min(), mus.max())
        print(logvars.min(), logvars.max())
#    """
    return kl_loss

def compute_two_gaussian_loss(mu1, logvar1, mu2, logvar2):
    """Computes the KL loss between the embedding attained from the answers
    and the categories.
    KL divergence between two gaussians:
        log(sigma_2/sigma_1) + (sigma_2^2 + (mu_1 - mu_2)^2)/(2sigma_1^2) - 0.5
    Args:
        mu1: Means from first space.
        logvar1: Log variances from first space.
        mu2: Means from second space.
        logvar2: Means from second space.
    """
    numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
    fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
    kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1)
    return kl / (mu1.size(0) + 1e-8)

def generate_batch_ios(model, program_seq, simulator, vocab, random=False):
    batch_size = program_seq.shape[0]
    start_input_grids, start_output_grids = get_first_io(program_seq, simulator, vocab)
    input_grids, output_grids = start_input_grids, start_output_grids

    model = model.module if hasattr(model, "module") else model

    srcs = [[] for i in range(batch_size)]
    if not random:
        for query_step in range(params.query_num):
            embedding = model.query.encode_io(input_grids, output_grids)
            mus_t, logvars_t = model.query.encode_into_t(embedding)
            mus_t, logvars_t = model.query.intersection(mus_t, logvars_t)
            #embedding = model.query.reparameterize(mus_t, logvars_t)
            #embedding = mus_t
            embedding = torch.cat([mus_t, logvars_t], -1)
            ############## latent code ###########
            if params.latent_code:
                latent_targ = torch.Tensor([query_step] * input_grids.shape[0]).cuda().view(-1, 1)
                if params.noise:
                    noise_targ = torch.Tensor(np.random.uniform(-1, 1, (input_grids.shape[0], 5))).cuda()
                    latent_emb = torch.cat([latent_targ, noise_targ], -1)
                else:
                    latent_emb = latent_targ
                embedding = torch.cat([embedding, latent_emb], -1)
            ######################################
            query_inp = model.query.decode_process(embedding, params.hard_softmax)

            query_out = model.env_step(query_inp, program_seq, simulator, params.hard_softmax)
            if query_step > 0:
                input_grids = torch.cat([input_grids, query_inp], 1)
                output_grids = torch.cat([output_grids, query_out], 1)
            else:
                input_grids = query_inp
                output_grids = query_out
            for idx in range(batch_size):
                inp_grid = query_inp[idx].view(-1).nonzero(as_tuple=False).view(-1).short().cpu().data
                out_grid = query_out[idx].view(-1).nonzero(as_tuple=False).view(-1).short().cpu().data
                srcs[idx].append((inp_grid, out_grid))
        return srcs 
    else:
        for query_step in range(5):
            inp = []
            for idx in range(batch_size):
                inp.append(torch.Tensor(get_map_nosize()).unsqueeze(0))
            inp = torch.stack(inp, 0).cuda()
            out = model.env_step(inp, program_seq, simulator, params.hard_softmax)
            for idx in range(batch_size):
                inp_grid = inp[idx].view(-1).nonzero(as_tuple=False).view(-1).short().cpu().data
                out_grid = out[idx].view(-1).nonzero(as_tuple=False).view(-1).short().cpu().data
                srcs[idx].append((inp_grid, out_grid))
        return srcs
    
def get_output_val(val, program):
    output_type = program.var_types[-1:]
    types = output_type
    vals = []
    for raw_val, typ in zip([val], types):
#        print(raw_val.shape)
        if str(typ) == 'NULL':
            continue
        elif str(typ) == 'INT':
            raw_val = raw_val[0] + params.integer_min
        elif str(typ) == 'LIST':
            raw_val = raw_val + params.integer_min
        else:
            raise ValueError('bad type {}'.format(typ))
#        val = Value.construct(raw_val.tolist(), input_type)
        vals.append(raw_val.tolist())
    return vals

class KLDistance(BaseDistance):
    def __init__(self, **kwargs):
        super().__init__(is_inverted=False, normalize_embeddings=False, **kwargs)
        assert not self.is_inverted
        assert not self.normalize_embeddings

    def compute_mat(self, query_emb, ref_emb):
        dim_size = query_emb.shape[-1]
        mu1 = query_emb[..., :int(dim_size / 2)].unsqueeze(-2)
        logvar1 = query_emb[..., int(dim_size / 2):].unsqueeze(-2)
        mu2 = ref_emb[..., :int(dim_size / 2)].unsqueeze(-3)
        logvar2 = ref_emb[..., int(dim_size / 2):].unsqueeze(-3)
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, -1)
        return kl

class NTXentDistLoss(NTXentLoss):
    def __init__(self, temperature=0.07, **kwargs):
        super().__init__(temperature=temperature, **kwargs)
    
    def get_default_distance(self):
        return KLDistance()

class NSLoss(nn.Module):
    '''
    Negative Sampling Loss.
    '''
    def __init__(self, gamma):
        super(NSLoss, self).__init__()
        self.gamma = gamma
    
    def forward(self, program_emb, io_emb):
        dim_size = program_emb.shape[-1]
        mu1 = program_emb[..., :int(dim_size / 2)].unsqueeze(-2)
        logvar1 = program_emb[..., int(dim_size / 2):].unsqueeze(-2)
        mu2 = io_emb[..., :int(dim_size / 2)].unsqueeze(-3)
        logvar2 = io_emb[..., int(dim_size / 2):].unsqueeze(-3)
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, -1)
        diagnal_mask = torch.eye(*kl.shape, device='cuda')
        loss = -torch.log(torch.sigmoid(self.gamma - kl * diagnal_mask)).sum() / diagnal_mask.sum() \
            - torch.log(torch.sigmoid(kl * (1 - diagnal_mask) - self.gamma)).sum() / (1 - diagnal_mask).sum()
        return loss

class NSLoss2(nn.Module):
    '''
    Negative Sampling Loss. Considering negative samples among programs and negative samples among IOs.
    '''
    def __init__(self, gamma):
        super(NSLoss2, self).__init__()
        self.gamma = gamma
    
    def forward(self, program_emb, io_emb):
        tmp = torch.cat([program_emb, io_emb])
        io_emb = torch.cat([io_emb, program_emb])
        program_emb = tmp
        dim_size = program_emb.shape[-1]
        mu1 = program_emb[..., :int(dim_size / 2)].unsqueeze(-2)
        logvar1 = program_emb[..., int(dim_size / 2):].unsqueeze(-2)
        mu2 = io_emb[..., :int(dim_size / 2)].unsqueeze(-3)
        logvar2 = io_emb[..., int(dim_size / 2):].unsqueeze(-3)
        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1, -1)
        pos_mask = torch.eye(*kl.shape, device='cuda')
        tmp_mask = torch.eye(*kl.shape, device='cuda')
        N = int(kl.shape[0] / 2)
        tmp_mask[:N, N:2*N] = torch.eye(N, N, device='cuda')
        tmp_mask[N:2*N, :N] = torch.eye(N, N, device='cuda')
        neg_mask = 1 - tmp_mask

        loss = -torch.log(torch.sigmoid(self.gamma - kl * pos_mask)).sum() / pos_mask.sum() \
            - torch.log(torch.sigmoid(kl * neg_mask - self.gamma)).sum() / neg_mask.sum()
        return loss


from pytorch_metric_learning.utils import loss_and_miner_utils as lmu
from pytorch_metric_learning.utils import common_functions as c_f

class ProbDistance(BaseDistance):
    #TODO: detach from BaseDistance
    def __init__(self, distribution=None, **kwargs):
        super().__init__(is_inverted=False, normalize_embeddings=False, **kwargs)
        assert not self.is_inverted
        assert not self.normalize_embeddings
        #self.distribution = distribution

    def compute_mat(self, program_emb, io_emb):
        dim_size = io_emb.shape[-1]
        #print('alpha', alpha)
        #print('beta', beta)
        program_emb = program_emb.unsqueeze(1)
        #dist = self.distribution(alpha, beta)
        ################### Normal Distribution #####################
        if params.distribution == 'Normal':
            alpha = io_emb[..., :int(dim_size / 2)].unsqueeze(0)
            beta = (0.5 * io_emb[..., int(dim_size / 2):]).exp().unsqueeze(0)
            dist = D.normal.Normal(alpha, beta)
        ################### Beta Distribution #########################
        elif params.distribution == 'Beta':
            #alpha = (io_emb[..., :int(dim_size / 2)]).exp().unsqueeze(0)
            #beta = (io_emb[..., int(dim_size / 2):]).exp().unsqueeze(0)
            alpha = torch.clamp(io_emb[..., :int(dim_size / 2)], 0.05, 1e9).unsqueeze(0)
            beta = torch.clamp(io_emb[..., int(dim_size / 2):], 0.05, 1e9).unsqueeze(0)
            #alpha = (F.relu(io_emb[..., :int(dim_size / 2)]) + 0.05).unsqueeze(0)
            #beta = (F.relu(io_emb[..., int(dim_size / 2):]) + 0.05).unsqueeze(0)
            dist = D.beta.Beta(alpha, beta)
            program_emb = torch.sigmoid(program_emb)
        ################### GMM ######################
        #mix = D.Categorical(torch.ones(5,))
        #comp = D.Normal(torch.randn(5,2), torch.rand(5,2))
        #print(comp.batch_shape)
        #print(comp.event_shape)
        #dist = D.mixture_same_family.MixtureSameFamily(mix, comp)
        else:
            raise ValueError('bad type {}'.format(params.distribution))

        log_prob = dist.log_prob(program_emb)
        #print('log_prob', log_prob)
        #print('log_prob', log_prob.shape)
        return log_prob.sum(-1) 

class NTXentProbLoss(NTXentLoss):
    def __init__(self, temperature=0.07, **kwargs):
        super().__init__(temperature=temperature, **kwargs)
        self.distribution = D.normal.Normal 
    
    def get_default_distance(self):
        return ProbDistance(self.distribution)

    def compute_loss(self, embeddings, labels, indices_tuple):
        indices_tuple = lmu.convert_to_pairs(indices_tuple, labels)
        if all(len(x) <= 1 for x in indices_tuple):
            return self.zero_losses()
        mat = self.distance(program_emb, io_emb)
        return self.loss_method(mat, labels, indices_tuple)
    
    def forward(self, program_emb, io_emb, labels, indices_tuple=None):
        """
        Args:
            embeddings: tensor of size (batch_size, embedding_size)
            labels: tensor of size (batch_size)
            indices_tuple: tuple of size 3 for triplets (anchors, positives, negatives)
                            or size 4 for pairs (anchor1, postives, anchor2, negatives)
                            Can also be left as None
        Returns: the loss
        """
        self.reset_stats()
        labels = c_f.to_device(program_emb, io_emb, labels)
        loss_dict = self.compute_loss(program_emb, io_emb, labels, indices_tuple)
        self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
        return self.reducer(loss_dict, embeddings, labels)

class NTXentProbLoss2(nn.Module):
    def __init__(self, temperature=0.1):
        super().__init__()
        #self.distribution = D.normal.Normal
        #self.distribution = D.beta.Beta
        self.distribution = None 
        self.distance = ProbDistance(self.distribution)
        self.temperature = temperature
        self.criterion = nn.CrossEntropyLoss() 
        
    def nt_xent_loss(self, program_emb, io_emb, eps=1e-6):
        """
        assume out_1 and out_2 are normalized
        out_1: [batch_size, dim]
        out_2: [batch_size, dim]
        """
        # gather representations in case of distributed training
        # out_1_dist: [batch_size * world_size, dim]
        # out_2_dist: [batch_size * world_size, dim]
        batch_size = program_emb.shape[0]
        if torch.distributed.is_available() and torch.distributed.is_initialized():
            program_emb_dist = torch.cat(GatherLayer.apply(program_emb), 0)
            io_emb_dist = torch.cat(GatherLayer.apply(io_emb), 0)
        else:
            program_emb_dist = program_emb
            io_emb_dist = io_emb

        #print('program_emb', program_emb)
        #print('io_emb', io_emb)
        #print('program_emb', program_emb.shape)
        #print('program_emb_dist', program_emb_dist.shape)
        #print('io_emb', io_emb.shape)
        #print('io_emb_dist', io_emb_dist.shape)
        #program_emb_dist = torch.rand(*program_emb_dist.shape).cuda()
        #io_emb_dist = torch.rand(*io_emb_dist.shape).cuda()
        distance = self.distance.compute_mat(program_emb_dist, io_emb_dist)
        #print('distance', distance)
        mask = torch.eye(distance.shape[0], dtype=torch.bool).to(distance.device)
        pos = (distance[mask] / self.temperature).view(distance.shape[0], 1)
        #print('pos', pos)
        neg = (distance[~mask] / self.temperature).view(distance.shape[0], -1)
        #print('neg', neg)
        logits = torch.cat([pos, neg], -1)
        labels = torch.zeros(pos.shape[0]).to(pos.device).long()
        loss = self.criterion(logits, labels)

        return loss
    
    def forward(self, program_emb, io_emb):
        loss = self.nt_xent_loss(program_emb, io_emb)
        return loss

class SyncFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, tensor):
        ctx.batch_size = tensor.shape[0]

        gathered_tensor = [torch.zeros_like(tensor) for _ in range(torch.distributed.get_world_size())]

        torch.distributed.all_gather(gathered_tensor, tensor)
        gathered_tensor = torch.cat(gathered_tensor, 0)

        return gathered_tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        torch.distributed.all_reduce(grad_input, op=torch.distributed.ReduceOp.SUM, async_op=False)

        idx_from = torch.distributed.get_rank() * ctx.batch_size
        idx_to = (torch.distributed.get_rank() + 1) * ctx.batch_size
        return grad_input[idx_from:idx_to]

class GatherLayer(torch.autograd.Function):
    """Gather tensors from all process, supporting backward propagation."""

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        output = [torch.zeros_like(input) for _ in range(dist.get_world_size())]
        dist.all_gather(output, input)
        return tuple(output)

    @staticmethod
    def backward(ctx, *grads):
        (input,) = ctx.saved_tensors
        grad_out = torch.zeros_like(input)
        grad_out[:] = grads[dist.get_rank()]
        return grad_out

class HellingerLoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def hellinger_distance(self, alpha1, beta1, alpha2, beta2, eps=1e-6):
        batch_size = alpha1.shape[0]

        mu1 = alpha1
        mu2 = alpha2
        logvar1 = beta1
        logvar2 = beta2
        sigma1 = (0.5 * beta1).exp()
        sigma2 = (0.5 * beta2).exp()
        var1 = beta1.exp()
        var2 = beta2.exp()

        numerator = logvar1.exp() + torch.pow(mu1 - mu2, 2)
        fraction = torch.div(numerator, (logvar2.exp() + 1e-8))
        kl = 0.5 * torch.sum(logvar2 - logvar1 + fraction - 1)
        return kl / (mu1.size(0) + 1e-8)

        #distance = 1 - torch.sqrt((2 * sigma1 * sigma2) / (var1 + var2)) * torch.exp(-0.25 * (mu1 - mu2)**2 / (var1 + var2))

        #return distance.mean()
    
    def forward(self, alpha1, beta1, alpha2, beta2):
        distance = self.hellinger_distance(alpha1, beta1, alpha2, beta2)
        return torch.exp(-distance)
        #return 1 / (distance + 1e-8)
        #return -distance