import random
#import multiprocessing
#from multiprocessing import set_start_method
#from torch.multiprocessing import Pool, Process, set_start_method
#try:
#    set_start_method('spawn', force=True)
#except RuntimeError:
#    pass
import torch.nn.functional as F
from torch import nn

import params
import torch
from model.model import BaseModel, PCCoder
from model.query import Query

from dsl.program import Program
from dsl.example import Example
from dsl.value import Value

from env.statement import Statement, statement_to_index
from env.env import ProgramEnv


class Combinar(BaseModel):
    def __init__(self, le):
        super(Combinar, self).__init__()
        self.le = le
        if not params.diff_embedding:
            self.embedding = nn.Linear(params.integer_range + 1, params.embedding_size, bias=False)
            torch.nn.init.normal_(self.embedding.weight)
            self.model = PCCoder(self.embedding)
            self.query = Query(self.embedding)
        else:
            self.embedding = nn.Linear(params.integer_range + 1, params.embedding_size, bias=False)
            torch.nn.init.normal_(self.embedding.weight)
            self.embedding_q = nn.Linear(params.integer_range + 1, params.embedding_size, bias=False)
            torch.nn.init.normal_(self.embedding_q.weight)
            self.model = PCCoder(self.embedding)
            self.query = Query(self.embedding_q)

    def forward(self, x, var, var_types, program, step, drop, typ, query_step):
        #print('query_step:', query_step)
        #print('x:', x.shape)
        #print('typ:', typ.shape)
#        print('query_inp:', query_inp.shape)
#        print(x.argmax(-1))
        query_inp, query_index = self.query(x, typ[:, :int(max(1, query_step))],
                                            hard_softmax=params.hard_softmax)
#        if i >= 1:
#            distance.append((query_inp - x[:, 1:, :3])**2)
##        query_out_index = self.get_query_output(program, query_index)
        #TODO: how to choose var_start?
        query_io, var_encoded, var_typ = self.env_step(query_index, query_inp, program, step, drop)
#        query_io, var_encoded, var_typ = torch.ones(*x[:, :1].shape, device='cuda'), torch.ones(x.shape[0], 1, 12, 20, 513, device='cuda'), torch.zeros(x.shape[0], 1, 12, 2, device='cuda')
        if query_step > 0:
            x = torch.cat([x, query_io], 1)
            var = torch.cat([var, var_encoded], 1)
            var_types = torch.cat([var_types, var_typ], 1)
        else:
            x = query_io
            var = var_encoded
            var_types = var_typ
#        query_io_index = torch.cat([query_index, query_out_index], axis=2)
#        query_io_index2 = query_io.argmax(-1)
#        print(query_io_index2)
        #print('var:', var.shape)
        #print('var:', var.requires_grad)
        #print('query_io:', query_io.shape)
        #print('query_io:', query_io.requires_grad)
        #print('var_types:', var_types.shape)
        #print('var_types:', var_types.requires_grad)
        pred = self.model(var, var_types)
#        l = -torch.sum(y * torch.log(pred + 1e-8)) / pred.shape[0]
#        loss += l

        return pred, x, var, var_types

    def env_step(self, input_index_batch, input_batch, program_batch, step_batch, drop_batch):
        output_batch = []
        var_batch = []
        typ_batch = []
        #print('input_batch:', input_batch.shape)
        #print('input_index_batch:', input_batch.shape)
        for batch_idx in range(program_batch.shape[0]):
            program = self.le.inverse_transform(program_batch[batch_idx].cpu())[0]
            program = Program.parse(program.rstrip())
            num_inputs = len(program.input_types)
            step = step_batch[batch_idx]
            to_drop = drop_batch[batch_idx]
            input_index = input_index_batch[batch_idx, 0]
            input = input_batch[batch_idx]

            input_vals = self.get_input_val(input_index, program)

            #print(program)
            #print('input_index', input_index)
            #print('input_vals', input_vals)
            output = self.get_query_output(program, input_vals)
            #print('output', output)
            output_batch.append(output.unsqueeze(0).unsqueeze(0).unsqueeze(0))

            example = Example(input_vals, output.tolist(), print_=False)
            env = ProgramEnv([example])
            for i, statement in enumerate(program.statements[:step]):
                f, args = statement.function, list(statement.args)
                #print(statement)
                statement = Statement(f, args)

                if env.num_vars < params.max_program_vars:
                    env.step(statement)
                else:
                    # Choose a random var (that is not used anymore) to drop.
                    rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])
                    env.step(statement, rand_idx)

            var = torch.Tensor(env.get_encoding()).cuda()
            typ_batch.append(var[:, :, :2].unsqueeze(0))
            #print(var)
            #print('var:', var.shape)
            #print('output:', output.shape)
            #print('input_index:', input_index.shape)
            #print('input:', input.shape)
            var = var[:, num_inputs : -1, 2:]
            var = torch.cat([var, output.unsqueeze(0).unsqueeze(0)], -2).unsqueeze(0).long()
            var_encoded = F.one_hot(var, params.integer_range + 1).float()
            #print('var_encoded:', var_encoded.shape)
            var_encoded = torch.cat([input[:, :num_inputs].unsqueeze(0), var_encoded], -3)
            #print('var_encoded:', var_encoded.shape)

            var_batch.append(var_encoded)
            #print('new var:', var.shape)
            
        var_encoded = torch.cat(var_batch, 0)
        #print('var_encoded:', var_encoded.requires_grad)
        #print('var_encoded:', var_encoded.shape)
        #print('input:', input_batch.requires_grad)
        #print('input:', input_batch.shape)

        output_onehot = var_encoded[:, :, -1:]
        #print('output_onehot:', output_onehot.shape)
        query_io = torch.cat([input_batch, output_onehot], axis=2)
        #print('query_io', query_io.shape)
        #print('query_io', query_io.requires_grad)

        typ = torch.cat(typ_batch, 0)
        #print('typ', typ.shape)
        #print('typ', typ.requires_grad)

        return query_io, var_encoded, typ


    def get_input_val(self, input_val, program):
        input_types = list(program.input_types)
        for i in range(len(program.input_types), 3):
            input_types.append('NULL')
        input_vals = []
        for raw_val, input_type in zip(input_val, input_types):
#            print(raw_val.shape)
            if str(input_type) == 'NULL':
                continue
            elif str(input_type) == 'INT':
                raw_val = raw_val[0] + params.integer_min
            elif str(input_type) == 'LIST':
#                if params.integer_range in raw_val:
#                    raw_val = raw_val[:torch.where(raw_val==params.integer_range)[0][0]] + params.integer_min
#                else:
#                    raw_val = raw_val + params.integer_min
                raw_val = raw_val + params.integer_min
            else:
                raise ValueError('bad type {}'.format(input_type))
#            val = Value.construct(raw_val.tolist(), input_type)
            input_vals.append(raw_val.tolist())
        return input_vals

    def get_query_output(self, program, input_vals):
        input_vals = [Value.construct(raw_val) for raw_val in input_vals]
        output_val = program(*input_vals)
        output = torch.zeros(20, device='cuda') + params.integer_range
        if str(output_val.type) == 'INT':
            output[0] = output_val.val - params.integer_min
        else:
            if output_val.val is not None and len(output_val.val) != 0:
                output[:len(output_val.val)] = torch.Tensor(output_val.val) - params.integer_min
        return output.long()


#    def single_process(self, batch_idx):
#        program = self.le.inverse_transform(self.program_batch[batch_idx].cpu())[0]
#        #program = self.le.inverse_transform(program.cpu())[0]
#        program = Program.parse(program.rstrip())
#        num_inputs = len(program.input_types)
#        step = self.step_batch[batch_idx]
#        to_drop = self.drop_batch[batch_idx]
#        input_index = self.input_index_batch[batch_idx, 0]
#        input = self.input_batch[batch_idx]
#
#        input_vals = self.get_input_val(input_index, program)
#
#        output = self.get_query_output(program, input_vals)
#        #print('output', output)
##        output_batch.append(output.unsqueeze(0).unsqueeze(0).unsqueeze(0))
#
#        example = Example(input_vals, output.tolist(), print_=False)
#        env = ProgramEnv([example])
#        for i, statement in enumerate(program.statements[:step]):
#            f, args = statement.function, list(statement.args)
#            #print(statement)
#            statement = Statement(f, args)
#
#            if env.num_vars < params.max_program_vars:
#                env.step(statement)
#            else:
#                # Choose a random var (that is not used anymore) to drop.
#                rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])
#                env.step(statement, rand_idx)
#
#        var = torch.Tensor(env.get_encoding())#.cuda()
##        typ_batch.append(var[:, :, :2].unsqueeze(0))
#        #print(var)
#        #print('var:', var.shape)
#        #print('output:', output.shape)
#        #print('input_index:', input_index.shape)
#        #print('input:', input.shape)
#        var = var[:, num_inputs : -1, 2:]
#        var = torch.cat([var, output.unsqueeze(0).unsqueeze(0)], -2).unsqueeze(0).long()
#        var_encoded = F.one_hot(var, params.integer_range + 1).float()
#        #print('var_encoded:', var_encoded.shape)
#        var_encoded = torch.cat([input[:, :num_inputs].unsqueeze(0), var_encoded], -3)
#        #print('var_encoded:', var_encoded.shape)
#
##        var_batch.append(var_encoded)
#        #print('new var:', var.shape)
#        return var_encoded.cuda(), var[:, :, :2].unsqueeze(0).cuda()
#
##    def single_process(self, program, step, to_drop, input_index, input):
##        #program = self.le.inverse_transform(self.program_batch[batch_idx].cpu())[0]
##        program = self.le.inverse_transform(program.cpu())[0]
##        program = Program.parse(program.rstrip())
##        num_inputs = len(program.input_types)
##        #step = self.step_batch[batch_idx]
##        #to_drop = self.drop_batch[batch_idx]
##        input_index = input_index[0]
##        #input = self.input_batch[batch_idx]
##
##        input_vals = self.get_input_val(input_index, program)
##
##        output = self.get_query_output(program, input_vals)
##        #print('output', output)
###        output_batch.append(output.unsqueeze(0).unsqueeze(0).unsqueeze(0))
##
##        example = Example(input_vals, output.tolist(), print_=False)
##        env = ProgramEnv([example])
##        for i, statement in enumerate(program.statements[:step]):
##            f, args = statement.function, list(statement.args)
##            #print(statement)
##            statement = Statement(f, args)
##
##            if env.num_vars < params.max_program_vars:
##                env.step(statement)
##            else:
##                # Choose a random var (that is not used anymore) to drop.
##                rand_idx = random.choice([j for j in range(len(to_drop)) if to_drop[j] > 0])
##                env.step(statement, rand_idx)
##
##        var = torch.Tensor(env.get_encoding())#.cuda()
###        typ_batch.append(var[:, :, :2].unsqueeze(0))
##        #print(var)
##        #print('var:', var.shape)
##        #print('output:', output.shape)
##        #print('input_index:', input_index.shape)
##        #print('input:', input.shape)
##        var = var[:, num_inputs : -1, 2:]
##        var = torch.cat([var, output.unsqueeze(0).unsqueeze(0)], -2).unsqueeze(0).long()
##        var_encoded = F.one_hot(var, params.integer_range + 1).float()
##        #print('var_encoded:', var_encoded.shape)
##        var_encoded = torch.cat([input[:, :num_inputs].unsqueeze(0), var_encoded], -3)
##        #print('var_encoded:', var_encoded.shape)
##
###        var_batch.append(var_encoded)
##        #print('new var:', var.shape)
###        var_encoded, var_typ = torch.ones(1, 1, 12, 20, 513, device='cuda'), torch.zeros(1, 12, 2, device='cuda')
###        return var_encoded.cuda(), var[:, :, :2].unsqueeze(0).cuda()
###        return var_encoded.cuda(), var_typ.cuda()
##        return

#    def env_step(self, input_index_batch, input_batch, program_batch, step_batch, drop_batch, pool):
#        self.output_batch = []
#        self.var_batch = []
#        self.typ_batch = []
#        #print('input_batch:', input_batch.shape)
#        #print('input_index_batch:', input_batch.shape)
##        for batch_idx in range(program_batch.shape[0]):
#
#        #
#
##        res = list(pool.starmap(self.single_process, zip(program_batch.detach().cpu(), step_batch.detach().cpu(), drop_batch.detach().cpu(), input_index_batch.detach().cpu(), input_batch.detach().cpu())))
##        with Pool(processes=1) as pool:
##        res = list(pool.imap(self.single_process, list(range(8))))
##        res = list(pool.imap(self.f, list(range(8))))
#        res = list(pool.imap(self.single_process, list(range(program_batch.shape[0]))))
##        with Pool(processes=8) as pool:
##            res = list(pool.starmap(self.single_process, zip(program_batch.detach().cpu(), step_batch.detach().cpu(), drop_batch.detach().cpu(), input_index_batch.detach().cpu(), input_batch.detach().cpu())))
##        print(len(res))
##        print(len(list(zip(*res))))
#        #var_batch, typ_batch = zip(*res)
#            
#        #var_encoded = torch.cat(var_batch, 0)
#        var_encoded, typ = torch.ones(32, 1, 12, 20, 513, device='cuda'), torch.zeros(32, 12, 2, device='cuda')
#        #print('var_encoded:', var_encoded.requires_grad)
#        #print('var_encoded:', var_encoded.shape)
#        #print('input:', input_batch.requires_grad)
#        #print('input:', input_batch.shape)
#
#        output_onehot = var_encoded[:, :, -1:]
#        #print('output_onehot:', output_onehot.shape)
#        #print('input_batch:', input_batch.shape)
#        query_io = torch.cat([input_batch, output_onehot], axis=2)
#        #print('query_io', query_io.shape)
#        #print('query_io', query_io.requires_grad)
#
#        #typ = torch.cat(typ_batch, 0)
#        #print('typ', typ.shape)
#        #print('typ', typ.requires_grad)
#
#        return query_io, var_encoded, typ
#