import torch
from model.combinar_MI6 import CombinarMI
from utils import *
from utils_hd import *
from data.data import (load_input_file, load_input_file_seq, get_minibatch, shuffle_dataset,
                     KarelDataset, QueryDataset, KarelDatasetNoWorlds, KarelDatasetSeq, BucketingSampler)
from karel.consistency import Simulator
from torch.autograd import Variable
from cuda import use_cuda, LongTensor, FloatTensor
import sys

from scripts.get_map import get_map as get_map_nosize 

import params

torch.manual_seed(params.seed)
torch.cuda.manual_seed_all(params.seed)
np.random.seed(params.seed)
random.seed(params.seed)

random = False 

train_file ='datasets/karel/1m_6ex_karel/train.json' 
val_file ='datasets/karel/1m_6ex_karel/val.json' 
test_file ='datasets/karel/1m_6ex_karel/test.json' 

save_dir = sys.argv[1] 
print(save_dir, 'random =', random)
model = CombinarMI().cuda()
main_params = model.main_parameters()
main_learning_rate = 1e-4 
main_optimizer = torch.optim.Adam(main_params, lr=main_learning_rate)
models = [model]
models_name = ['model']
optims = [main_optimizer]
optims_name = ['main_optimizer']


train_file_dataset, vocab = load_input_file_seq(train_file, params.vocab_file)
val_file_dataset, _ = load_input_file_seq(val_file, params.vocab_file)
test_file_dataset, _ = load_input_file_seq(test_file, params.vocab_file)
vocabulary_size = len(vocab["tkn2idx"])
tgt_start = vocab["tkn2idx"]["<s>"]
tgt_end = vocab["tkn2idx"]["m)"]
tgt_pad = vocab["tkn2idx"]["<pad>"]
print(len(vocab["tkn2idx"]))
nb_ios = 5 
batch_size = 32 
train_dataset = KarelDatasetSeq(train_file_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
val_dataset = KarelDatasetSeq(val_file_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
test_dataset = KarelDatasetSeq(test_file_dataset, tgt_start, tgt_end, tgt_pad, nb_ios, shuffle=False)
#train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_data_loader = DataLoader(train_dataset, shuffle=False, pin_memory=False,
                    batch_size=batch_size,
                    num_workers=0, collate_fn=query_collate)
#val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
val_data_loader = DataLoader(val_dataset, shuffle=False, pin_memory=False,
                    batch_size=batch_size,
                    num_workers=0, collate_fn=query_collate)
test_data_loader = DataLoader(test_dataset, shuffle=False, pin_memory=False,
                    batch_size=batch_size,
                    num_workers=0, collate_fn=query_collate)

simulator = Simulator(vocab["idx2tkn"])

with torch.no_grad():
    if not random:
        ckpt_path = save_dir + '/model-best'
        print("=> loading checkpoint '{}'".format(ckpt_path))
        loaded_models, loaded_opts, start_epoch, best_val_error = \
            load_checkpoint(models, optims, models_name, optims_name, ckpt_path)
        print(len(loaded_models))
        model = loaded_models[0]
    model.eval()                    

    train_ios = []
    pbar = tqdm(train_data_loader)
    for batch in pbar:
        program_seq = Variable(batch[2].type(LongTensor))
        current_ios = generate_batch_ios(model, program_seq, simulator, vocab, random=random)
        train_ios.extend(current_ios)
    train_file_dataset['sources'] = train_ios
    torch.save(train_file_dataset, save_dir + '/train.thdump')
    
    val_ios = []
    pbar = tqdm(val_data_loader)
    for batch in pbar:
        program_seq = Variable(batch[2].type(LongTensor))
        current_ios = generate_batch_ios(model, program_seq, simulator, vocab, random=random)
        val_ios.extend(current_ios)
        #print(program_seq)
    val_file_dataset['sources'] = val_ios
    #print(val_file_dataset['targets'][:10])
    torch.save(val_file_dataset, save_dir + '/val.thdump') 

    test_ios = []
    pbar = tqdm(test_data_loader)
    for batch in pbar:
        program_seq = Variable(batch[2].type(LongTensor))
        current_ios = generate_batch_ios(model, program_seq, simulator, vocab, random=random)
        test_ios.extend(current_ios)
        #print(program_seq)
    test_file_dataset['sources'] = test_ios
    torch.save(test_file_dataset, save_dir + '/test.thdump') 
