from __future__ import division
# External imports
import json
import os
import torch

from torch.autograd import Variable
from tqdm import tqdm

from nps.data import load_input_file, get_minibatch, shuffle_dataset, KarelDataset, KarelDatasetWithRandom
from karel.consistency import Simulator
from syntax.checker import PySyntaxChecker

import torch.distributed as dist

import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence

from karel.world import World
import numpy as np

import random

from nps.query_net import *
from nps.utils import load_checkpoint, save_checkpoint, init_distributed_mode, expand_img, random_IO, select_grids, my_collate 
#-------------------------------
def evaluate_query(query,
                   model,
                   vocabulary_path,
                   dataset_path,
                   nb_ios,
                   nb_samples,
                   use_grammar,
                   output_path,
                   beam_size,
                   top_k,
                   batch_size,
                   use_cuda,
                   dump_programs,
                   random_ios,
                   query_num,
                   firstio,
                   random_test=False):
    output_path = output_path+'/Result_'+str(random_ios)+'_'+str(query_num)+'/'
    all_outputs_path, all_semantic_output_path, \
    all_syntax_output_path, all_generalize_output_path, \
    all_query_output_path, all_random_output_path, program_dump_path = get_path(output_path, top_k)
#    if os.path.exists(all_outputs_path[0]):
#        with open(all_outputs_path[0], "r") as out_file:
#            out_file_content = out_file.read()
#            print("Using cached result from {}".format(all_outputs_path[0]))
#            print("Greedy select accuracy: {}".format(out_file_content))
#            return

    # Load the vocabulary of the trained model
    dataset, vocab = load_input_file(dataset_path, vocabulary_path)
    if random_test:
        randomset, _ = load_input_file('datasets/karel/1m_6ex_karel/val.random.json',
            vocabulary_path)
    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]

    simulator = Simulator(vocab["idx2tkn"])
    
    # And put it into evaluation mode
    model.eval()
    query.eval()

    syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda)
    if use_grammar:
        model.module.set_syntax_checker(syntax_checker)

    if beam_size == 1:
        top_k = 1
    nb_correct = [0 for _ in range(top_k)]
    nb_semantic_correct = [0 for _ in range(top_k)]
    nb_syntax_correct = [0 for _ in range(top_k)]
    nb_generalize_correct = [0 for _ in range(top_k)]
    nb_random_correct = [0 for _ in range(top_k)]
    nb_query_correct = [0 for _ in range(top_k)]
    total_nb_list = [0]
 
    if random_test:
        print('test with random')
        val_dataset = KarelDatasetWithRandom(dataset, randomset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
        val_loader = DataLoader(val_dataset, shuffle=False, pin_memory=True,
                            batch_size=batch_size, sampler=val_sampler,
                            num_workers=0, collate_fn=random_collate)
#        dataset = shuffle_dataset(dataset, batch_size, randomize=False)
        pbar = tqdm(enumerate(val_loader), total=len(val_loader))
        with torch.no_grad():
            for batch, (inp_grids, out_grids,
                in_tgt_seq, in_tgt_seq_list, out_tgt_seq,
                inp_worlds, out_worlds,
                random_inp_worlds, random_out_worlds,
                targets,
                inp_test_worlds, out_test_worlds) in pbar:
                sp_idx = batch * batch_size

                inp_grids = inp_grids.cuda() 
                out_grids = out_grids.cuda() 
                in_tgt_seq = in_tgt_seq.cuda()
                out_tgt_seq = out_tgt_seq.cuda()

#----------------------------------------------
#|                   random input             |
#----------------------------------------------
#                inp_grids, out_grids = select_grids(inp_grids, out_grids, out_tgt_seq, simulator, random_ios, firstio=firstio)
#----------------------------------------------
#                inp_grids, out_grids = make_query(query, inp_grids, out_grids, out_tgt_seq, simulator, query_num)

                eval_func(
                    model,
                    dump_programs,
                    inp_grids,
                    out_grids,
                    in_tgt_seq,
                    in_tgt_seq_list,
                    out_tgt_seq,
                    inp_worlds,
                    out_worlds,
                    random_inp_worlds,
                    random_out_worlds,
                    inp_test_worlds,
                    out_test_worlds,
                    tgt_pad,
                    tgt_start,
                    tgt_end,
                    beam_size,
                    top_k,
                    simulator,
                    program_dump_path,
                    nb_correct,
                    nb_semantic_correct,
                    nb_syntax_correct,
                    nb_generalize_correct,
                    nb_random_correct,
                    nb_query_correct,
                    total_nb_list,
                    vocab,
                    sp_idx,
                    random_test
                    )

            total_nb = total_nb_list[0]

            semantic_at_one = dump_results(
                                top_k,
                                nb_correct,
                                nb_semantic_correct,
                                nb_syntax_correct,
                                nb_generalize_correct,
                                nb_random_correct,
                                nb_query_correct,
                                total_nb,
                                all_outputs_path,
                                all_semantic_output_path,
                                all_syntax_output_path,
                                all_generalize_output_path,
                                all_random_output_path,
                                all_query_output_path,
                                random_test
                                )
    else:
        print('test w/o random')
        val_dataset = KarelDataset(dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
        val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
        val_loader = DataLoader(val_dataset, shuffle=False, pin_memory=True,
                            batch_size=batch_size, sampler=val_sampler,
                            num_workers=0, collate_fn=my_collate)
#        dataset = shuffle_dataset(dataset, batch_size, randomize=False)
        pbar = tqdm(enumerate(val_loader), total=len(val_loader))
        with torch.no_grad():
            for batch, (inp_grids, out_grids,
                in_tgt_seq, in_tgt_seq_list, out_tgt_seq,
                inp_worlds, out_worlds,
                targets,
                inp_test_worlds, out_test_worlds) in pbar:
                sp_idx = batch * batch_size

                inp_grids = inp_grids.cuda() 
                out_grids = out_grids.cuda() 
                in_tgt_seq = in_tgt_seq.cuda()
                out_tgt_seq = out_tgt_seq.cuda()

#---        -----------------------------------
#|                   random input             |
#---        -----------------------------------
#                inp_grids, out_grids = select_grids(inp_grids, out_grids, out_tgt_seq, simulator, random_ios, firstio=firstio)
#---        -----------------------------------
#                inp_grids, out_grids = make_query(query, inp_grids, out_grids, out_tgt_seq, simulator, query_num)

                eval_func(
                    model,
                    dump_programs,
                    inp_grids,
                    out_grids,
                    in_tgt_seq,
                    in_tgt_seq_list,
                    out_tgt_seq,
                    inp_worlds,
                    out_worlds,
                    list(range(len(inp_worlds))),
                    list(range(len(out_worlds))),
                    inp_test_worlds,
                    out_test_worlds,
                    tgt_pad,
                    tgt_start,
                    tgt_end,
                    beam_size,
                    top_k,
                    simulator,
                    program_dump_path,
                    nb_correct,
                    nb_semantic_correct,
                    nb_syntax_correct,
                    nb_generalize_correct,
                    nb_random_correct,
                    nb_query_correct,
                    total_nb_list,
                    vocab,
                    sp_idx,
                    random_test
                    )

            total_nb = total_nb_list[0]

            semantic_at_one = dump_results(
                                top_k,
                                nb_correct,
                                nb_semantic_correct,
                                nb_syntax_correct,
                                nb_generalize_correct,
                                nb_random_correct,
                                nb_query_correct,
                                total_nb,
                                all_outputs_path,
                                all_semantic_output_path,
                                all_syntax_output_path,
                                all_generalize_output_path,
                                all_random_output_path,
                                all_query_output_path,
                                random_test
                                )
        return semantic_at_one


def add_eval_args(parser):
    parser.add_argument('--eval_nb_ios', type=int,
                        default=5)
    parser.add_argument('--use_grammar', action="store_true")
    parser.add_argument("--val_nb_samples", type=int,
                        default=0,
                        help="How many samples to use to compute the accuracy."
                        "Default: %(default)s, for all the dataset")

def add_beam_size_arg(parser):
    parser.add_argument("--eval_batch_size", type=int,
                        default=8)
    parser.add_argument("--beam_size", type=int,
                        default=10,
                        help="Size of the beam search. Default %(default)s")
    parser.add_argument("--top_k", type=int,
                        default=5,
                        help="How many candidates to return. Default %(default)s")

def get_path(output_path, top_k):
    all_outputs_path = []
    all_semantic_output_path = []
    all_syntax_output_path = []
    all_generalize_output_path = []
    all_random_output_path = []
    all_query_output_path = []

    print('output path:', output_path)
    res_dir = os.path.dirname(output_path)
    if not os.path.exists(res_dir) and dist.get_rank() == 0:
        os.makedirs(res_dir)
    dist.barrier()
    for k in range(top_k):
        new_term = "exactmatch_top%d.txt" % (k+1)
        new_semantic_term = "semantic_top%d.txt" % (k+1)
        new_syntax_term = "syntax_top%d.txt" % (k+1)
        new_generalize_term = "fullgeneralize_top%d.txt" % (k+1)
        new_random_term = "fullrandom_top%d.txt" % (k+1)
        new_query_term = "query_top%d.txt" % (k+1)

        new_file_name = output_path + new_term
        new_semantic_file_name = output_path + new_semantic_term
        new_syntax_file_name = output_path + new_syntax_term
        new_generalize_file_name = output_path + new_generalize_term
        new_random_file_name = output_path + new_random_term
        new_query_file_name = output_path + new_query_term

        all_outputs_path.append(new_file_name)
        all_semantic_output_path.append(new_semantic_file_name)
        all_syntax_output_path.append(new_syntax_file_name)
        all_generalize_output_path.append(new_generalize_file_name)
        all_random_output_path.append(new_random_file_name)
        all_query_output_path.append(new_query_file_name)
    program_dump_path = os.path.join(res_dir, "log.wrong_programs")

    f = open(program_dump_path, 'w')
    f.close()
    f = open(program_dump_path.replace("wrong", "true"), 'w')
    f.close()
    f = open(program_dump_path.replace("wrong", "qright_swrong"), 'w')
    f.close()
    tmp_path = program_dump_path.replace("wrong_programs", "inp_worlds")
    f = open(tmp_path, 'w')
    f.close()
    tmp_path = program_dump_path.replace("wrong_programs", "out_worlds")
    f = open(tmp_path, 'w')
    f.close()
    tmp_path = program_dump_path.replace("wrong_programs", "q_inp_worlds")
    f = open(tmp_path, 'w')
    f.close()
    tmp_path = program_dump_path.replace("wrong_programs", "q_out_worlds")
    f = open(tmp_path, 'w')
    f.close()
    tmp_path = program_dump_path.replace("wrong_programs", "wrong_worlds")
    f = open(tmp_path, 'w')
    f.close()

    return all_outputs_path, all_semantic_output_path, \
        all_syntax_output_path, all_generalize_output_path, \
        all_query_output_path, all_random_output_path, program_dump_path

def dump_results(
    top_k,
    nb_correct,
    nb_semantic_correct,
    nb_syntax_correct,
    nb_generalize_correct,
    nb_random_correct,
    nb_query_correct,
    total_nb,
    all_outputs_path,
    all_semantic_output_path,
    all_syntax_output_path,
    all_generalize_output_path,
    all_random_output_path,
    all_query_output_path,
    random_test=False
    ):
    for k in range(top_k):
        t1 = torch.tensor([100*nb_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t2 = torch.tensor([100*nb_semantic_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t3 = torch.tensor([100*nb_syntax_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t4 = torch.tensor([100*nb_generalize_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t7 = torch.tensor([100*nb_query_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        dist.barrier() 
        dist.all_reduce(t1)
        dist.all_reduce(t2)
        dist.all_reduce(t3)
        dist.all_reduce(t4)
        dist.all_reduce(t7)
        t1, t2, t3, t4, t7 = t1.tolist(), t2.tolist(), t3.tolist(), t4.tolist(), t7.tolist()
        with open(str(all_outputs_path[k]), "w") as res_file:
#            res_file.write(str(100*nb_correct[k]/total_nb))
            res_file.write(str(t1[0]/t1[1]))
        with open(str(all_semantic_output_path[k]), "w") as sem_res_file:
#            sem_res_file.write(str(100*nb_semantic_correct[k]/total_nb))
            sem_res_file.write(str(t2[0]/t2[1]))
        with open(str(all_syntax_output_path[k]), "w") as stx_res_file:
#            stx_res_file.write(str(100*nb_syntax_correct[k]/total_nb))
            stx_res_file.write(str(t3[0]/t3[1]))
        with open(str(all_generalize_output_path[k]), "w") as gen_res_file:
#            gen_res_file.write(str(100*nb_generalize_correct[k]/total_nb))
            gen_res_file.write(str(t4[0]/t4[1]))
        with open(str(all_query_output_path[k]), "w") as query_res_file:
#            gen_res_file.write(str(100*nb_generalize_correct[k]/total_nb))
            query_res_file.write(str(t7[0]/t7[1]))

        if random_test:
            t8 = torch.tensor([100*nb_random_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
            dist.barrier() 
            dist.all_reduce(t8)
            t8 = t8.tolist()
            with open(str(all_random_output_path[k]), "w") as gen_res_file:
                gen_res_file.write(str(t8[0]/t8[1]))

    t5 = torch.tensor([total_nb], dtype=torch.float64, device='cuda')
    t6 = torch.tensor([nb_semantic_correct[0]], dtype=torch.float64, device='cuda')
    dist.barrier()
    dist.all_reduce(t5)
    dist.all_reduce(t6)
    t5, t6 = t5.tolist(), t6.tolist()
    print('total_nb:', t5[0])
    print('nb_semantic_correct[0]:', t6[0])
    semantic_at_one = t6[0] / t5[0]
    return semantic_at_one


def eval_func(
            model,
            dump_programs,
            inp_grids,
            out_grids,
            in_tgt_seq,
            in_tgt_seq_list,
            out_tgt_seq,
            inp_worlds,
            out_worlds,
            random_inp_worlds,
            random_out_worlds,
            inp_test_worlds,
            out_test_worlds,
            tgt_pad,
            tgt_start,
            tgt_end,
            beam_size,
            top_k,
            simulator,
            program_dump_path,
            nb_correct,
            nb_semantic_correct,
            nb_syntax_correct,
            nb_generalize_correct,
            nb_random_correct,
            nb_query_correct,
            total_nb_list,
            vocab,
            sp_idx,
            random_test
            ):
    max_len = out_tgt_seq.size(1) + 10
#    if dump_programs:
#        decoder_logit, syntax_logit = model(inp_grids, out_grids, in_tgt_seq, in_tgt_seq_list)
#        if syntax_logit is not None and model.module.decoder.learned_syntax_checker is not None:
#            syntax_logit = syntax_logit.cpu().data.numpy()
#            for n in range(in_tgt_seq.size(0)):
#                decoded_dump_dir = os.path.join(program_dump_path, str(n + sp_idx))
#                if not os.path.exists(decoded_dump_dir) and dist.get_rank() == 0:
#                    os.makedirs(decoded_dump_dir)
#                dist.barrier()
#                seq = in_tgt_seq.cpu().data.numpy()[n].tolist()
#                seq_len = seq.index(0) if 0 in seq else len(seq)
#                file_name = str(n) + "_learned_syntax"
#                norm_logit = syntax_logit[n,:seq_len]
#                norm_logit = np.log(-norm_logit)
#                norm_logit = 1 / (1 + np.exp(-norm_logit))
#                np.save(os.path.join(decoded_dump_dir, file_name), norm_logit)
#                ini_state = syntax_checker.get_initial_checker_state()
#                file_name = str(n) + "_manual_syntax"
#                mask = syntax_checker.get_sequence_mask(ini_state, seq).squeeze().cpu().numpy()[:seq_len]
#                np.save(os.path.join(decoded_dump_dir, file_name), mask)
#                file_name = str(n) + "_diff"
#                diff = mask.astype(float) - norm_logit
#                diff = (diff + 1) / 2 # remap to [0,1]
#                np.save(os.path.join(decoded_dump_dir, file_name), diff)

    decoded = model.module.beam_sample(inp_grids, out_grids,
                                tgt_start, tgt_end, max_len,
                                beam_size, top_k)
    for batch_idx, (target, sp_decoded,
                    sp_input_worlds, sp_output_worlds,
                    sp_test_input_worlds, sp_test_output_worlds,
                    sp_random_input_worlds, sp_random_output_worlds,
                    sp_inp_grids, sp_out_grids) in \
        enumerate(zip(out_tgt_seq.chunk(out_tgt_seq.size(0)), decoded,
                      inp_worlds, out_worlds,
                      inp_test_worlds, out_test_worlds,
                      random_inp_worlds, random_out_worlds,
                      inp_grids.cpu(), out_grids.cpu())):

        sample_idx = sp_idx + batch_idx

        total_nb_list[0] += 1
        target = target.cpu().data.squeeze().numpy().tolist()
        target = [tkn_idx for tkn_idx in target if tkn_idx != tgt_pad]

#        if dump_programs:
#            decoded_dump_dir = os.path.join(program_dump_path, str(batch_idx + sp_idx))
#            if not os.path.exists(decoded_dump_dir) and dist.get_rank() == 0:
#                os.makedirs(decoded_dump_dir)
#            dist.barrier()
#            write_program(os.path.join(decoded_dump_dir, "target"), target, vocab["idx2tkn"])
#            for rank, dec in enumerate(sp_decoded):
#                pred = dec[1]
#                ll = dec[0]
#                file_name = str(rank)+ " - " + str(ll)
#                write_program(os.path.join(decoded_dump_dir, file_name), pred, vocab["idx2tkn"])


#        pred = target
        # Exact matches
        for rank, dec in enumerate(sp_decoded):
            pred = dec[-1]
            if pred == target:
                # This prediction is correct. This means that we score for
                # all the following scores
                for top_idx in range(rank, top_k):
                    nb_correct[top_idx] += 1
                break

        # Semantic matches
        for rank, dec in enumerate(sp_decoded):
            pred = dec[-1]
            parse_success, cand_prog = simulator.get_prog_ast(pred)
            if (not parse_success):
                continue
            semantically_correct = True
            for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds):
                res_emu = simulator.run_prog(cand_prog, input_world)
#                if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                if (res_emu.outgrid != output_world):
                    # This prediction is semantically incorrect.
                    semantically_correct = False
                    break
            if semantically_correct:
                # Score for all the following ranks
                for top_idx in range(rank, top_k):
                    nb_semantic_correct[top_idx] += 1
                break

        # Random Test 
        if random_test:
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                parse_success, cand_prog = simulator.get_prog_ast(pred)
                if (not parse_success):
                    continue
                randomizes = True
                for (input_world, output_world) in zip(sp_random_input_worlds, sp_random_output_worlds):
                    res_emu = simulator.run_prog(cand_prog, input_world)
#                    if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                    if (res_emu.outgrid != output_world):
                        # This prediction is semantically incorrect.
                        randomizes = False
                        break
                if randomizes:
                    # Score for all the following ranks
                    for top_idx in range(rank, top_k):
                        nb_random_correct[top_idx] += 1
                    break

        # Generalization
        q_inp_worlds = []
        q_out_worlds = []
        for idx in range(len(sp_inp_grids)):
            q_inp_worlds.append(World.fromPytorchTensor(sp_inp_grids[idx]))
            q_out_worlds.append(World.fromPytorchTensor(sp_out_grids[idx]))

        top1_generalizes = True

        for rank, dec in enumerate(sp_decoded):
            pred = dec[-1]
            parse_success, cand_prog = simulator.get_prog_ast(pred)
            if (not parse_success):
                continue
            generalizes = True
            for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds):
                res_emu = simulator.run_prog(cand_prog, input_world)
#                if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                if (res_emu.outgrid != output_world):
                    # This prediction is semantically incorrect.
                    generalizes = False
                    break
            for (input_world, output_world) in zip(sp_test_input_worlds, sp_test_output_worlds):
                res_emu = simulator.run_prog(cand_prog, input_world)
#                if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                if (res_emu.outgrid != output_world):
                    # This prediction is semantically incorrect.
                    generalizes = False
                    break
            if generalizes:
                # Score for all the following ranks
                for top_idx in range(rank, top_k):
                    nb_generalize_correct[top_idx] += 1
                break

            if dump_programs:
                if not generalizes and rank == 0:
                    top1_generalizes = False
                    write_program(program_dump_path, pred, vocab["idx2tkn"], sample_idx)
                    write_program(program_dump_path.replace("wrong", "true"), target, vocab["idx2tkn"], sample_idx)
                    write_grids(program_dump_path, sp_input_worlds+sp_test_input_worlds,
                                sp_output_worlds+sp_test_output_worlds, input_world,
                                res_emu.outgrid, q_inp_worlds, q_out_worlds, sample_idx)

        # Query 
        for rank, dec in enumerate(sp_decoded):
            pred = dec[-1]
#            pred = target 
            parse_success, cand_prog = simulator.get_prog_ast(pred)
            if (not parse_success):
                continue
            satisfy_query = True
            for (input_world, output_world) in zip(q_inp_worlds, q_out_worlds):
                res_emu = simulator.run_prog(cand_prog, input_world)
                tmp_out_grid = res_emu.outgrid.toPytorchTensor2(sp_inp_grids.shape[-1])
                tmp_out_world = World.fromPytorchTensor2(tmp_out_grid.cpu())
#                if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                if (tmp_out_world != output_world):
                    # This prediction is semantically incorrect.
                    satisfy_query = False
#                    print('False2')
                    break
            if dump_programs:
                if satisfy_query and rank == 0 and not top1_generalizes:
                    write_program(program_dump_path.replace("wrong", "qright_swrong"), target, vocab["idx2tkn"], sample_idx)
            if satisfy_query:
                # Score for all the following ranks
                for top_idx in range(rank, top_k):
                    nb_query_correct[top_idx] += 1
                break


        # Correct syntaxes
        for rank, dec in enumerate(sp_decoded):
            pred = dec[-1]
            parse_success, cand_prog = simulator.get_prog_ast(pred)
            if parse_success:
                for top_idx in range(rank, top_k):
                    nb_syntax_correct[top_idx] += 1
                break


def make_query(query, inp_grids, out_grids, out_tgt_seq, simulator, query_num):
    for query_idx in range(query_num):
        query_inp = query(inp_grids, out_grids)
###        query_inp_hard = query.module.hard(query_inp) 
        out_list = []
        for idx in range(out_tgt_seq.shape[0]):
            torch.set_printoptions(profile="full")
            inp_world = World.fromPytorchTensor(query_inp[idx][0].detach().cpu().long())
###            inp_world = World.fromPytorchTensor(query_inp_hard[idx][0].detach().cpu().long())
            out_tgt = out_tgt_seq[idx].cpu().numpy()
            out_tgt = np.trim_zeros(out_tgt)
            parse_success, cand_prog = simulator.get_prog_ast(out_tgt.tolist())
            if (not parse_success):
                print('parse error!')
                print(out_tgt_seq[idx])
                write_program('./log.tmp', out_tgt.tolist(), vocab["idx2tkn"], 0)
                continue
            res_emu = simulator.run_prog(cand_prog, inp_world)
            out_grid = res_emu.outgrid.toPytorchTensor(inp_grids.shape[-1])


#            tmp_out_world = World.fromPytorchTensor(out_grid)
#            if (res_emu.outgrid != tmp_out_world):
#                print("False1")
#                print('============================')
#                print(res_emu.outgrid.toString())
#                print('----------------------------')
#                print(tmp_out_world.toString())
#                print(res_emu.outgrid.heroRow, tmp_out_world.heroRow)
#                print(res_emu.outgrid.heroCol, tmp_out_world.heroCol)
#                print(res_emu.outgrid.heroDir, tmp_out_world.heroDir)
#                print(res_emu.outgrid.markers, tmp_out_world.markers)
#                print('============================')


            out_list.append(out_grid)
        query_out = torch.stack(out_list, 0).unsqueeze(1).cuda()
        inp_grids = torch.cat([inp_grids, query_inp], 1)
        out_grids = torch.cat([out_grids, query_out], 1)

#    # Query 
#    for idx in range(out_tgt_seq.shape[0]):
#        q_inp_worlds = []
#        q_out_worlds = []
#        for ios in range(len(inp_grids[idx])):
#            q_inp_worlds.append(World.fromPytorchTensor(inp_grids[idx][ios].cpu()))
#            q_out_worlds.append(World.fromPytorchTensor(out_grids[idx][ios].cpu()))
#        out_tgt = out_tgt_seq[idx].cpu().numpy()
#        out_tgt = np.trim_zeros(out_tgt)
#        pred = out_tgt
#        parse_success, cand_prog = simulator.get_prog_ast(pred)
#        if (not parse_success):
#            continue
#        satisfy_query = True
#        for (input_world, output_world) in zip(q_inp_worlds, q_out_worlds):
#            res_emu = simulator.run_prog(cand_prog, input_world)
#            tmp_out_grid = res_emu.outgrid.toPytorchTensor(inp_grids.shape[-1])
#            tmp_out_world = World.fromPytorchTensor(tmp_out_grid)
##            if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
#            if (tmp_out_world != output_world):
#                # This prediction is semantically incorrect.
#                satisfy_query = False
#                break
#        if not satisfy_query:
#            print("False")
    return inp_grids.data, out_grids.data

def write_program(path, tkn_idxs, vocab, sp_idx):
    program_tkns = [vocab[tkn_idx] for tkn_idx in tkn_idxs]

    indent = 0
    is_new_line = False
    with open(path, "a") as target_file:
        target_file.write("program %d:\n" % sp_idx)
        for tkn in program_tkns:
            if tkn in ["m(", "w(", "i(", "e(", "r("]:
                indent += 4
                target_file.write("\n"+" "*indent)
                target_file.write(tkn + " ")
                is_new_line = False
            elif tkn in ["m)", "w)", "i)", "e)", "r)"]:
                if is_new_line:
                    target_file.write("\n"+" "*indent)
                indent -= 4
                target_file.write(tkn)
                if indent < 0:
                    indent = 0
                is_new_line = True
            elif tkn in ["REPEAT"]:
                if is_new_line:
                    target_file.write("\n"+" "*indent)
                    is_new_line = False
                target_file.write(tkn + " ")
            else:
                if is_new_line:
                    target_file.write("\n"+" "*indent)
                    is_new_line = False
                target_file.write(tkn + " ")
        target_file.write("\n")
        target_file.write("\n")

def write_grids(path, inp_worlds, out_worlds, wrong_inp_worlds, wrong_out_worlds, q_inp_worlds, q_out_worlds, sp_idx):
    tmp_path = path.replace("wrong_programs", "inp_worlds")
    with open(tmp_path, "a") as target_file:
        target_file.write("program %d:\n" % sp_idx)
        for idx in range(len(inp_worlds)):
            target_file.write("IO %d:\n" % idx)
            target_file.write(inp_worlds[idx].toString())
            target_file.write('\n')
            target_file.write("\n")

    tmp_path = path.replace("wrong_programs", "out_worlds")
    with open(tmp_path, "a") as target_file:
        target_file.write("program %d:\n" % sp_idx)
        for idx in range(len(out_worlds)):
            target_file.write("IO %d:\n" % idx)
            target_file.write(out_worlds[idx].toString())
            target_file.write('\n')
            target_file.write("\n")

    tmp_path = path.replace("wrong_programs", "q_inp_worlds")
    with open(tmp_path, "a") as target_file:
        target_file.write("program %d:\n" % sp_idx)
        for idx in range(len(q_inp_worlds)):
            target_file.write("IO %d:\n" % idx)
            target_file.write(q_inp_worlds[idx].toString())
            target_file.write('\n')
            target_file.write("\n")

    tmp_path = path.replace("wrong_programs", "q_out_worlds")
    with open(tmp_path, "a") as target_file:
        target_file.write("program %d:\n" % sp_idx)
        for idx in range(len(q_out_worlds)):
            target_file.write("IO %d:\n" % idx)
            target_file.write(q_out_worlds[idx].toString())
            target_file.write('\n')
            target_file.write("\n")

    tmp_path = path.replace("wrong_programs", "wrong_worlds")
    with open(tmp_path, "a") as target_file:
        target_file.write("program %d:\n" % sp_idx)
        target_file.write("wrong I:\n")
        target_file.write(wrong_inp_worlds.toString())
        target_file.write('\n')
        target_file.write("wrong O:\n")
        target_file.write(wrong_out_worlds.toString())
        target_file.write('\n')
        target_file.write("\n")

def random_collate(batch):
    inp_grids = torch.stack([item[0] for item in batch], 0)
    out_grids = torch.stack([item[1] for item in batch], 0)
    in_tgt_seq = [item[2] for item in batch]
    out_tgt_seq = [item[4] for item in batch]
    inp_worlds = [item[5] for item in batch]
    out_worlds = [item[6] for item in batch]
    random_inp_worlds = [item[7] for item in batch]
    random_out_worlds = [item[8] for item in batch]
    targets = [item[9] for item in batch]
    inp_test_worlds = [item[10] for item in batch]
    out_test_worlds = [item[11] for item in batch]

    in_tgt_seq = pad_sequence(in_tgt_seq, batch_first=True, padding_value=0)
    in_tgt_seq = in_tgt_seq[:, :-1]
    out_tgt_seq = pad_sequence(out_tgt_seq, batch_first=True, padding_value=0)
    
    input_lines = in_tgt_seq.tolist()

    return [inp_grids, out_grids, in_tgt_seq, input_lines, out_tgt_seq, \
        inp_worlds, out_worlds, random_inp_worlds, random_out_worlds, \
        targets, inp_test_worlds, out_test_worlds]

