from __future__ import division
# External imports
import json
import os
import torch

from torch.autograd import Variable
from tqdm import tqdm

from nps.data import load_input_file, get_minibatch, shuffle_dataset, KarelDataset, KarelDatasetWithRandom
from karel.consistency import Simulator
from syntax.checker import PySyntaxChecker

import torch.distributed as dist

import torch.distributed as dist
from torch.utils.data.dataloader import DataLoader
from torch.nn.utils.rnn import pad_sequence
import random

def add_eval_args(parser):
    parser.add_argument('--eval_nb_ios', type=int,
                        default=5)
    parser.add_argument('--use_grammar', action="store_true")
    parser.add_argument("--val_nb_samples", type=int,
                        default=0,
                        help="How many samples to use to compute the accuracy."
                        "Default: %(default)s, for all the dataset")

def add_beam_size_arg(parser):
    parser.add_argument("--eval_batch_size", type=int,
                        default=8)
    parser.add_argument("--beam_size", type=int,
                        default=10,
                        help="Size of the beam search. Default %(default)s")
    parser.add_argument("--top_k", type=int,
                        default=5,
                        help="How many candidates to return. Default %(default)s")

def load_checkpoint(model, optimizer, ckpt_path):
    path_to_weight_dump = ckpt_path
    weight_ckpt = torch.load(path_to_weight_dump)

    raw_model = model.module if hasattr(model, "module") else model
    raw_model.load_state_dict(weight_ckpt['model'])
    loaded_model = model

    optimizer.load_state_dict(weight_ckpt['optimizer'])
    loaded_optimizer = optimizer
    start_epoch = weight_ckpt['epoch'] + 1
    best_val_acc = weight_ckpt['best_val_acc']
    return loaded_model, loaded_optimizer, start_epoch, best_val_acc

def evaluate_model(model,
                   model_weights,
                   vocabulary_path,
                   dataset_path,
                   nb_ios,
                   nb_samples,
                   use_grammar,
                   output_path,
                   beam_size,
                   top_k,
                   batch_size,
                   use_cuda,
                   dump_programs,
                   random_test=False):
    all_outputs_path = []
    all_semantic_output_path = []
    all_syntax_output_path = []
    all_generalize_output_path = []
    all_random_output_path = []
    res_dir = os.path.dirname(output_path)
    if not os.path.exists(res_dir) and dist.get_rank() == 0:
        os.makedirs(res_dir)
    dist.barrier()
    for k in range(top_k):

        new_term = "exactmatch_top%d.txt" % (k+1)
        new_semantic_term = "semantic_top%d.txt" % (k+1)
        new_syntax_term = "syntax_top%d.txt" % (k+1)
        new_generalize_term = "fullgeneralize_top%d.txt" % (k+1)
        new_random_term = "random_top%d.txt" % (k+1)

        new_file_name = output_path + new_term
        new_semantic_file_name = output_path + new_semantic_term
        new_syntax_file_name = output_path + new_syntax_term
        new_generalize_file_name = output_path + new_generalize_term
        new_random_file_name = output_path + new_random_term

        all_outputs_path.append(new_file_name)
        all_semantic_output_path.append(new_semantic_file_name)
        all_syntax_output_path.append(new_syntax_file_name)
        all_generalize_output_path.append(new_generalize_file_name)
        all_random_output_path.append(new_random_file_name)
    program_dump_path = os.path.join(res_dir, "generated")

    #if os.path.exists(all_outputs_path[0]):
    #    with open(all_outputs_path[0], "r") as out_file:
    #        out_file_content = out_file.read()
    #        print("Using cached result from {}".format(all_outputs_path[0]))
    #        print("Greedy select accuracy: {}".format(out_file_content))
    #        return

    # Load the vocabulary of the trained model
    dataset, vocab = load_input_file(dataset_path, vocabulary_path)
    tgt_start = vocab["tkn2idx"]["<s>"]
    tgt_end = vocab["tkn2idx"]["m)"]
    tgt_pad = vocab["tkn2idx"]["<pad>"]

    simulator = Simulator(vocab["idx2tkn"])
    # Load the model
    #if not use_cuda:
    #    # https://discuss.pytorch.org/t/on-a-cpu-device-how-to-load-checkpoint-saved-on-gpu-device/349/8
    #    # Is it failing?
    #    model = torch.load(model_weights, map_location=lambda storage, loc: storage)
    #else:
#   #     model = torch.load(model_weights)
#   #     model.cuda()
    #    model_dict = torch.load(model_weights)
    #    raw_model = model.module
    #    raw_model.load_state_dict(model_dict)
    ## And put it into evaluation mode
    model.eval()

    syntax_checker = PySyntaxChecker(vocab["tkn2idx"], use_cuda)
    if use_grammar:
        model.module.set_syntax_checker(syntax_checker)

    if beam_size == 1:
        top_k = 1
    nb_correct = [0 for _ in range(top_k)]
    nb_semantic_correct = [0 for _ in range(top_k)]
    nb_syntax_correct = [0 for _ in range(top_k)]
    nb_generalize_correct = [0 for _ in range(top_k)]
    nb_random_correct = [0 for _ in range(top_k)]
    total_nb = 0

    def my_collate(batch):
        inp_grids = torch.stack([item[0] for item in batch], 0)
        out_grids = torch.stack([item[1] for item in batch], 0)
        in_tgt_seq = [item[2] for item in batch]
        input_lines = [item[3] for item in batch]
        out_tgt_seq = [item[4] for item in batch]
        inp_worlds = [item[5] for item in batch]
        out_worlds = [item[6] for item in batch]
        random_inp_worlds = [item[7] for item in batch]
        random_out_worlds = [item[8] for item in batch]
        targets = [item[9] for item in batch]
        inp_test_worlds = [item[10] for item in batch]
        out_test_worlds = [item[11] for item in batch]

        in_tgt_seq = pad_sequence(in_tgt_seq, batch_first=True, padding_value=tgt_pad)
        in_tgt_seq = in_tgt_seq[:, :-1]
        out_tgt_seq = pad_sequence(out_tgt_seq, batch_first=True, padding_value=tgt_pad)

        input_lines = in_tgt_seq.tolist()

        return [inp_grids, out_grids, in_tgt_seq, input_lines, out_tgt_seq, \
            inp_worlds, out_worlds, random_inp_worlds, random_out_worlds, targets, inp_test_worlds, out_test_worlds]

    if random_test:
        if 'val' in dataset_path:
            random_path = 'datasets/karel/1m_6ex_karel/val.random.json'
        elif 'test' in dataset_path:
            random_path = 'datasets/karel/1m_6ex_karel/test.random.json'
        print('random:', random_path)
        random_data, _ = load_input_file(random_path, vocabulary_path)
        val_dataset = KarelDatasetWithRandom(dataset, random_data, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
    else:
        val_dataset = KarelDataset(dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
    val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
    val_loader = DataLoader(val_dataset, shuffle=False, pin_memory=True,
                        batch_size=batch_size, sampler=val_sampler,
                        num_workers=0, collate_fn=my_collate)
#    dataset = shuffle_dataset(dataset, batch_size, randomize=False)
    pbar = tqdm(enumerate(val_loader), total=len(val_loader))
    for batch_idx, (inp_grids, out_grids,
        in_tgt_seq, in_tgt_seq_list, out_tgt_seq,
        inp_worlds, out_worlds,
        random_inp_worlds, random_out_worlds,
        targets,
        inp_test_worlds, out_test_worlds) in pbar:
        #print(inp_grids.shape)
        sp_idx = batch_idx * batch_size
#    for sp_idx in tqdm(range(0, len(dataset["sources"]), batch_size)):


#        random_ios = random.randint(1,2)
#        inp_grids = inp_grids[:, :random_ios]
#        out_grids = out_grids[:, :random_ios]

#        inp_grids, out_grids, \
#        in_tgt_seq, in_tgt_seq_list, out_tgt_seq, \
#        inp_worlds, out_worlds, \
#        _, \
#        inp_test_worlds, out_test_worlds = get_minibatch(dataset, sp_idx, batch_size,
#                                                         tgt_start, tgt_end, tgt_pad,
#                                                         nb_ios, shuffle=False, volatile_vars=True)

        max_len = out_tgt_seq.size(1) + 10
        if use_cuda:
            inp_grids, out_grids = inp_grids.cuda(), out_grids.cuda()
            in_tgt_seq, out_tgt_seq = in_tgt_seq.cuda(), out_tgt_seq.cuda()

        if dump_programs:
            import numpy as np
            decoder_logit, syntax_logit = model(inp_grids, out_grids, in_tgt_seq, in_tgt_seq_list)
            if syntax_logit is not None and model.module.decoder.learned_syntax_checker is not None:
                syntax_logit = syntax_logit.cpu().data.numpy()
                for n in range(in_tgt_seq.size(0)):
                    decoded_dump_dir = os.path.join(program_dump_path, str(n + sp_idx))
                    if not os.path.exists(decoded_dump_dir) and dist.get_rank() == 0:
                        os.makedirs(decoded_dump_dir)
                    dist.barrier()
                    seq = in_tgt_seq.cpu().data.numpy()[n].tolist()
                    seq_len = seq.index(0) if 0 in seq else len(seq)
                    file_name = str(n) + "_learned_syntax"
                    norm_logit = syntax_logit[n,:seq_len]
                    norm_logit = np.log(-norm_logit)
                    norm_logit = 1 / (1 + np.exp(-norm_logit))
                    np.save(os.path.join(decoded_dump_dir, file_name), norm_logit)
                    ini_state = syntax_checker.get_initial_checker_state()
                    file_name = str(n) + "_manual_syntax"
                    mask = syntax_checker.get_sequence_mask(ini_state, seq).squeeze().cpu().numpy()[:seq_len]
                    np.save(os.path.join(decoded_dump_dir, file_name), mask)
                    file_name = str(n) + "_diff"
                    diff = mask.astype(float) - norm_logit
                    diff = (diff + 1) / 2 # remap to [0,1]
                    np.save(os.path.join(decoded_dump_dir, file_name), diff)

        decoded = model.module.beam_sample(inp_grids, out_grids,
                                    tgt_start, tgt_end, max_len,
                                    beam_size, top_k)
        for batch_idx, (target, sp_decoded,
                        sp_input_worlds, sp_output_worlds,
                        sp_test_input_worlds, sp_test_output_worlds,
                        sp_random_input_worlds, sp_random_output_worlds) in \
            enumerate(zip(out_tgt_seq.chunk(out_tgt_seq.size(0)), decoded,
                          inp_worlds, out_worlds,
                          inp_test_worlds, out_test_worlds,
                          random_inp_worlds, random_out_worlds)):

            total_nb += 1
            target = target.cpu().data.squeeze().numpy().tolist()
            target = [tkn_idx for tkn_idx in target if tkn_idx != tgt_pad]

            if dump_programs:
                decoded_dump_dir = os.path.join(program_dump_path, str(batch_idx + sp_idx))
                if not os.path.exists(decoded_dump_dir) and dist.get_rank() == 0:
                    os.makedirs(decoded_dump_dir)
                dist.barrier()
                write_program(os.path.join(decoded_dump_dir, "target"), target, vocab["idx2tkn"])
                for rank, dec in enumerate(sp_decoded):
                    pred = dec[1]
                    ll = dec[0]
                    file_name = str(rank)+ " - " + str(ll)
                    write_program(os.path.join(decoded_dump_dir, file_name), pred, vocab["idx2tkn"])


            # Exact matches
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                if pred == target:
                    # This prediction is correct. This means that we score for
                    # all the following scores
                    for top_idx in range(rank, top_k):
                        nb_correct[top_idx] += 1
                    break

            # Semantic matches
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                parse_success, cand_prog = simulator.get_prog_ast(pred)
                if (not parse_success):
                    continue
                semantically_correct = True
                for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds):
                    res_emu = simulator.run_prog(cand_prog, input_world)
                    if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                        # This prediction is semantically incorrect.
                        semantically_correct = False
                        break
                if semantically_correct:
                    # Score for all the following ranks
                    for top_idx in range(rank, top_k):
                        nb_semantic_correct[top_idx] += 1
                    break

            # Generalization
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                parse_success, cand_prog = simulator.get_prog_ast(pred)
                if (not parse_success):
                    continue
                generalizes = True
                for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds):
                    res_emu = simulator.run_prog(cand_prog, input_world)
                    if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                        # This prediction is semantically incorrect.
                        generalizes = False
                        break
                for (input_world, output_world) in zip(sp_test_input_worlds, sp_test_output_worlds):
                    res_emu = simulator.run_prog(cand_prog, input_world)
                    if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                        # This prediction is semantically incorrect.
                        generalizes = False
                        break
                if generalizes:
                    # Score for all the following ranks
                    for top_idx in range(rank, top_k):
                        nb_generalize_correct[top_idx] += 1
                    break

            # Random 
            if random_test:
                for rank, dec in enumerate(sp_decoded):
                    pred = dec[-1]
                    parse_success, cand_prog = simulator.get_prog_ast(pred)
                    if (not parse_success):
                        continue
                    random_generalizes = True
                    for (input_world, output_world) in zip(sp_input_worlds, sp_output_worlds):
                        res_emu = simulator.run_prog(cand_prog, input_world)
                        if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                            # This prediction is semantically incorrect.
                            random_generalizes = False
                            break
                    for (input_world, output_world) in zip(sp_random_input_worlds, sp_random_output_worlds):
                        res_emu = simulator.run_prog(cand_prog, input_world)
                        if (res_emu.status != 'OK') or res_emu.crashed or (res_emu.outgrid != output_world):
                            # This prediction is semantically incorrect.
                            random_generalizes = False
                            break
                    if random_generalizes:
                        # Score for all the following ranks
                        for top_idx in range(rank, top_k):
                            nb_random_correct[top_idx] += 1
                        break

            # Correct syntaxes
            for rank, dec in enumerate(sp_decoded):
                pred = dec[-1]
                parse_success, cand_prog = simulator.get_prog_ast(pred)
                if parse_success:
                    for top_idx in range(rank, top_k):
                        nb_syntax_correct[top_idx] += 1
                    break

    for k in range(top_k):
        t1 = torch.tensor([100*nb_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t2 = torch.tensor([100*nb_semantic_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t3 = torch.tensor([100*nb_syntax_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t4 = torch.tensor([100*nb_generalize_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        t8 = torch.tensor([100*nb_random_correct[k]/total_nb, 1], dtype=torch.float64, device='cuda')
        dist.barrier() 
        dist.all_reduce(t1)
        dist.all_reduce(t2)
        dist.all_reduce(t3)
        dist.all_reduce(t4)
        dist.all_reduce(t8)
        t1, t2, t3, t4, t8 = t1.tolist(), t2.tolist(), t3.tolist(), t4.tolist(), t8.tolist()
        with open(str(all_outputs_path[k]), "w") as res_file:
#            res_file.write(str(100*nb_correct[k]/total_nb))
            res_file.write(str(t1[0]/t1[1]))
        with open(str(all_semantic_output_path[k]), "w") as sem_res_file:
#            sem_res_file.write(str(100*nb_semantic_correct[k]/total_nb))
            sem_res_file.write(str(t2[0]/t2[1]))
        with open(str(all_syntax_output_path[k]), "w") as stx_res_file:
#            stx_res_file.write(str(100*nb_syntax_correct[k]/total_nb))
            stx_res_file.write(str(t3[0]/t3[1]))
        with open(str(all_generalize_output_path[k]), "w") as gen_res_file:
#            gen_res_file.write(str(100*nb_generalize_correct[k]/total_nb))
            gen_res_file.write(str(t4[0]/t4[1]))
        with open(str(all_random_output_path[k]), "w") as ran_res_file:
#            gen_res_file.write(str(100*nb_generalize_correct[k]/total_nb))
            ran_res_file.write(str(t8[0]/t8[1]))

    t5 = torch.tensor([total_nb], dtype=torch.float64, device='cuda')
    t6 = torch.tensor([nb_semantic_correct[0]], dtype=torch.float64, device='cuda')
    t7 = torch.tensor([nb_correct[0]], dtype=torch.float64, device='cuda')
    t9 = torch.tensor([nb_random_correct[0]], dtype=torch.float64, device='cuda')
    dist.barrier() 
    dist.all_reduce(t5)
    dist.all_reduce(t6)
    dist.all_reduce(t7)
    dist.all_reduce(t9)
    t5, t6, t7, t9 = t5.tolist(), t6.tolist(), t7.tolist(), t9.tolist()
    print('total_nb:', t5[0])
    print('nb_semantic_correct[0]:', t6[0])
    print('nb_correct[0]:', t7[0])
    print('random_correct[0]:', t9[0])
    semantic_at_one = t6[0] / t5[0]
    exact_at_one = t7[0] / t5[0]
    random_at_one = t9[0] / t5[0]
    print(semantic_at_one)
    print(exact_at_one)
    print(random_at_one)
    return semantic_at_one, exact_at_one


def write_program(path, tkn_idxs, vocab):
    program_tkns = [vocab[tkn_idx] for tkn_idx in tkn_idxs]

    indent = 0
    is_new_line = False
    with open(path, "w") as target_file:
        for tkn in program_tkns:
            if tkn in ["m(", "w(", "i(", "e(", "r("]:
                indent += 4
                target_file.write("\n"+" "*indent)
                target_file.write(tkn + " ")
                is_new_line = False
            elif tkn in ["m)", "w)", "i)", "e)", "r)"]:
                if is_new_line:
                    target_file.write("\n"+" "*indent)
                indent -= 4
                target_file.write(tkn)
                if indent < 0:
                    indent = 0
                is_new_line = True
            elif tkn in ["REPEAT"]:
                if is_new_line:
                    target_file.write("\n"+" "*indent)
                    is_new_line = False
                target_file.write(tkn + " ")
            else:
                if is_new_line:
                    target_file.write("\n"+" "*indent)
                    is_new_line = False
                target_file.write(tkn + " ")
        target_file.write("\n")
