import os
import sys
import torch
import torch.distributed as dist
from torch.nn.utils.rnn import pad_sequence
import torch.autograd as autograd
from karel.world import World
import time
from pathlib import Path
import json
import logging
import numpy as np
import torch.nn.functional as F
import random

#from nps.evaluate import evaluate_enc, evaluate_finetune
def get_first_io(out_tgt_seq, simulator, vocab):
    inp_grids = torch.zeros(out_tgt_seq.shape[0], 1, 16, 18, 18).cuda()
    inp_grids[:, :, 5, 0, :] = 1.
    inp_grids[:, :, 5, -1, :] = 1.
    inp_grids[:, :, 5, :, 0] = 1.
    inp_grids[:, :, 5, :, -1] = 1.
    inp_grids[:, :, 0, 9, 9] = 1.
    out_list = []
    out_grids = inp_grids
    if False:
        for idx in range(out_tgt_seq.shape[0]):
            inp_world = World.fromPytorchTensor2(inp_grids[idx][0].detach().cpu().long())
            out_tgt = out_tgt_seq[idx][1:].cpu().numpy()
            out_tgt = np.trim_zeros(out_tgt)
            parse_success, cand_prog = simulator.get_prog_ast(out_tgt.tolist())
            if (not parse_success):
                print('parse error!')
                print(out_tgt_seq[idx])
                #write_program('./log.tmp', out_tgt.tolist(), vocab["idx2tkn"])
                continue
            res_emu = simulator.run_prog(cand_prog, inp_world)
            out_grid = res_emu.outgrid.toPytorchTensor2(inp_grids.shape[-1])
            out_list.append(out_grid)
        out_grids = torch.stack(out_list, 0).unsqueeze(1).cuda()
    return inp_grids, out_grids

def my_collate(batch):
    inp_grids = torch.stack([item[0] for item in batch], 0)
    out_grids = torch.stack([item[1] for item in batch], 0)
    in_tgt_seq = [item[2] for item in batch]
    out_tgt_seq = [item[4] for item in batch]
    inp_worlds = [item[5] for item in batch]
    out_worlds = [item[6] for item in batch]
    targets = [item[7] for item in batch]
    inp_test_worlds = [item[8] for item in batch]
    out_test_worlds = [item[9] for item in batch]

    in_tgt_seq = pad_sequence(in_tgt_seq, batch_first=True, padding_value=0)
    in_tgt_seq = in_tgt_seq[:, :-1]
    out_tgt_seq = pad_sequence(out_tgt_seq, batch_first=True, padding_value=0)
    
    input_lines = in_tgt_seq.tolist()

    return [inp_grids, out_grids, in_tgt_seq, input_lines, out_tgt_seq, \
        inp_worlds, out_worlds, targets, inp_test_worlds, out_test_worlds]

def query_collate(batch):
    inp_grids = torch.stack([item[0] for item in batch], 0)
    out_grids = torch.stack([item[1] for item in batch], 0)
    in_tgt_seq = [item[2] for item in batch]
    seq_len = torch.stack([item[3] for item in batch], 0)

    in_tgt_seq = pad_sequence(in_tgt_seq, batch_first=True, padding_value=0)
    in_tgt_seq = in_tgt_seq
    
    return [inp_grids, out_grids, in_tgt_seq, seq_len]

def expand_img(simulator, inp_grids, seq):
    inp_grids[:, :, 5, 0, :] = 1
    inp_grids[:, :, 5, :, 0] = 1
    inp_grids[:, :, 5, -1, :] = 1
    inp_grids[:, :, 5, :, -1] = 1

    out_grids_list = []
    inp_grids_list = []
    for batch in range(inp_grids.shape[0]):
        out_tgt = seq[batch].cpu().numpy()
        out_tgt = np.trim_zeros(out_tgt)
        parse_success, cand_prog = simulator.get_prog_ast(out_tgt.tolist())
        if (not parse_success):
            print('parse error!')
            print(seq[idx])
            #write_program('./log.tmp', out_tgt.tolist(), vocab["idx2tkn"])
            continue
        out_list = []
        inp_list = []
        for idx in range(inp_grids.shape[1]):
            inp_world = World.fromPytorchTensor(inp_grids[batch][idx].detach().cpu().long())
            res_emu = simulator.run_prog(cand_prog, inp_world)
            out_grid = res_emu.outgrid.toPytorchTensor(inp_grids.shape[-1])
            inp_grid = inp_world.toPytorchTensor(inp_grids.shape[-1])
            out_list.append(out_grid)
            inp_list.append(inp_grid)
        out = torch.stack(out_list, 0).cuda()
        inp = torch.stack(inp_list, 0).cuda()
        out_grids_list.append(out)
        inp_grids_list.append(inp)
    expanded_out_grids = torch.stack(out_grids_list, 0).cuda()
    expanded_inp_grids = torch.stack(inp_grids_list, 0).cuda()
    return expanded_inp_grids, expanded_out_grids

def save_data(inp_grids, out_grids, out_tgt_seq, data_list):
    batch_list = []
    for seq_idx in range(inp_grids.shape[0]):
        grid_list = []
        for io_idx in range(inp_grids.shape[1]):
            grid_list.append(
                (inp_grids[seq_idx][io_idx].view(-1).nonzero(as_tuple=False).view(-1).short().data,
                 out_grids[seq_idx][io_idx].view(-1).nonzero(as_tuple=False).view(-1).short().data))
        save_seq = out_tgt_seq[seq_idx][:out_tgt_seq[seq_idx].nonzero(as_tuple=False)[-1] + 1].data
        batch_list.append((grid_list, save_seq))
    data_list.extend(batch_list)

def dump_log(epoch_idx,
             batch_idx,
             losses,
             recent_losses):
     t1 = torch.tensor([sum(recent_losses)/len(recent_losses), 1], dtype=torch.float64, device='cuda')
     dist.barrier() 
     dist.all_reduce(t1)
     t1 = t1.tolist()
     loss_mean = t1[0] / t1[1]

     if dist.get_rank() == 0:
        logging.info('Epoch : %d Minibatch : %d Loss : %.5f' % (
            epoch_idx, batch_idx, loss_mean)
        )
     losses.append(loss_mean)

def description(epoch_idx, losses_name, losses, optimizer, pbar):
    description = f"epoch {epoch_idx} "
    for i in range(len(losses)):
        description += (losses_name[i] + f" {losses[i].item():.3f} ")
    description += f"lr {optimizer.state_dict()['param_groups'][0]['lr']:.3e}"
    pbar.set_description(description)

class TrainSignal(object):
    SUPERVISED = "supervised"
    RL = "rl"
    BEAM_RL = "beam_rl"