import random
#import multiprocessing
#from multiprocessing import set_start_method
#from torch.multiprocessing import Pool, Process, set_start_method
#try:
#    set_start_method('spawn', force=True)
#except RuntimeError:
#    pass
import torch.nn.functional as F
from torch import nn
import numpy as np

import params
import torch
from model.model import BaseModel
from model.query_MI6 import QueryMI

from karel.world import World

class Combinar(BaseModel):
    def __init__(self, le=None):
        super(Combinar, self).__init__()
        self.query = QueryMI(vocab_size=params.vocab_size, input_dropout_p=0.0, dropout_p=0.0, rnn_cell='lstm')

    def forward(self, x, var, var_types, program, program_seq, step, drop, typ, plengths, query_step):
        query_inp, query_index = self.query(x, typ[:, :int(max(1, query_step))], program_seq, plengths, params.hard_softmax)
        pred, x, var, var_types = self.ps(query_index, query_inp, program, step, drop, query_step, x, var, var_types)
        return pred, x, var, var_types

    def ps(self, query_index, query_inp, program, step, drop, query_step, x, var, var_types):
        query_io, var_encoded, var_typ = self.env_step(query_index, query_inp, program, step, drop)
#        torch.ones(*x[:, :1].shape, device='cuda'), torch.ones(x.shape[0], 1, 12, 20, 513, device='cuda'), torch.zeros(x.shape[0], 1, 12, 2, device='cuda')
        if query_step > 0:
            x = torch.cat([x, query_io], 1)
            var = torch.cat([var, var_encoded], 1)
            var_types = torch.cat([var_types, var_typ], 1)
        else:
            x = query_io
            var = var_encoded
            var_types = var_typ
        pred = self.model(var, var_types)

        return pred, x, var, var_types
    
    def predict(self, x, var, var_types, program, step, drop, typ, query_step):
        query_inp, query_index = self.query.predict_from_io(x, typ[:, :int(max(1, query_step))], params.hard_softmax)
        pred, x, var, var_types = self.ps(query_index, query_inp, program, step, drop, query_step, x, var, var_types)
        return pred, x, var, var_types

    def env_step(self, query_inp, program_seq, simulator, hard_softmax=False):
        out_tgt_seq = program_seq[:, 1:]
        if hard_softmax:
            out_grids = self.get_query_out(query_inp, out_tgt_seq, simulator)
        else:
            out_grids = self.get_query_out_hard(query_inp, out_tgt_seq, simulator)
        return out_grids

    def get_query_out_hard(self, query_inp, out_tgt_seq, simulator):
        out_list = []

        query_tmp = query_inp[:,:,:,1:-1,1:-1]
        hero = query_tmp[:,:,:4,:,:]
        map = query_tmp[:,:,5:,:,:]
        index = hero.max(-1, keepdim=True)[1]
        hero_hard = torch.zeros_like(hero, memory_format=torch.legacy_contiguous_format).scatter_(-1, index, 1.0)
        index = map.max(2, keepdim=True)[1]
        map_hard = torch.zeros_like(map, memory_format=torch.legacy_contiguous_format).scatter_(2, index, 1.0)
        zeros = torch.zeros(*map_hard.shape[0:2], 1, *map_hard.shape[-2:], device='cuda')
        query_hard = torch.cat([hero_hard, zeros, map_hard], 2)
        query_hard = F.pad(query_hard, (1,1,1,1))

        query_hard[:,:,5,:,0] = 1
        query_hard[:,:,5,0,:] = 1
        query_hard[:,:,5,:,-1] = 1
        query_hard[:,:,5,-1,:] = 1

        for idx in range(out_tgt_seq.shape[0]):
            torch.set_printoptions(profile="full")
            inp_world = World.fromPytorchTensor2(query_hard[idx][0].detach().cpu().long())
    ###        inp_world = World.fromPytorchTensor2(query_inp_hard[idx][0].detach().cpu().long())
            out_tgt = out_tgt_seq[idx].cpu().numpy()
            out_tgt = np.trim_zeros(out_tgt)
            parse_success, cand_prog = simulator.get_prog_ast(out_tgt.tolist())
            if (not parse_success):
                print('parse error!')
                print(out_tgt_seq[idx])
                #write_program('./log.tmp', out_tgt.tolist(), vocab["idx2tkn"])
                continue
            res_emu = simulator.run_prog(cand_prog, inp_world)
            out_grid = res_emu.outgrid.toPytorchTensor2(query_inp.shape[-1])
            out_list.append(out_grid)
        query_out = torch.stack(out_list, 0).unsqueeze(1).cuda()
        return query_out 

    def get_query_out(self, query_inp, out_tgt_seq, simulator):
        out_list = []
        for idx in range(out_tgt_seq.shape[0]):
            torch.set_printoptions(profile="full")
            inp_world = World.fromPytorchTensor2(query_inp[idx][0].detach().cpu().long())
    ###        inp_world = World.fromPytorchTensor2(query_inp_hard[idx][0].detach().cpu().long())
            out_tgt = out_tgt_seq[idx].cpu().numpy()
            out_tgt = np.trim_zeros(out_tgt)
            parse_success, cand_prog = simulator.get_prog_ast(out_tgt.tolist())
            if (not parse_success):
                print('parse error!')
                print(out_tgt_seq[idx])
                #write_program('./log.tmp', out_tgt.tolist(), vocab["idx2tkn"])
                continue
            res_emu = simulator.run_prog(cand_prog, inp_world)
            out_grid = res_emu.outgrid.toPytorchTensor2(query_inp.shape[-1])
            out_list.append(out_grid)
        query_out = torch.stack(out_list, 0).unsqueeze(1).cuda()
        return query_out 