#!/usr/bin/env python
import argparse

from nps.evaluate import evaluate_model, add_beam_size_arg, add_eval_args
from nps.network import add_model_cli_args
from nps.utils import add_common_arg, s2intL
import os
import torch
from pathlib import Path
from nps.network import IOs2Seq

def load_checkpoint(model, ckpt_path):
    path_to_weight_dump = ckpt_path
    weight_ckpt = torch.load(path_to_weight_dump)

    raw_model = model.module if hasattr(model, "module") else model
    raw_model.load_state_dict(weight_ckpt['model'])
    loaded_model = model

    return loaded_model

def init_distributed_mode():
    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
        rank = int(os.environ["RANK"])
        world_size = int(os.environ['WORLD_SIZE'])
        gpu = int(os.environ['LOCAL_RANK'])
    elif 'SLURM_PROCID' in os.environ:
        rank = int(os.environ['SLURM_PROCID'])
        gpu = rank % torch.cuda.device_count()
    else:
        print('Not using distributed mode')
        distributed = False
        return
    distributed = True

    torch.cuda.set_device(gpu)
    dist_backend = 'nccl'
    dist_url = 'env://'
    print('| distributed init (rank {}): {}'.format(
        rank, dist_url), flush=True)
    torch.distributed.init_process_group(backend=dist_backend, init_method=dist_url,
                                         world_size=world_size, rank=rank)
    setup_for_distributed(rank == 0)
    return gpu
    
def setup_for_distributed(is_master):
    """    This function disables printing when not in master process
    """
    import builtins as __builtin__
    builtin_print = __builtin__.print

    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        if is_master or force:
            builtin_print(*args, **kwargs)

    __builtin__.print = print


parser = argparse.ArgumentParser(description='Evaluate a Seq2Seq model on type prediction')

# What we need to run
parser.add_argument("--model_weights", type=str,
                    default="exps/fake_run/Weights/weights_1.model",
                    help="Weights of the model to evaluate")
parser.add_argument("--vocabulary", type=str,
                    default="data/1m_6ex_karel/new_vocab.vocab",
                    help="Vocabulary of the trained model")
parser.add_argument("--dataset", type=str,
                    default="data/1m_6ex_karel/val.json",
                    help="Dataset to evaluate against")
parser.add_argument("--dump_programs", action="store_true")
parser.add_argument("--random_test", action="store_true")

add_model_cli_args(parser)
add_beam_size_arg(parser)
add_eval_args(parser)
add_common_arg(parser)

args = parser.parse_args()

kernel_size = args.kernel_size
conv_stack = s2intL(args.conv_stack)
fc_stack = s2intL(args.fc_stack)
vocabulary_size = 52
tgt_embedding_size = args.tgt_embedding_size
lstm_hidden_size = args.lstm_hidden_size
nb_lstm_layers = args.nb_lstm_layers
learn_syntax = args.learn_syntax
batch_size = 32


gpu = init_distributed_mode()

save_dir = Path('/'.join(args.dataset.split('/')[:-1]))
result_dir = save_dir

models_dir = result_dir / "Weights"
path_to_weight_dump = models_dir / "best.model"
#for epoch in range(100):
#    path_to_weight_dump = models_dir / ("weights_%d.model" % epoch)
#    if not path_to_weight_dump.exists():
#        break
print(path_to_weight_dump)
model = IOs2Seq(kernel_size, conv_stack, fc_stack,
                vocabulary_size, tgt_embedding_size,
                lstm_hidden_size, nb_lstm_layers,
                learn_syntax)
device = torch.cuda.current_device()
model = model.to(device)
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu], find_unused_parameters=True)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

if path_to_weight_dump.exists():
    print('checkpoint loaded')
    model = load_checkpoint(model, path_to_weight_dump)

print(args.dataset)

for nb_ios in range(1, 6):
    print('nb_ios:', nb_ios)
    output_path = str(result_dir / ("eval/%d/val_.txt" % nb_ios))
    evaluate_model(model,
                   args.model_weights,
                   args.vocabulary,
                   args.dataset,
                   nb_ios,
                   0,
                   args.use_grammar,
                   output_path,
                   100,
                   50,
                   batch_size,
                   True,
                   args.dump_programs,
                   args.random_test)
