from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import json
import os
import random
import multiprocessing
import time 
from pathlib import Path
from tqdm import tqdm
import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import Variable
from torch import nn
import torch.nn.functional as F

import params
from model.model import PCCoder
from model.query import Query
from model.combinar import Combinar
from cuda import use_cuda, LongTensor, FloatTensor
from env.env import ProgramEnv
from env.operator import Operator, operator_to_index
from env.statement import Statement, statement_to_index
from dsl.program import Program
from dsl.example import Example

from utils_hd import *

from sklearn import preprocessing

def generate_prog_data(line):
    data = json.loads(line.rstrip())
    examples = Example.from_line(data)
    env = ProgramEnv(examples)
    program = Program.parse(data['program'])

    inputs = []
    statements = []
    drop = []
    operators = []
    programs = []
    steps = []
    input_nums = []
    types = []


    for i, statement in enumerate(program.statements):
        inputs.append(env.get_encoding())

        # Translate absolute indices to post-drop indices
        f, args = statement.function, list(statement.args)
        for j, arg in enumerate(args):
            if isinstance(arg, int):
                args[j] = env.real_var_idxs.index(arg)

        statement = Statement(f, args)
        statements.append(statement_to_index[statement])

        used_args = []
        for next_statement in program.statements[i:]:
            used_args += [x for x in next_statement.args if isinstance(x, int)]

        to_drop = []
        for j in range(params.max_program_vars):
            if j >= env.num_vars or env.real_var_idxs[j] not in used_args:
                to_drop.append(1)
            else:
                to_drop.append(0)

        drop.append(to_drop)
        rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])

        operator = Operator.from_statement(statement)
        operators.append(operator_to_index[operator])

        if env.num_vars < params.max_program_vars:
            env.step(statement)
        else:
            # Choose a random var (that is not used anymore) to drop.
            env.step(statement, rand_idx)

        programs.append(str(program))
        # print("Inputs Shape:", [inp for inp in inputs])
        # print("Statements:", statements)
        # print("Drop:", drop)
        # print("Operators:", operators)
        steps.append(i)
        input_nums.append(len(program.input_types))
        typ = program.input_types + (3 - len(program.input_types)) * ['NULL'] + program.var_types[-1:]
        typ_token = []
        # for choice in query
        for item in typ:
            if str(item) == 'LIST':
                typ_token.append([0, 1, 0])
            elif str(item) == 'INT':
                typ_token.append([1, 0, 0])
            elif str(item) == 'NULL':
                typ_token.append([0, 0, 1])
            else:
                raise ValueError('bad type {}'.format(item))
        types.append([typ_token] * 5) # 5 ios same types
    return inputs, statements, drop, operators, programs, steps, input_nums, types


def load_data(fileobj, max_len):
    X = []
    Y = []
    Z = []
    W = []
    P = []
    S = []
    N = []
    T = []

    #print("Loading dataset...")
    lines = fileobj.read().splitlines()
    if max_len is not None:
        selected_lines = random.sample(lines, max_len)
        lines = selected_lines

    pool = multiprocessing.Pool(processes=1)
    res = list(tqdm(pool.imap(generate_prog_data, lines), total=len(lines)))

    for input, target, to_drop, operators, programs, steps, input_nums, types in res:
        X += input
        Y += target
        Z += to_drop
        W += operators
        P += programs
        S += steps
        N += input_nums
        T += types 

    return np.array(X), np.array(Y), np.array(Z), np.array(W), np.array(P), np.array(S), np.array(N), np.array(T)


def train():
    params.gpus = init_distributed_mode() 
    print(params.gpus)
    # Define paths for storing tensorboard logs
#    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
#    save_dir = params.model_output_path + '/' + date + '/PE_model/'
    if not params.load_from_checkpoint:
        date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
        save_dir = Path(params.model_output_path + date)
    else:
        save_dir = Path(params.model_output_path + params.load_from_checkpoint)
    print('saved to:', save_dir)
    if not save_dir.exists():
        if dist.get_rank() == 0:
            os.makedirs(str(save_dir))
        dist.barrier()
    tb_writer = SummaryWriter(save_dir)

    # Load train and val data
    with open(params.train_path, 'r') as f:
        train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ = load_data(f, params.max_len)

    with open(params.val_path, 'r') as f:
        val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ = load_data(f, params.max_len)

    print(train_program[0:1])

    le = preprocessing.LabelEncoder()
    programs = le.fit_transform(np.concatenate([train_program, val_program]))
    train_program = programs[:len(train_program)]
    val_program = programs[len(train_program):]

    print(train_program[0:1])
    print(le.inverse_transform(train_program[0:1]))
    # Define model
    
    model = Combinar(le)
    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    device = torch.cuda.current_device()
    model = model.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[params.gpus], find_unused_parameters=True)

    #Define optimizer and loss
    optimizer = torch.optim.Adam(model.parameters(), lr=params.learn_rate)

    models = [model]
    models_name = ['model']
    optims = [optimizer]
    optims_name = ['optimizer']

    start_epoch = 0
    best_val_error = np.inf
    if params.load_from_checkpoint:
        print("=> loading checkpoint '{}'".format(params.load_from_checkpoint))
        loaded_models, loaded_opts, start_epoch, best_val_error = \
            load_checkpoint(models, optims, models_name, optims_name, params.load_from_checkpoint)
        model = loaded_models[0]
        optimizer = loaded_opts[0]

    lr_sched = torch.optim.lr_scheduler.StepLR(optimizer, step_size=params.lr_scheduler_step_size, gamma=params.gamma)

    statement_criterion = nn.CrossEntropyLoss()
    drop_criterion = nn.BCELoss()
    operator_criterion = nn.CrossEntropyLoss()
    latent_criterion = nn.CrossEntropyLoss()

    #Convert to appropriate types
    # The cuda types are not used here on purpose - most GPUs can't handle so much memory
    train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ = \
        torch.LongTensor(train_data), torch.LongTensor(train_statement_target), \
        torch.FloatTensor(train_drop_target), torch.LongTensor(train_operator_target), \
        torch.LongTensor(train_program).view(-1, 1), torch.LongTensor(train_step), torch.LongTensor(train_input_num), torch.LongTensor(train_typ)
    val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ = \
        torch.LongTensor(val_data), torch.LongTensor(val_statement_target), \
        torch.FloatTensor(val_drop_target), torch.LongTensor(val_operator_target), \
        torch.LongTensor(val_program).view(-1, 1), torch.LongTensor(val_step), torch.LongTensor(val_input_num), torch.LongTensor(val_typ)


    val_data = Variable(val_data.type(LongTensor))
    val_statement_target = Variable(val_statement_target.type(LongTensor))
    val_drop_target = Variable(val_drop_target.type(FloatTensor))
    val_operator_target = Variable(val_operator_target.type(LongTensor))

    train_dataset = TensorDataset(train_data, train_statement_target, train_drop_target, train_operator_target, train_program, train_step, train_input_num, train_typ)
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_data_loader = DataLoader(train_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False, sampler=train_sampler, num_workers=0)

    val_dataset = TensorDataset(val_data, val_statement_target, val_drop_target, val_operator_target, val_program, val_step, val_input_num, val_typ)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_data_loader = DataLoader(val_dataset, batch_size=params.batch_size,
                            shuffle=False, pin_memory=False, sampler=val_sampler, num_workers=0)

    x_s = torch.zeros(params.batch_size, 1, 4, params.max_list_len, device='cuda').long() + params.integer_range
    for epoch in range(start_epoch, params.num_epochs):
        train_sampler.set_epoch(epoch)
        model.train()
        print("Epoch %d" % epoch)

        train_statement_losses, train_drop_losses, train_operator_losses = [], [], []
        train_latent_losses = []

        for batch in tqdm(train_data_loader):
#            query_num = random.randint(1, int(min(5, epoch / 2 + 1)))
#            query_num = 5 
            query_num = random.randint(2, 5)
            x = Variable(batch[0].type(LongTensor))
            y = Variable(batch[1].type(LongTensor))
            z = Variable(batch[2].type(FloatTensor))
            w = Variable(batch[3].type(LongTensor))
            p = Variable(batch[4].type(LongTensor))
            s = Variable(batch[5].type(LongTensor))
            n = Variable(batch[6].type(LongTensor))
            t = Variable(batch[7].type(LongTensor))

#            torch.set_printoptions(profile="full")            
#            print(x)

            optimizer.zero_grad()

            statement_loss = torch.tensor([0.]).cuda()
            drop_loss = torch.tensor([0.]).cuda()
            operator_loss = torch.tensor([0.]).cuda()
            latent_loss = torch.tensor([0.]).cuda()
            loss = torch.tensor([0.]).cuda()
            #TODO: distance loss
            distance = []
            var = None
            var_types = None
            x = x_s[:x.shape[0]].clone() #torch.cat([t.unsqueeze(1)[:, :, :, :-1], x_s[:x.shape[0]].clone()], axis=-1)
            x = F.one_hot(x, params.integer_range + 1).float()
            for i in range(query_num):
                (pred_act, pred_drop, pred_operator, latent_pred), x, var, var_types = model(x, var, var_types, p, s, z, t, i)
                statement_loss += statement_criterion(pred_act, y)
                drop_loss += drop_criterion(pred_drop, z)
                operator_loss += operator_criterion(pred_operator, w)
                latent_targ = torch.Tensor([i] * x.shape[0]).cuda()
                latent_loss = latent_criterion(latent_pred, latent_targ.view(-1).long())
            statement_loss /= query_num
            drop_loss /= query_num
            operator_loss /= query_num
            latent_loss /= query_num
            loss += statement_loss + operator_loss + drop_loss + latent_loss

            train_statement_losses.append(statement_loss.item() * y.shape[0])
            train_drop_losses.append(drop_loss.item() * z.shape[0])
            train_operator_losses.append(operator_loss.item() * w.shape[0])
            train_latent_losses.append(latent_loss.item() * w.shape[0])


            loss.backward()
            optimizer.step()


        avg_statement_train_loss = np.array(train_statement_losses).sum()
        avg_drop_train_loss = np.array(train_drop_losses).sum()
        avg_operator_train_loss = np.array(train_operator_losses).sum()
        avg_latent_train_loss = np.array(train_latent_losses).sum()

        model.eval()

        with torch.no_grad():
            query_num = 5

            total_statement_loss = 0.
            total_drop_loss = 0.
            total_operator_loss = 0.
            total_val_error = 0.
            for batch in tqdm(val_data_loader):
                x = Variable(batch[0].type(LongTensor))
                y = Variable(batch[1].type(LongTensor))
                z = Variable(batch[2].type(FloatTensor))
                w = Variable(batch[3].type(LongTensor))
                p = Variable(batch[4].type(LongTensor))
                s = Variable(batch[5].type(LongTensor))
                n = Variable(batch[6].type(LongTensor))
                t = Variable(batch[7].type(LongTensor))
                #x = Variable(batch[0].type(LongTensor))
                #y = Variable(batch[1].type(LongTensor))
                #z = Variable(batch[2].type(FloatTensor))
                #w = Variable(batch[3].type(LongTensor))

                statement_loss = torch.tensor([0.]).cuda()
                drop_loss = torch.tensor([0.]).cuda()
                operator_loss = torch.tensor([0.]).cuda()
                loss = torch.tensor([0.]).cuda()
                distance = []
                var = None
                var_types = None
                x = x_s[:x.shape[0]].clone()
                x = F.one_hot(x, params.integer_range + 1).float()
                for i in range(query_num):
                    output, x, var, var_types = model(x, var, var_types, p, s, z, t, i)
#                    statement_loss += statement_criterion(pred_act, y)
#                    drop_loss += drop_criterion(pred_drop, z)
#                    operator_loss += operator_criterion(pred_operator, w)
#                statement_loss /= query_num
#                drop_loss /= query_num
#                operator_loss /= query_num
#                loss += statement_loss + operator_loss + drop_loss
                val_statement_pred = output[0]
                val_drop_pred = output[1]
                val_operator_pred = output[2]

                val_statement_loss = statement_criterion(val_statement_pred, y)
                val_drop_loss = drop_criterion(val_drop_pred, z)
                val_operator_loss = operator_criterion(val_operator_pred, w)

                total_statement_loss += (val_statement_loss * y.shape[0]).item()
                total_drop_loss += (val_drop_loss * z.shape[0]).item()
                total_operator_loss += (val_operator_loss * w.shape[0]).item()

                predict = val_statement_pred.data.max(1)[1]
                total_val_error += (predict != y.data).sum().item()



            lr_sched.step()

            t1 = torch.tensor([avg_statement_train_loss], dtype=torch.float64, device='cuda')
            t2 = torch.tensor([avg_drop_train_loss], dtype=torch.float64, device='cuda')
            t3 = torch.tensor([avg_operator_train_loss], dtype=torch.float64, device='cuda')
            t4 = torch.tensor([total_statement_loss], dtype=torch.float64, device='cuda')
            t5 = torch.tensor([total_drop_loss], dtype=torch.float64, device='cuda')
            t6 = torch.tensor([total_operator_loss], dtype=torch.float64, device='cuda')
            t7 = torch.tensor([total_val_error], dtype=torch.float64, device='cuda')
            t8 = torch.tensor([avg_latent_train_loss], dtype=torch.float64, device='cuda')
            dist.barrier()
            dist.all_reduce(t1)
            dist.all_reduce(t2)
            dist.all_reduce(t3)
            dist.all_reduce(t4)
            dist.all_reduce(t5)
            dist.all_reduce(t6)
            dist.all_reduce(t7)
            dist.all_reduce(t8)
            avg_statement_train_loss, avg_drop_train_loss, avg_operator_train_loss, avg_latent_train_loss = \
                t1.tolist()[0]/len(train_data), t2.tolist()[0]/len(train_data), t3.tolist()[0]/len(train_data), t8.tolist()[0]/len(train_data)
            val_statement_loss, val_drop_loss, val_operator_loss = \
                t4.tolist()[0]/len(val_data), t5.tolist()[0]/len(val_data), t6.tolist()[0]/len(val_data)
            val_error = t7.tolist()[0] / len(val_data)

            print("Train loss: S %f" % avg_statement_train_loss, "D %f" % avg_drop_train_loss,
                  "F %f" % avg_operator_train_loss)
            print("Val loss: S %f" % val_statement_loss, "D %f" % val_drop_loss,
                  "F %f" % val_operator_loss)

            print("Val classification error: %f" % val_error)

            if dist.get_rank() == 0:
                tb_writer.add_scalar("loss/train_statement_loss", avg_statement_train_loss, epoch)
                tb_writer.add_scalar("loss/train_drop_loss", avg_drop_train_loss, epoch)
                tb_writer.add_scalar("loss/train_operator_loss", avg_operator_train_loss, epoch)
                tb_writer.add_scalar("loss/train_latent_loss", avg_latent_train_loss, epoch)

                tb_writer.add_scalar("loss/val_statement_loss", val_statement_loss, epoch)
                tb_writer.add_scalar("loss/val_drop_loss", val_drop_loss, epoch)
                tb_writer.add_scalar("loss/val_operator_loss", val_operator_loss, epoch)

                tb_writer.add_scalar("error/val_error", val_error, epoch)
                tb_writer.add_scalar("lr/lr", optimizer.state_dict()['param_groups'][0]['lr'], epoch)

            print('best_val_error:', best_val_error)
            if val_error < best_val_error:
                ckpt_path = save_dir / 'model-best'
                print("Found new best model")
                best_val_error = val_error
#                save(model, optimizer, epoch, params, save_dir)
                save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)
                patience_ctr = 0
            else:
                patience_ctr += 1
                if patience_ctr == params.patience:
                    print("Ran out of patience. Stopping training early...")
                    break
            ckpt_path = save_dir / 'model-latest'
            save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)

if __name__ == '__main__':
    train()



#def train():
#    params.gpus = init_distributed_mode() 
#    print(params.gpus)
#    # Define paths for storing tensorboard logs
##    date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
##    save_dir = params.model_output_path + '/' + date + '/PE_model/'
#    if not params.load_from_checkpoint:
#        date = time.strftime('%Y%m%d%H%M', time.localtime(time.time()))
#        save_dir = Path(params.model_output_path + date)
#    else:
#        save_dir = Path(params.model_output_path + params.load_from_checkpoint)
#    print('saved to:', save_dir)
#    if not save_dir.exists():
#        if dist.get_rank() == 0:
#            os.makedirs(str(save_dir))
#            tb_writer = SummaryWriter(save_dir)
#        dist.barrier()
#
#    # Load train and val data
#    with open(params.train_path, 'r') as f:
#        train_data, train_statement_target, train_drop_target, train_operator_target = load_data(f, params.max_len)
#
#    with open(params.val_path, 'r') as f:
#        val_data, val_statement_target, val_drop_target, val_operator_target = load_data(f, params.max_len)
#
#    # Define model
#    
#    model = PCCoder()
#    query = Query()
#    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
#    query = torch.nn.SyncBatchNorm.convert_sync_batchnorm(query)
#
#    device = torch.cuda.current_device()
#    model = model.to(device)
#    query = query.to(device)
#    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[params.gpus], find_unused_parameters=True)
#    query = torch.nn.parallel.DistributedDataParallel(query, device_ids=[params.gpus], find_unused_parameters=True)
#
#    #Define optimizer and loss
#    optimizer_m = torch.optim.Adam(model.parameters(), lr=params.learn_rate)
#    optimizer_q = torch.optim.Adam(query.parameters(), lr=params.learn_rate)
#
#    models = [model, query]
#    models_name = ['model', 'query']
#    optims = [optimizer_m, optimizer_q]
#    optims_name = ['optimizer_m', 'optimizer_q']
#
#    start_epoch = 0
#    best_val_error = np.inf
#    if params.load_from_checkpoint:
#        print("=> loading checkpoint '{}'".format(params.checkpoint_dir))
#        loaded_models, loaded_opts, start_epoch, best_val_error = \
#            load_checkpoint(models, optims, models_name, optims_name, params.load_from_checkpoint)
#        model, query = loaded_models
#        optimizer_m, optimizer_q = loaded_opts
#
#    lr_sched_m = torch.optim.lr_scheduler.StepLR(optimizer_m, step_size=params.lr_scheduler_step_size)
#    lr_sched_q = torch.optim.lr_scheduler.StepLR(optimizer_q, step_size=params.lr_scheduler_step_size)
#
#    statement_criterion = nn.CrossEntropyLoss()
#    drop_criterion = nn.BCELoss()
#    operator_criterion = nn.CrossEntropyLoss()
#
#    #Convert to appropriate types
#    # The cuda types are not used here on purpose - most GPUs can't handle so much memory
#    train_data, train_statement_target, train_drop_target, train_operator_target = torch.LongTensor(train_data), torch.LongTensor(train_statement_target), \
#                                                    torch.FloatTensor(train_drop_target), torch.LongTensor(train_operator_target)
#    val_data, val_statement_target, val_drop_target, val_operator_target = torch.LongTensor(val_data), torch.LongTensor(val_statement_target), \
#                                                    torch.FloatTensor(val_drop_target), torch.LongTensor(val_operator_target)
#
#
#    val_data = Variable(val_data.type(LongTensor))
#    val_statement_target = Variable(val_statement_target.type(LongTensor))
#    val_drop_target = Variable(val_drop_target.type(FloatTensor))
#    val_operator_target = Variable(val_operator_target.type(LongTensor))
#
#    train_dataset = TensorDataset(train_data, train_statement_target, train_drop_target, train_operator_target)
#    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
#    train_data_loader = DataLoader(train_dataset, batch_size=params.batch_size, 
#                            shuffle=False, pin_memory=False, sampler=train_sampler, num_workers=0)
#
#    val_dataset = TensorDataset(val_data, val_statement_target, val_drop_target, val_operator_target)
#    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
#    val_data_loader = DataLoader(val_dataset, batch_size=params.batch_size, 
#                            shuffle=False, pin_memory=False, sampler=val_sampler, num_workers=0)
#
##    start_x = torch.zeros(batch_size, 1, 4, encoding.L, device='cuda').long() + constants.NULL
#    for epoch in range(start_epoch, params.num_epochs):
#        model.train()
#        query.train()
#        print("Epoch %d" % epoch)
#
#        train_statement_losses, train_drop_losses, train_operator_losses = [], [], []
#
#        for batch in tqdm(train_data_loader):
#            query_num = random.randint(1, int(min(5, epoch / 2 + 1)))
#            print(x)
#            x = Variable(batch[0].type(LongTensor))
#            y = Variable(batch[1].type(LongTensor))
#            z = Variable(batch[2].type(FloatTensor))
#            w = Variable(batch[3].type(LongTensor))
#
#            optimizer_m.zero_grad()
#            optimizer_q.zero_grad()
#
#            x = start_x[:x.shape[0]].clone()
##            x_index = x
##            x_out_index = get_query_output(program, x_index)
##            x_out = F.one_hot(x_out_index, constants.NULL+1)
##            x = torch.cat([x, x_out], axis=2)
#            loss = torch.tensor([0.]).cuda()
#            distance = [] 
#            for i in range(query_num):
##                print('query_inp:', query_inp.shape)
##                print(x.argmax(-1))
#                query_inp, query_index = query(typ)
##                if i == 0:
##                    query_inp, query_index = query(typ, x, emb)
##                else:
##                    query_inp, query_index = query(typ, x[:, 1:], emb)
#                if i >= 1:
#                    distance.append((query_inp - x[:, 1:, :3])**2)
#                query_out_index = get_query_output(program, query_index)
#                query_out = F.one_hot(query_out_index, constants.NULL+1)
#                query_io = torch.cat([query_inp, query_out], axis=2)
#                x = torch.cat([x, query_io], axis=1)
##                print(typ.shape)
#                query_io_index = torch.cat([query_index, query_out_index], axis=2)
#                query_io_index2 = query_io.argmax(-1) 
#                #print(query_io_index2)
#                pred = model(typ[:,:(i+2)*4,:], x[:, 1:])
#                l = -torch.sum(y * torch.log(pred + 1e-8)) / pred.shape[0]
#                loss += l
#
#            pred_act, pred_drop, pred_operator = model(x)
#            statement_loss = statement_criterion(pred_act, y)
#            drop_loss = drop_criterion(pred_drop, z)
#            operator_loss = operator_criterion(pred_operator, w)
#            loss = statement_loss + operator_loss + drop_loss
#
#            train_statement_losses.append(statement_loss.item() * y.shape[0])
#            train_drop_losses.append(drop_loss.item() * z.shape[0])
#            train_operator_losses.append(operator_loss.item() * w.shape[0])
#
#            loss.backward()
#            optimizer_m.step()
#            optimizer_q.step()
#
#
#        avg_statement_train_loss = np.array(train_statement_losses).sum()
#        avg_drop_train_loss = np.array(train_drop_losses).sum()
#        avg_operator_train_loss = np.array(train_operator_losses).sum()
#
#        model.eval()
#        query.eval()
#
#        with torch.no_grad():
#
#            total_statement_loss = 0.
#            total_drop_loss = 0.
#            total_operator_loss = 0.
#            total_val_error = 0.
#            for batch in val_data_loader:
#                x = Variable(batch[0].type(LongTensor))
#                y = Variable(batch[1].type(LongTensor))
#                z = Variable(batch[2].type(FloatTensor))
#                w = Variable(batch[3].type(LongTensor))
#                output = model(x)
#                val_statement_pred = output[0]
#                val_drop_pred = output[1]
#                val_operator_pred = output[2]
#
#                val_statement_loss = statement_criterion(val_statement_pred, y)
#                val_drop_loss = drop_criterion(val_drop_pred, z)
#                val_operator_loss = operator_criterion(val_operator_pred, w)
#
#                total_statement_loss += (val_statement_loss * y.shape[0]).item()
#                total_drop_loss += (val_drop_loss * z.shape[0]).item()
#                total_operator_loss += (val_operator_loss * w.shape[0]).item()
#
#                predict = val_statement_pred.data.max(1)[1]
#                total_val_error += (predict != y.data).sum().item()
#            lr_sched_m.step()
#            lr_sched_q.step()
#
#            t1 = torch.tensor([avg_statement_train_loss], dtype=torch.float64, device='cuda')
#            t2 = torch.tensor([avg_drop_train_loss], dtype=torch.float64, device='cuda')
#            t3 = torch.tensor([avg_operator_train_loss], dtype=torch.float64, device='cuda')
#            t4 = torch.tensor([total_statement_loss], dtype=torch.float64, device='cuda')
#            t5 = torch.tensor([total_drop_loss], dtype=torch.float64, device='cuda')
#            t6 = torch.tensor([total_operator_loss], dtype=torch.float64, device='cuda')
#            t7 = torch.tensor([total_val_error], dtype=torch.float64, device='cuda')
#            dist.barrier()
#            dist.all_reduce(t1)
#            dist.all_reduce(t2)
#            dist.all_reduce(t3)
#            dist.all_reduce(t4)
#            dist.all_reduce(t5)
#            dist.all_reduce(t6)
#            dist.all_reduce(t7)
#            avg_statement_train_loss, avg_drop_train_loss, avg_operator_train_loss = \
#                t1.tolist()[0]/len(train_data), t2.tolist()[0]/len(train_data), t3.tolist()[0]/len(train_data)
#            val_statement_loss, val_drop_loss, val_operator_loss = \
#                t4.tolist()[0]/len(val_data), t5.tolist()[0]/len(val_data), t6.tolist()[0]/len(val_data)
#            val_error = t7.tolist()[0] / len(val_data)
#
#            print("Train loss: S %f" % avg_statement_train_loss, "D %f" % avg_drop_train_loss,
#                  "F %f" % avg_operator_train_loss)
#            print("Val loss: S %f" % val_statement_loss, "D %f" % val_drop_loss,
#                  "F %f" % val_operator_loss)
#
#            print("Val classification error: %f" % val_error)
#
#            if dist.get_rank() == 0:
#                tb_writer.add_scalar("loss/train_statement_loss", avg_statement_train_loss, epoch)
#                tb_writer.add_scalar("loss/train_drop_loss", avg_drop_train_loss, epoch)
#                tb_writer.add_scalar("loss/train_operator_loss", avg_operator_train_loss, epoch)
#
#                tb_writer.add_scalar("loss/val_statement_loss", val_statement_loss, epoch)
#                tb_writer.add_scalar("loss/val_drop_loss", val_drop_loss, epoch)
#                tb_writer.add_scalar("loss/val_operator_loss", val_operator_loss, epoch)
#
#                tb_writer.add_scalar("error/val_error", val_error, epoch)
#                tb_writer.add_scalar("lr/lr", optimizer_m.state_dict()['param_groups'][0]['lr'], epoch)
#
#            if val_error < best_val_error:
#                ckpt_path = save_dir / 'model-best'
#                print("Found new best model")
#                best_val_error = val_error
##                save(model, optimizer, epoch, params, save_dir)
#                save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)
#                patience_ctr = 0
#            else:
#                patience_ctr += 1
#                if patience_ctr == params.patience:
#                    print("Ran out of patience. Stopping training early...")
#                    break
#            ckpt_path = save_dir / 'model-latest'
#            save_checkpoint(models, optims, models_name, optims_name, epoch, best_val_error, ckpt_path)
#
#if __name__ == '__main__':
#    train()