import torch
from karel.world import World
from karel.consistency import Simulator
import sys
import os
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

IMG_FEAT = 5184
IMG_DIM = 18
IMG_SIZE = torch.Size((16, IMG_DIM, IMG_DIM))

def write_program(tkn_idxs, vocab, sp_idx):
    program_tkns = [vocab[tkn_idx] for tkn_idx in tkn_idxs]

    indent = 0
    is_new_line = False
    print("program %d:\n" % sp_idx, end='')
    for tkn in program_tkns:
        if tkn in ["m(", "w(", "i(", "e(", "r("]:
            indent += 4
            print("\n"+" "*indent, end='')
            print(tkn + " ", end='')
            is_new_line = False
        elif tkn in ["m)", "w)", "i)", "e)", "r)"]:
            if is_new_line:
                print("\n"+" "*indent, end='')
            indent -= 4
            print(tkn, end='')
            if indent < 0:
                indent = 0
            is_new_line = True
        elif tkn in ["REPEAT"]:
            if is_new_line:
                print("\n"+" "*indent, end='')
                is_new_line = False
            print(tkn + " ", end='')
        else:
            if is_new_line:
                print("\n"+" "*indent, end='')
                is_new_line = False
            print(tkn + " ", end='')
    print("\n", end='')
    print("\n", end='')


def grid_desc_to_tensor(grid_desc):
    grid = torch.Tensor(IMG_FEAT).fill_(0)
    grid.index_fill_(0, grid_desc.long(), 1)
    grid = grid.view(IMG_SIZE)
    return grid


def translate(seq,
              vocab):
    return [vocab[str(elt)] for elt in seq]


def check_branch():
    '''
    path_to_dataset: File containing the data
    path_to_vocab: File containing the vocabulary
    '''
    train_file = 'datasets/karel/1m_6ex_karel/train.json'
    val_file = 'datasets/karel/1m_6ex_karel/val.json'
    path_to_vocab = 'datasets/karel/1m_6ex_karel/new_vocab.vocab'

    tgt_tkn2idx = {
        '<pad>': 0,
    }
    next_id = 1
    with open(path_to_vocab, 'r') as vocab_file:
        for line in vocab_file.readlines():
            tgt_tkn2idx[line.strip()] = next_id
            next_id += 1
    tgt_idx2tkn = {}
    for tkn, idx in tgt_tkn2idx.items():
        tgt_idx2tkn[idx] = tkn

    vocab = {"idx2tkn": tgt_idx2tkn,
             "tkn2idx": tgt_tkn2idx}
    simulator = Simulator(vocab["idx2tkn"])

    path_to_dataset = val_file
    path_to_ds_cache = path_to_dataset.replace('.json', '.thdump')
#    path_to_ds_cache = path_to_dataset.replace('.json', '.random.thdump')
    for path_to_ds_cache in [train_file.replace('.json', '.0.6.thdump')]:
        dataset = torch.load(path_to_ds_cache)

        branch_status = []
        single_rates = []
#        for idx in [100]: 
        step = int(len(dataset['sources'])/8)
        print(step)
        for idx in tqdm(range(0 * step, 1 * step)): 
            sample = dataset['sources'][idx]
            sample_inp_grids = []
            sample_out_grids = []
            sample_inp_worlds = []
            sample_out_worlds = []
            sample_test_inp_worlds = []
            sample_test_out_worlds = []
            for inp_grid_desc, out_grid_desc in sample:

                # Do the inp_grid
                inp_grid = grid_desc_to_tensor(inp_grid_desc)
                # Do the out_grid
                out_grid = grid_desc_to_tensor(out_grid_desc)
                sample_inp_grids.append(inp_grid)
                sample_out_grids.append(out_grid)

            targets = dataset["targets"][idx]
            covered, single_rate = run_grids(sample_inp_grids, targets, simulator)
            single_rates.append(single_rate)
#            if not covered:
#                print(idx, 'not covered')
            branch_status.append(covered)
#            print(targets)
#            write_program(targets, vocab["idx2tkn"], 0)
        print('Branch cover rate:', sum(branch_status)/len(branch_status))
        print(len(single_rates))
        plt.title('cover_rate')
        pt = sns.distplot(np.array(single_rates), label=path_to_ds_cache.split('/')[-1])
        pt.legend()
    plt.savefig('cover_rate_0.6.png')


def run_grids(inp_grids, out_tgt_seq, simulator):
    branch_checker = None
    for inp_grid in inp_grids:
        inp_world = World.fromPytorchTensor(inp_grid.long())
        out_tgt = out_tgt_seq
        parse_success, cand_prog = simulator.get_prog_ast(out_tgt)
        if (not parse_success):
            print('parse error!')
            print(out_tgt)
            sys.exit()
        res_emu = simulator.run_prog(cand_prog, inp_world)
        if branch_checker is None:
            branch_checker = torch.zeros(len(simulator.emulator.branch_flag)).int()
        branch_checker |= torch.tensor(simulator.emulator.branch_flag).int()
    if 0 in branch_checker or 1 in branch_checker or 2 in branch_checker:
        single_rate = ((branch_checker==3).sum() / len(branch_checker)).item()
        return False, single_rate 
    else:
        return True, 1.


if __name__ == '__main__':
    check_branch()
#    check_branch()
