from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
import random
import multiprocessing
import time 
from pathlib import Path
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import params
from model.combinar_MI6 import CombinarMI
from cuda import use_cuda, LongTensor, FloatTensor
from env.env import ProgramEnv
from env.operator import Operator, operator_to_index
from env.statement import Statement, statement_to_index
from dsl.program import Program
from dsl.example import Example

from utils_hd import *

from sklearn import preprocessing

from pytorch_metric_learning.losses import NTXentLoss

from dsl import constraint
from dsl.value import NULLVALUE

torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
np.random.seed(params.seed)
random.seed(params.seed)


def generate_prog_data_back(line):
    data = json.loads(line.rstrip())
    examples = Example.from_line(data)
    env = ProgramEnv(examples)
    program = Program.parse(data['program'])

    inputs = []
    statements = []
    drop = []
    operators = []
    programs = []
    steps = []
    input_nums = []
    types = []
    program_lengths = []


    for i, statement in enumerate(program.statements):
        inputs.append(env.get_encoding())

        # Translate absolute indices to post-drop indices
        f, args = statement.function, list(statement.args)
        for j, arg in enumerate(args):
            if isinstance(arg, int):
                args[j] = env.real_var_idxs.index(arg)

        statement = Statement(f, args)
        statements.append(statement_to_index[statement])

        used_args = []
        for next_statement in program.statements[i:]:
            used_args += [x for x in next_statement.args if isinstance(x, int)]

        to_drop = []
        for j in range(params.max_program_vars):
            if j >= env.num_vars or env.real_var_idxs[j] not in used_args:
                to_drop.append(1)
            else:
                to_drop.append(0)

        drop.append(to_drop)
        rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])

        operator = Operator.from_statement(statement)
        operators.append(operator_to_index[operator])

        if env.num_vars < params.max_program_vars:
            env.step(statement)
        else:
            # Choose a random var (that is not used anymore) to drop.
            env.step(statement, rand_idx)

        programs.append(str(program))
        # print("Inputs Shape:", [inp for inp in inputs])
        # print("Statements:", statements)
        # print("Drop:", drop)
        # print("Operators:", operators)
        steps.append(i)
        input_nums.append(len(program.input_types))
        typ = program.input_types + (3 - len(program.input_types)) * ['NULL'] + program.var_types[-1:]
        typ_token = []
        # for choice in query
        for item in typ:
            if str(item) == 'LIST':
                typ_token.append([0, 1, 0])
            elif str(item) == 'INT':
                typ_token.append([1, 0, 0])
            elif str(item) == 'NULL':
                typ_token.append([0, 0, 1])
            else:
                raise ValueError('bad type {}'.format(item))
        types.append([typ_token] * 5) # 5 ios same types

#    print('statements:', statements)
#    print('i:', i)
    program_seq = [[1] + [idx + 3 for idx in statements] + [2] + [0] * (4 - len(statements))] * (i + 1)
    program_lengths = [len(statements) + 2] * (i + 1)
#    print('program_seq:', program_seq)
    return inputs, statements, drop, operators, programs, steps, input_nums, types, program_seq, program_lengths

def generate_prog_data(line):
    data = json.loads(line.rstrip())
    examples = Example.from_line(data)
    env = ProgramEnv(examples)
    program = Program.parse(data['program'])

    inputs = []
    statements = []
    drop = []
    operators = []
    programs = []
    steps = []
    input_nums = []
    types = []
    program_lengths = []
    new_inputs = []


    for i, statement in enumerate(program.statements):
        inputs.append(env.get_encoding())

        # Translate absolute indices to post-drop indices
        f, args = statement.function, list(statement.args)
        for j, arg in enumerate(args):
            if isinstance(arg, int):
                args[j] = env.real_var_idxs.index(arg)

        statement = Statement(f, args)
        statements.append(statement_to_index[statement])

        used_args = []
        for next_statement in program.statements[i:]:
            used_args += [x for x in next_statement.args if isinstance(x, int)]

        to_drop = []
        for j in range(params.max_program_vars):
            if j >= env.num_vars or env.real_var_idxs[j] not in used_args:
                to_drop.append(1)
            else:
                to_drop.append(0)

        drop.append(to_drop)
        rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])

        operator = Operator.from_statement(statement)
        operators.append(operator_to_index[operator])

        if env.num_vars < params.max_program_vars:
            env.step(statement)
        else:
            # Choose a random var (that is not used anymore) to drop.
            env.step(statement, rand_idx)

        programs.append(str(program))
        # print("Inputs Shape:", [inp for inp in inputs])
        # print("Statements:", statements)
        # print("Drop:", drop)
        # print("Operators:", operators)
        steps.append(i)
        input_nums.append(len(program.input_types))
        typ = program.input_types + (3 - len(program.input_types)) * ['NULL'] + program.var_types[-1:]
        typ_token = []
        # for choice in query
        for item in typ:
            if str(item) == 'LIST':
                typ_token.append([0, 1, 0])
            elif str(item) == 'INT':
                typ_token.append([1, 0, 0])
            elif str(item) == 'NULL':
                typ_token.append([0, 0, 1])
            else:
                raise ValueError('bad type {}'.format(item))
        types.append([typ_token] * 5) # 5 ios same types
    new_inputs.append(env.get_encoding())
    program_seq = [[1] + [idx + 3 for idx in statements] + [2] + [0] * (params.max_prog_len - len(statements))]
    program_lengths = [len(statements) + 2]
#    print('program_seq:', program_seq)
    return new_inputs[:len(program_seq)], statements[:len(program_seq)], drop[:len(program_seq)], operators[:len(program_seq)], programs[:len(program_seq)], steps[:len(program_seq)], input_nums[:len(program_seq)], types[:len(program_seq)], program_seq, program_lengths

def generate_prog_data2(line):
    '''
    var-statements for data augmentation
    '''
    data = json.loads(line.rstrip())
    examples = Example.from_line(data)
    env = ProgramEnv(examples)
    program = Program.parse(data['program'])

    inputs = []
    statements = []
    statementss = []
    drop = []
    operators = []
    programs = []
    steps = []
    input_nums = []
    types = []
    program_lengths = []
    new_inputs = []


    for i, statement in enumerate(program.statements):
        inputs.append(env.get_encoding())

        # Translate absolute indices to post-drop indices
        f, args = statement.function, list(statement.args)
        for j, arg in enumerate(args):
            if isinstance(arg, int):
                args[j] = env.real_var_idxs.index(arg)

        statement = Statement(f, args)
        statements.append(statement_to_index[statement])
        statementss.append(statements)

        used_args = []
        for next_statement in program.statements[i:]:
            used_args += [x for x in next_statement.args if isinstance(x, int)]

        to_drop = []
        for j in range(params.max_program_vars):
            if j >= env.num_vars or env.real_var_idxs[j] not in used_args:
                to_drop.append(1)
            else:
                to_drop.append(0)

        drop.append(to_drop)
        rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])

        operator = Operator.from_statement(statement)
        operators.append(operator_to_index[operator])

        if env.num_vars < params.max_program_vars:
            env.step(statement)
        else:
            # Choose a random var (that is not used anymore) to drop.
            env.step(statement, rand_idx)

        programs.append(str(program))
        # print("Inputs Shape:", [inp for inp in inputs])
        # print("Statements:", statements)
        # print("Drop:", drop)
        # print("Operators:", operators)
        steps.append(i)
        input_nums.append(len(program.input_types))
        typ = program.input_types + (3 - len(program.input_types)) * ['NULL'] + program.var_types[-1:]
        typ_token = []
        # for choice in query
        for item in typ:
            if str(item) == 'LIST':
                typ_token.append([0, 1, 0])
            elif str(item) == 'INT':
                typ_token.append([1, 0, 0])
            elif str(item) == 'NULL':
                typ_token.append([0, 0, 1])
            else:
                raise ValueError('bad type {}'.format(item))
        types.append([typ_token] * 5) # 5 ios same types
    program_seq = []
    program_lengths = []
    for s in statementss:
        s = [1] + [idx + 3 for idx in statements] + [2] + [0] * (12 - len(s))
        program_seq.append(s)
        program_lengths.append(len(s) + 2)
#    print('program_seq:', program_seq)
    return inputs, statements, drop, operators, programs, steps, input_nums, types, program_seq, program_lengths

def load_data(fileobj, max_len):
    X = []
    Y = []
    Z = []
    W = []
    P = []
    S = []
    N = []
    T = []
    SEQ = []
    L = []

    #print("Loading dataset...")
    lines = fileobj.read().splitlines()
    if max_len is not None:
        selected_lines = random.sample(lines, max_len)
        lines = selected_lines

#    pool = multiprocessing.Pool(processes=1)
#    res = list(tqdm(pool.imap(generate_prog_data, lines[-12800:]), total=len(lines[-12800:])))
    res = []
#    for line in tqdm(lines[-1280:], total=len(lines[-1280:])):
    for line in tqdm(lines, total=len(lines)):
        res.append(generate_prog_data(line))

    for input, target, to_drop, operators, programs, steps, input_nums, types, program_seq, program_lengths in res:
        X += input
        Y += target
        Z += to_drop
        W += operators
        P += programs
        S += steps
        N += input_nums
        T += types 
        SEQ += program_seq 
        L += program_lengths

    return np.array(X), np.array(Y), np.array(Z), np.array(W), np.array(P), np.array(S), np.array(N), np.array(T), np.array(SEQ), np.array(L)


def train():
    params.gpus = init_distributed_mode() 
    print(params.gpus)
    # Define paths for storing tensorboard logs
#    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
#    save_dir = params.model_output_path + '/' + date + '/PE_model/'
    if not params.load_from_checkpoint:
        date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
        save_dir = Path(params.model_output_path + date)
    else:
        save_dir = Path('/'.join(params.load_from_checkpoint.split('/')[:-1]))
    print('saved to:', save_dir)
    if not save_dir.exists():
        if dist.get_rank() == 0:
            os.makedirs(str(save_dir))
        dist.barrier()
    if dist.get_rank() == 0:
        tb_writer = SummaryWriter(save_dir)

    print('here1')
    # Load train and val data
    with open(params.train_path, 'r') as f:
        train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = load_data(f, params.max_len)
    print('here2')
    with open(params.val_path, 'r') as f:
        val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths = load_data(f, params.max_len)

    print(train_program[0:1])

    le = preprocessing.LabelEncoder()
    programs = le.fit_transform(np.concatenate([train_program, val_program]))
    train_program = programs[:len(train_program)]
    val_program = programs[len(train_program):]

    print(train_program[0:1])
    print(le.inverse_transform(train_program[0:1]))
    # Define model
    
    model = CombinarMI(le)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    device = torch.cuda.current_device()
    model = model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[params.gpus], find_unused_parameters=False)

    #Define optimizer and loss
#    optimizer = torch.optim.Adam(model.parameters(), lr=params.learn_rate)
    main_params = model.module.main_parameters()
    info_params = model.module.info_parameters()
    main_learning_rate = params.learn_rate
    info_learning_rate = params.learn_rate
    main_optimizer = torch.optim.Adam(main_params, lr=main_learning_rate)
    info_optimizer = torch.optim.Adam(info_params, lr=info_learning_rate)
#    scheduler = ReduceLROnPlateau(optimizer=main_optimizer, mode='min',
#                                  factor=0.1, patience=args.patience,
#                                  verbose=True, min_lr=1e-7)
#    info_scheduler = ReduceLROnPlateau(optimizer=info_optimizer, mode='min',
#                                       factor=0.1, patience=args.patience,
#                                       verbose=True, min_lr=1e-7)

    models = [model]
    models_name = ['model']
    optims = [main_optimizer]
    optims_name = ['main_optimizer']
#    optims = [main_optimizer, info_optimizer]
#    optims_name = ['main_optimizer', 'info_optimizer']

    print(main_optimizer.state_dict()['param_groups'][0]['lr'])
    start_epoch = 0
    best_val_loss = np.inf
    if params.load_from_checkpoint:
        print("=> loading checkpoint '{}'".format(params.load_from_checkpoint))
        loaded_models, loaded_opts, start_epoch, best_val_loss = \
            load_checkpoint(models, optims, models_name, optims_name, params.load_from_checkpoint)
        model = loaded_models[0]
        main_optimizer = loaded_opts[0]

    print(main_optimizer.state_dict()['param_groups'][0]['lr'])

    main_lr_sched = torch.optim.lr_scheduler.StepLR(main_optimizer, step_size=params.lr_scheduler_step_size)
    info_lr_sched = torch.optim.lr_scheduler.StepLR(info_optimizer, step_size=params.lr_scheduler_step_size)

    statement_criterion = nn.CrossEntropyLoss()
    drop_criterion = nn.BCELoss()
    operator_criterion = nn.CrossEntropyLoss()
    l2_criterion = nn.MSELoss()
#    info_citerion = NTXentLoss(temperature=0.07)
    info_citerion = NTXentDistLoss(temperature=0.07)
    ns_citerion = NSLoss(gamma=params.ns_gamma)
    ns_citerion_base = NSLoss(gamma=12)
    prob_citerion = NTXentProbLoss2(temperature=0.07)
    latent_criterion = nn.CrossEntropyLoss()
    hellinger_criterion = HellingerLoss()


    #Convert to appropriate types
    # The cuda types are not used here on purpose - most GPUs can't handle so much memory
    train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths = \
        torch.LongTensor(train_data), torch.LongTensor(train_statement_target), \
        torch.FloatTensor(train_drop_target), torch.LongTensor(train_operator_target), \
        torch.LongTensor(train_program).view(-1, 1), torch.LongTensor(train_step), torch.LongTensor(train_input_num), torch.LongTensor(train_typ), torch.LongTensor(train_program_seq), torch.LongTensor(train_plengths)
    val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths = \
        torch.LongTensor(val_data), torch.LongTensor(val_statement_target), \
        torch.FloatTensor(val_drop_target), torch.LongTensor(val_operator_target), \
        torch.LongTensor(val_program).view(-1, 1), torch.LongTensor(val_step), torch.LongTensor(val_input_num), torch.LongTensor(val_typ), torch.LongTensor(val_program_seq), torch.LongTensor(val_plengths)


    val_data = Variable(val_data.type(LongTensor))
    val_statement_target = Variable(val_statement_target.type(LongTensor))
    val_drop_target = Variable(val_drop_target.type(FloatTensor))
    val_operator_target = Variable(val_operator_target.type(LongTensor))

    train_dataset = TensorDataset(train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ, train_program_seq, train_plengths)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_data_loader = DataLoader(train_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False, sampler=train_sampler, num_workers=0)

    val_dataset = TensorDataset(val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ, val_program_seq, val_plengths)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_data_loader = DataLoader(val_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False, sampler=val_sampler, num_workers=0)

    x_s = torch.zeros(params.batch_size, 1, 4, params.max_list_len, device='cuda').long() + params.integer_range
    patience_ctr = 0
    for epoch in range(start_epoch, params.num_epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        print("Epoch %d" % epoch)

        train_statement_losses, train_drop_losses, train_operator_losses = [], [], []
        train_z_n_kl_loss, train_t_n_kl_loss, train_z_t_kl_loss, train_recon_io_loss, train_recon_p_loss = [], [], [], [], []
        batch_idx = 0
        for batch in tqdm(train_data_loader):
            global_step = int(batch_idx + epoch * (len(train_data) / params.batch_size / dist.get_world_size()))
            if dist.get_rank() == 0:
                #tb_writer.add_scalar("MI/infonce", loss.item(), global_step)
                if global_step % 20000 == 0:
                    programs = [
                        'LIST|REVERSE,0',
                        'LIST|REVERSE,0|REVERSE,1',
                        'LIST|REVERSE,0|REVERSE,1|REVERSE,2',
                        'LIST|COUNT,<0,0',
                        'LIST|MAP,*-1,0|COUNT,>0,1',
                        'LIST|COUNT,<0,0|TAKE,2,0',
                        'LIST|MINIMUM,0',
                        'LIST|SORT,0|HEAD,1',
                        'LIST|SORT,0',
                    ]
#                    tb_writer.add_embedding(io_features, global_step=global_step)
                    program_seq = []
                    plengths = []
                    for p in programs:
                        program = Program.parse(p)

                        statements = []
                        program_lengths = []

                        for i, statement in enumerate(program.statements):
                            f, args = statement.function, list(statement.args)
                            statement = Statement(f, args)
                            statements.append(statement_to_index[statement])

                        seq = [1] + [idx + 3 for idx in statements] + [2] + [0] * (4 - len(statements))
                        lengths = len(statements) + 2
                        program_seq.append(seq)
                        plengths.append(lengths)
                    program_seq = torch.LongTensor(program_seq).cuda()
                    plengths = torch.LongTensor(plengths).cuda()

                    model.eval()
                    with torch.no_grad():
                        program_features = model.module.query.encode_program(program_seq, plengths)
                        tb_writer.add_embedding(program_features, global_step=global_step, metadata=programs)
                    #    for batch in val_data_loader:
                    #        x = Variable(batch[0].type(LongTensor))
                    #        y = Variable(batch[1].type(LongTensor))
                    #        z = Variable(batch[2].type(FloatTensor))
                    #        w = Variable(batch[3].type(LongTensor))
                    #        p = Variable(batch[4].type(LongTensor))
                    #        s = Variable(batch[5].type(LongTensor))
                    #        n = Variable(batch[6].type(LongTensor))
                    #        t = Variable(batch[7].type(LongTensor))
                    #        program_seq = Variable(batch[8].type(LongTensor))
                    #        plengths = Variable(batch[9].type(LongTensor))
                    #        program_features = model.module.query.encode_program(program_seq, plengths)
                    #        break
                    #    tb_writer.add_embedding(program_features, global_step=global_step, metadata=le.inverse_transform(p.tolist()))
                    model.train()

            query_num = random.randint(1, int(min(5, epoch / 2 + 1)))
#            query_num = 1
#            query_num = random.randint(1, 5)
            x = Variable(batch[0].type(LongTensor))
            y = Variable(batch[1].type(LongTensor))
            z = Variable(batch[2].type(FloatTensor))
            w = Variable(batch[3].type(LongTensor))
            p = Variable(batch[4].type(LongTensor))
            s = Variable(batch[5].type(LongTensor))
            n = Variable(batch[6].type(LongTensor))
            t = Variable(batch[7].type(LongTensor))
            program_seq = Variable(batch[8].type(LongTensor))
            plengths = Variable(batch[9].type(LongTensor))

#            torch.set_printoptions(profile="full")            
#            print(x)

            main_optimizer.zero_grad()
            info_optimizer.zero_grad()
            # mutual information loss
            z_n_kl_loss = torch.tensor([0.]).cuda()
            t_n_kl_loss = torch.tensor([0.]).cuda()
            z_t_kl_loss = torch.tensor([0.]).cuda()
            recon_io_loss = torch.tensor([0.]).cuda()
            recon_p_loss = torch.tensor([0.]).cuda()
            # ps loss
            statement_loss = torch.tensor([0.]).cuda()
            drop_loss = torch.tensor([0.]).cuda()
            operator_loss = torch.tensor([0.]).cuda()

            loss = torch.tensor([0.]).cuda()
            info_loss = torch.tensor([0.]).cuda()
            #TODO: distance loss
            distance = []
            var = None
            var_types = None

            random_io = False 
            if random_io:
                example_batch = []
                p_str = le.inverse_transform(p.view(-1).tolist())
                for prog_sp in p_str:
                    prog_sp = Program.parse(prog_sp.rstrip())
                    input_output_examples = constraint.get_input_output_examples(prog_sp, num_examples=5,
                                                                                num_tries=1000)
                    example = []
                    if input_output_examples is None:
                       input_output_examples = [[[NULLVALUE], NULLVALUE]] * 5
                    for exp in input_output_examples:
                        inp, out = exp
                        encoded_inp = [var.encoded for var in inp]
                        encoded_out = out.encoded
                        if len(encoded_inp) < 3:
                            encoded_inp.extend([NULLVALUE.encoded] * (3 - len(encoded_inp)))
                        example.append(encoded_inp + [encoded_out])
                    example = torch.Tensor(example)
                    example_batch.append(example)
                x = torch.stack(example_batch, 0).cuda().long()

                num_examples = random.randint(1, 5) 
                typ = torch.cat([x[:,:num_examples,:3,:2], x[:,:num_examples,-1:,:2]], -2)
                x = torch.cat([x[:,:num_examples,:3,2:], x[:,:num_examples,-1:,2:]], -2)

            typ = t
            x = x_s[:x.shape[0]].clone() #torch.cat([t.unsqueeze(1)[:, :, :, :-1], x_s[:x.shape[0]].clone()], axis=-1)
            x = F.one_hot(x, params.integer_range + 1).float()
            batch_loss_infonce = 0.
            batch_loss_ns = 0.
            batch_loss_var_z = 0.
            batch_loss_var_t = 0.
            batch_loss_latent = 0.
            batch_loss_hellinger = 0.
            for query_step in range(query_num):
                #x = x.clone().detach()
                embedding = model.module.query.encode_io(x, typ[:, :int(max(1, query_step)), :, :2])
                mus_t, logvars_t = model.module.query.encode_into_t(embedding)
                mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                #print(query_step)
                #print(mus_t)
                #print(logvars_t)
                #embedding = model.module.query.reparameterize(mus_t, logvars_t)
                embedding = torch.cat([mus_t, logvars_t], -1)
                #embedding = mus_t
                #embedding = embedding.mean(1)
                ################## latent code ###############
                if params.latent_code:
                    latent_targ = torch.Tensor([query_step] * x.shape[0]).cuda().view(-1, 1)
                    embedding = torch.cat([embedding, latent_targ], -1)
                ##############################################
                query_inp, query_index = model.module.query.decode_process(embedding, typ, params.hard_softmax)
                query_io, var_encoded, var_typ = model.module.env_step(query_index, query_inp, p, s, z)
                if query_step > 0:
                    x = torch.cat([x, query_io], 1)
                    var = torch.cat([var, var_encoded], 1)
                    var_types = torch.cat([var_types, var_typ], 1)
                else:
                    x = query_io
                    var = var_encoded
                    var_types = var_typ
                
                io_features = model.module.query.encode_io(x, typ[:, :int(max(1, query_step + 1)), :, :2])
                #io_features = model.module.query.encode_io(x, typ))
                #print('io_featrues', io_features)
                #print(latent_pred)
                #print(latent_targ)
                #print(latent_pred.shape)
                #print(latent_targ.shape)
                ############### f-space distribution ############
                if params.hellinger:# and query_step >= 1:
                    #mus_t_old, logvars_t_old = mus_t.clone().unsqueeze(1), logvars_t.clone().unsqueeze(1)
                    mus_t_old, logvars_t_old = mus_t.clone().detach().unsqueeze(1), logvars_t.clone().detach().unsqueeze(1)
                    mus_t, logvars_t = model.module.query.encode_into_t(io_features)
                    loss_hellinger = hellinger_criterion(mus_t[:, -1:], logvars_t[:, -1:], mus_t_old, logvars_t_old)
                    #loss_hellinger = -((logvars_t - logvars_t_old)**2).mean()
                    loss += loss_hellinger
                    batch_loss_hellinger += loss_hellinger.item()
                    mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                else:
                    mus_t, logvars_t = model.module.query.encode_into_t(io_features)
                    mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                #loss_hellinger = logvars_t.mean()
                #loss += loss_hellinger
                #################################################
                ################## latent code ################
                if params.latent_code:
                    #io_features = model.module.query.reparameterize(mus_t, logvars_t)
                    #io_features = mus_t
                    io_features = torch.cat([mus_t, logvars_t], -1)
                    latent_pred = model.module.query.latent_decoder(io_features)
                    loss_latent = latent_criterion(latent_pred, latent_targ.view(-1).long())
                    loss += loss_latent
                    batch_loss_latent += loss_latent.item()
                #################################################
                program_features = model.module.query.encode_program(program_seq, plengths)
                #print('program_features', program_features)
                #mus_z, logvars_z = model.module.query.encode_into_z2(program_features)

                io_features = torch.cat([mus_t, logvars_t], -1)
                #program_features = torch.cat([mus_z, logvars_z], -1)
                io_features = io_features.view(-1, params.dist_dim * 2)
                program_features = program_features.view(-1, params.dist_dim)

                ############### negative sampling loss ##############
                #io_idx = torch.cat([indices + (2 * offset) * program_features.shape[0] for offset in range(dist.get_world_size())])
                #program_idx = torch.cat([indices + (2 * offset + 1) * program_features.shape[0] for offset in range(dist.get_world_size())])
                ##loss_ns = ns_citerion(embeddings[io_idx], embeddings[program_idx])
                #loss_ns = ns_citerion(embeddings[program_idx], embeddings[io_idx])
                #loss_ns_base = ns_citerion_base(embeddings[program_idx], embeddings[io_idx])
                #####################################################

                ############### prob loss ##############
                loss_infonce = prob_citerion(program_features, io_features)
                print('loss_infonce', loss_infonce.item())
                ########################################

                loss_var_t = torch.mean(logvars_t.exp())
                #loss = 0.0001 * t_n_kl_loss + loss_infonce
                #print(loss_latent)
                #print(loss_infonce)
                #print(loss_var)
                loss += loss_infonce
                #loss += loss_ns
                #loss += 1 * loss_var_z
                batch_loss_infonce += loss_infonce.item()
                #batch_loss_ns += loss_ns_base.item()
                batch_loss_var_t += loss_var_t.item()

            
            if params.latent_code:
                batch_loss_latent /= query_num
            if params.hellinger:# and query_num > 1:
                batch_loss_hellinger /= query_num
            batch_loss_infonce /= query_num
            #batch_loss_ns /= query_num
            batch_loss_var_t /= query_num

            loss /= query_num
            loss.backward()
            main_optimizer.step()

            if dist.get_rank() == 0:
                tb_writer.add_scalar("MI/infonce", batch_loss_infonce, global_step)
                #tb_writer.add_scalar("MI/ns", batch_loss_ns, global_step)
                tb_writer.add_scalar("MI/var_t", batch_loss_var_t, global_step)
                #tb_writer.add_scalar("MI/kl", t_n_kl_loss.item(), global_step)
                tb_writer.add_scalar("MI/total", loss.item(), global_step)
                tb_writer.add_scalar("lr/lr", main_optimizer.state_dict()['param_groups'][0]['lr'], global_step)
                if params.latent_code:
                    tb_writer.add_scalar("MI/latent", batch_loss_latent, global_step)
                if params.hellinger and query_num > 1:
                    tb_writer.add_scalar("MI/hellinger", batch_loss_hellinger, global_step)
                if global_step % 20000 == 0:
                    x_index = x.argmax(-1)
                    img_batch = x_index.view(x_index.shape[0], 1, -1, x_index.shape[-1])
                    tb_writer.add_images('query', img_batch, global_step)

            batch_idx += 1
            #break
        #break

        main_lr_sched.step()

        model.eval()

        with torch.no_grad():
            query_num = 5

            total_statement_loss = 0.
            total_drop_loss = 0.
            total_operator_loss = 0.
            total_val_error = 0.
            total_loss = 0.
            for batch in tqdm(val_data_loader):
                x = Variable(batch[0].type(LongTensor))
                y = Variable(batch[1].type(LongTensor))
                z = Variable(batch[2].type(FloatTensor))
                w = Variable(batch[3].type(LongTensor))
                p = Variable(batch[4].type(LongTensor))
                s = Variable(batch[5].type(LongTensor))
                n = Variable(batch[6].type(LongTensor))
                t = Variable(batch[7].type(LongTensor))
                program_seq = Variable(batch[8].type(LongTensor))
                plengths = Variable(batch[9].type(LongTensor))


                typ = t
                x = x_s[:x.shape[0]].clone() #torch.cat([t.unsqueeze(1)[:, :, :, :-1], x_s[:x.shape[0]].clone()], axis=-1)
                x = F.one_hot(x, params.integer_range + 1).float()
                for query_step in range(query_num):
                    embedding = model.module.query.encode_io(x, typ[:, :int(max(1, query_step)), :, :2])
                    mus_t, logvars_t = model.module.query.encode_into_t(embedding)
                    mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                    #embedding = model.module.query.reparameterize(mus_t, logvars_t)
                    #embedding = mus_t
                    embedding = torch.cat([mus_t, logvars_t], -1)
                    ############## latent code ###########
                    if params.latent_code:
                        latent_targ = torch.Tensor([query_step] * x.shape[0]).cuda().view(-1, 1)
                        embedding = torch.cat([embedding, latent_targ], -1)
                    ######################################
                    query_inp, query_index = model.module.query.decode_process(embedding, typ, params.hard_softmax)

                    query_io, var_encoded, var_typ = model.module.env_step(query_index, query_inp, p, s, z)
                    if query_step > 0:
                        x = torch.cat([x, query_io], 1)
                        var = torch.cat([var, var_encoded], 1)
                        var_types = torch.cat([var_types, var_typ], 1)
                    else:
                        x = query_io
                        var = var_encoded
                        var_types = var_typ
                

                io_features = model.module.query.encode_io(x, typ[:, :int(max(1, query_num)), :, :2])
                ############### f-space distribution ############
                mus_t, logvars_t = model.module.query.encode_into_t(io_features)
                mus_t, logvars_t = model.module.query.intersection(mus_t, logvars_t)
                #################################################
                program_features = model.module.query.encode_program(program_seq, plengths)
                #print('program_featrues', program_features)

                io_features = torch.cat([mus_t, logvars_t], -1)
                io_features = io_features.view(-1, params.dist_dim * 2)
                program_features = program_features.view(-1, params.dist_dim)

                #io_features = io_features.view(-1, params.dense_output_size)
                #program_features = program_features.repeat_interleave(query_step + 1, 0)

                ############### prob loss ##############
                loss_infonce = prob_citerion(program_features, io_features)
                ########################################
                total_loss += (loss_infonce * y.shape[0]).item()

            t1 = torch.tensor([total_loss], dtype=torch.float64, device='cuda')
            dist.barrier()
            dist.all_reduce(t1)
            total_loss = \
                t1.tolist()[0] / len(val_data)
            if dist.get_rank() == 0:
                tb_writer.add_scalar("val/total_loss", total_loss, epoch)

            if total_loss < best_val_loss:
                ckpt_path = save_dir / 'model-best'
                print("Found new best model")
                best_val_loss = total_loss
#                save(model, optimizer, epoch, params, save_dir)
                save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_loss, ckpt_path)
                patience_ctr = 0
            else:
                patience_ctr += 1
                if patience_ctr == params.patience:
                    print("Ran out of patience. Stopping training early...")
                    break
            ckpt_path = save_dir / 'model-latest'
            save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_loss, ckpt_path)

        ############## dump dataset #################
    model.eval()
    with torch.no_grad():
        if params.dump_dataset:
            if dist.get_rank() == 0:
                ckpt_path = save_dir / 'model-best'
                print("=> loading checkpoint '{}'".format(ckpt_path))
                loaded_models, loaded_opts, start_epoch, best_val_loss = \
                    load_checkpoint(models, optims, models_name, optims_name, ckpt_path)
                model = loaded_models[0]
                model.eval()
                ios = generate_ios(train_program, train_typ, model, train_step, train_drop_target)
                f = open(save_dir / 'train_gps', 'w')
                for item in ios:
                    problem = dict(program=item['program'], examples=item['examples'])
                    f.write(json.dumps(problem) + '\n')
                f.close()

                ios = generate_ios(val_program, val_typ, model, val_step, val_drop_target)
                f = open(save_dir / 'val_gps', 'w')
                for item in ios:
                    problem = dict(program=item['program'], examples=item['examples'])
                    f.write(json.dumps(problem) + '\n')
                f.close()
            dist.barrier()
        #############################################


if __name__ == '__main__':
    train()