"""
Defines Neural Networks
"""
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
from torch.autograd.variable import Variable
from ..utils.generators.mixed_len_generator import Parser, \
    SimulateStack
from typing import List
from ..utils.grammar import Stack, Mask, ContinuousImageStack
from torch.distributions.multivariate_normal import MultivariateNormal
import pdb

torch.set_printoptions(threshold=100000)





class ImitateJoint(nn.Module):
    def __init__(self,
                 hd_sz,
                 input_size,
                 encoder,
                 mode,
                 batch_size,
                 temperature,
                 num_layers=1,
                 time_steps=3,
                 num_draws=None,
                 canvas_shape=[64, 64],
                 unique_draw=None,
                 dropout=0.5,
                 num_GPU=1):
        """
        Defines RNN structure that takes features encoded by CNN and produces program
        instructions at every time step.
        :param num_draws: Total number of tokens present in the dataset or total number of operations to be predicted + a stop symbol = 400
        :param canvas_shape: canvas shape
        :param dropout: dropout
        :param hd_sz: rnn hidden size
        :param input_size: input_size (CNN feature size) to rnn
        :param encoder: Feature extractor network object
        :param mode: Mode of training, RNN, BDRNN or something else
        :param num_layers: Number of layers to rnn
        :param time_steps: max length of program
        """
        super(ImitateJoint, self).__init__()
        self.hd_sz = hd_sz
        self.in_sz = input_size
        self.num_layers = num_layers
        self.encoder = encoder
        self.time_steps = time_steps
        self.mode = mode
        self.canvas_shape = canvas_shape
        self.num_draws = num_draws
        self.unique_draw = unique_draw
        self.temperature = temperature
        self.init_exp = 0.2
        self.param_size = 3
        # TODO determine batch size

        # self.n_T = 400 # we don't want EOP end up in the stack

        # Dense layer to project input ops(labels) to input of rnn
        self.input_op_sz = 128
        self.num_GPU = num_GPU
        self.dense_input_op = nn.Linear(
            in_features=self.num_draws, out_features=self.input_op_sz)
        self.dense_input_op_param = nn.Linear(
            in_features=self.param_size, out_features=self.input_op_sz
        )
        self.rnn = nn.LSTM(
            input_size=self.in_sz * 3  + self.input_op_sz,
            hidden_size=self.hd_sz,
            num_layers=self.num_layers,
            batch_first=False)

        # adapt logsoftmax and softmax for different versions of pytorch
        self.pytorch_version = torch.__version__[2]

        self.logsoftmax = nn.LogSoftmax(1)
        self.softmax = nn.Softmax(1)
        self.dense_fc_1 = nn.Linear(
            in_features=self.hd_sz, out_features=self.hd_sz)
        self.dense_fc_cont_1 = nn.Linear(
            in_features=self.hd_sz, out_features=self.hd_sz)
        self.dense_output = nn.Linear(
            in_features=self.hd_sz, out_features=(self.num_draws))
        self.dense_output_param = nn.Linear(
            in_features=self.hd_sz, out_features=(self.param_size*2))

        self.drop = nn.Dropout(dropout)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.batch_norm_emb = nn.BatchNorm1d(self.input_op_sz)

    def forward(self, x: List):
        # program length in this case is the maximum time step that RNN runs

        data, program_len, epoch_n = x
        batch_size = data.size()[0]
        stack = Stack(int(batch_size), continuous=True)
        mask = Mask(continuous=True)
        imgS = ContinuousImageStack(int(batch_size), program_len)
        hidden = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda()
        context = Variable(torch.zeros(1, batch_size, self.hd_sz)).cuda()
        param_sample = Variable(torch.zeros(1, batch_size, self.param_size)).cuda()
        x_f = self.encoder.encode(data[:, 0:1, :, :])
        x_f = x_f.view(1, batch_size, self.in_sz)
        outputs = []
        samples = []
        params = []
        param_log_probs = []

        neg_entropy = torch.zeros((4,)).cuda()

        stack.init()

        for timestep in range(0, program_len * 3):

            # X_f is the input to the RNN at every time step along with previous
            # predicted label

            pop_sym = stack.pop()

            arr = Variable(
                torch.cuda.FloatTensor(batch_size, self.num_draws).fill_(0).scatter_(
                    1, torch.cuda.LongTensor(pop_sym).view(-1, 1), 1.0))

            arr = arr.detach()

            temp_input_op = arr  # .cuda()
            param_sample = param_sample.view(1, batch_size, self.param_size)
            input_op_rnn = self.relu(self.dense_input_op(temp_input_op) + self.dense_input_op_param(param_sample))
            #input_op_rnn = self.batch_norm_emb(input_op_rnn)
            input_op_rnn = input_op_rnn.view(1, batch_size,
                                             self.input_op_sz)

            x_f_1 = self.encoder.encode(imgS.batch_image[:, 0:1, :, :])
            x_f_1 = x_f_1.view(1, batch_size, self.in_sz)
            x_f_2 = self.encoder.encode(imgS.batch_image[:, 1:2, :, :])
            x_f_2 = x_f_2.view(1, batch_size, self.in_sz)

            input = torch.cat((x_f, input_op_rnn, x_f_1, x_f_2), 2)

            out, hc = self.rnn(input, (hidden, context))
            hidden = hc[0]
            context = hc[1]

            ###### Discrete Actions #####
            hd = self.relu(self.dense_fc_1(self.drop(hidden[0])))
            dense_output = self.dense_output(self.drop(hd))

            dense_output_mask = dense_output + mask.get_mask_logP(pop_sym)

            output = self.logsoftmax(dense_output_mask)

            output_probs = self.softmax(dense_output_mask)
            if np.random.rand() < self.epsilon:
                sample = torch.multinomial(output_probs.cpu(), 1).cuda()
            else:
                sample = torch.max(output_probs, 1)[1].view(
                    batch_size, 1)

            output_mask = torch.cuda.FloatTensor(batch_size, self.num_draws).fill_(0).scatter_(
                1, torch.cuda.LongTensor(sample).view(-1, 1), 1.0)
            outputs.append(output * output_mask)

            ##### Continuous Actions ######
            hd_cont = self.relu(self.dense_fc_cont_1(self.drop(hidden[0]))) # + self.dense_input_op(output_mask))
            param = self.dense_output_param(self.drop(hd_cont))
            param = self.sigmoid(param)

            # Get samples from output probabs based on epsilon greedy way
            # Epsilon will be reduced to 0 gradually following some schedule
            # for terminals, it always sample itself

            # Scaling the Continuous Params

            #zeros_vec = torch.zeros_like(pop_sym)
            #one_vec = torch.ones_like(pop_sym)
            #mask_ = torch.where(pop_sym == stack.sym2idx['P'], one_vec, zeros_vec)
            #mask_ = mask_.unsqueeze(1).repeat([1, self.param_size * 2])
            #param = param * mask_.type(torch.cuda.FloatTensor)

            scale = self.canvas_shape.copy()
            width = self.canvas_shape[0]
            scale += [width / 2., width / 2., width / 2., width / 4.]
            scale = torch.cuda.FloatTensor(scale)
            param_scale = param * scale
            param_scale[:, 2] += 6 * torch.ones((batch_size,)).cuda()
            param_entropy = torch.zeros((batch_size)).cuda()
            param_sample = torch.zeros((batch_size, 3)).cuda()
            param_log_prob = torch.zeros((batch_size)).cuda()

            zeros_vec = torch.zeros_like(pop_sym)
            one_vec = torch.ones_like(pop_sym)
            mask_P = torch.where(pop_sym == stack.sym2idx['P'], one_vec, zeros_vec)
            valid_idx = torch.nonzero(mask_P)

            if np.random.rand() < self.epsilon and valid_idx.nelement() != 0:
                #print (torch.diag_embed(param_scale[:, self.param_size:]))
                valid_idx = valid_idx.squeeze(1)
                param_scale = param_scale.index_select(0, valid_idx)
                param_dist = MultivariateNormal(param_scale[:, :self.param_size], torch.diag_embed(param_scale[:, self.param_size:]))
                param_sample_ = param_dist.sample()
                #param_entropy_ = param_dist.entropy()
                P_size = param_scale.size()[0]
                param_entropy_ = torch.bmm(param_scale[:, self.param_size:].view(P_size, 1, self.param_size), param_scale[:, self.param_size:].view(P_size, self.param_size, 1)).squeeze(1)
                param_log_prob_ = param_dist.log_prob(param_sample_)
                param_entropy[valid_idx] = param_entropy_.squeeze(1)
                param_sample[valid_idx, :] = param_sample_
                param_log_prob[valid_idx] = param_log_prob_
                #mask_ = mask_.unsqueeze(1).repeat([1, self.param_size])
                #param_sample = param_sample * mask_.type(torch.cuda.FloatTensor)


            else:
                param_dist = MultivariateNormal(param[:, :self.param_size],
                                                torch.diag_embed(param[:, self.param_size:]))
                param_sample = param_dist.mean
                param_entropy = param_dist.entropy()

                param_log_prob = param_dist.log_prob(param_sample)

                zeros_vec = torch.zeros_like(pop_sym)
                one_vec = torch.ones_like(pop_sym)
                # print ("new iteration")
                mask_ = torch.where(pop_sym == stack.sym2idx['P'], one_vec, zeros_vec)
                param_entropy = param_entropy * mask_.float()
                param_log_prob = param_log_prob * mask_.float()
                # print (param_log_prob.size())
                # print (output.size())
                mask_ = mask_.unsqueeze(1).repeat([1, self.param_size])
                param_sample = param_sample * mask_.float()
                #print(param_dist.variance * mask_.float())

            # Entropy Calculation
            entropy_element = (output * output_probs) * mask.get_mask(pop_sym)

            ent_ele0 = torch.sum(entropy_element[:, :stack.sym2idx['EOP'] - 3], dim=1)
            neg_entropy[0] += (ent_ele0.sum()) / batch_size  # (P_num.sum() + T_num.sum())
            ent_ele1 = torch.sum(entropy_element[:, stack.sym2idx['EOP'] - 3:stack.sym2idx['EOP']], dim=1)
            neg_entropy[1] += ent_ele1.sum() / batch_size
            ent_ele2 = torch.sum(entropy_element[:, stack.sym2idx['EOP']:], dim=1)
            neg_entropy[2] += ent_ele2.sum() / batch_size
            neg_entropy[3] += 5. *(torch.exp(-1*param_entropy[param_entropy != 0])).sum() / batch_size
            #neg_entropy[3] -=  param_entropy.sum() / batch_size

            # Stopping the gradient to flow backward from samples
            param_log_probs.append(param_log_prob.unsqueeze(1))
            sample = sample.detach()
            param_sample = param_sample.detach()
            samples.append(sample)
            params.append(param_sample)

            zeros_vec = torch.zeros_like(sample)
            masked_sample = torch.where(sample >= stack.sym2idx['E'], sample, zeros_vec)

            stack.push(masked_sample.squeeze(1))
            imgS.push(sample, param_sample)
            sample = sample.cpu()
            param = param.cpu()

            # TODO push into stack only if pop_sym is a non-T and remove non-T in reinforce


        return [outputs, samples, neg_entropy, params, param_log_probs]


class ContinuousParseModelOutput:
    def __init__(self, unique_draws: List, stack_size: int, steps: int,
                 canvas_shape: List):
        """
        This class parses complete output from the network which are in joint
        fashion. This class can be used to generate final canvas and
        expressions.
        :param unique_draws: Unique draw/op operations in the current dataset
        :param stack_size: Stack size
        :param steps: Number of steps in the program
        :param canvas_shape: Shape of the canvases
        """
        self.canvas_shape = canvas_shape
        self.stack_size = stack_size
        self.steps = steps
        self.Parser = Parser()
        self.sim = SimulateStack(self.stack_size, self.canvas_shape)
        self.unique_draws = unique_draws
        # self.pytorch_version = torch.__version__[2]
        """""""""
        HARDCODED
        """""""""
        self.n_T = 6  # EOP cannot render an image

    def get_final_canvas(self,
                         outputs: List,
                         if_just_expressions=False,
                         if_pred_images=False):
        # TODO
        return

    def expression2stack(self, expressions: List):
        """Assuming all the expression are correct and coming from
        groundtruth labels. Helpful in visualization of programs
        :param expressions: List, each element an expression of program
        """
        stacks = []
        for index, exp in enumerate(expressions):
            program = self.Parser.parse(exp)
            self.sim.generate_stack(program)
            stack = np.array(self.sim.stack_t[-1][0:1])
            # stack = np.stack(stack, axis=0)
            stacks.append(stack)
        stacks = np.stack(stacks, 0).astype(dtype=np.float32)
        return stacks

    def labels2exps(self, labels: np.ndarray, steps: int, params):

        """
        Assuming grountruth labels, we want to find expressions for them
        :param labels: Grounth labels batch_size x time_steps
        :return: expressions: Expressions corresponding to labels
        """

        if isinstance(labels, np.ndarray):
            batch_size = labels.shape[0]
        else:
            batch_size = labels.size()[0]
            labels = labels.data.cpu().numpy()
        # Initialize empty expression string, len equal to batch_size
        correct_programs = []
        expressions = [""] * batch_size
        p_len = [5] * batch_size
        pre = 0
        for j in range(batch_size):

            for i in range(labels.shape[1]):
                # TODO replace the pre specified value
                if labels[j, i] < self.n_T and labels[j, i] > 2:
                    expressions[j] += self.unique_draws[labels[j, i]]
                elif labels[j, i] < 3:
                    st = "("
                    p = params[i][j]
                    for param_i in range(p.shape[0]):
                        if param_i < 2:
                            st += str(max(min(int(p[param_i]), self.canvas_shape[0]),0)) + ","
                        else:
                            st += str(max(min(int(p[param_i]), self.canvas_shape[0]),1)) + ")"
                    expressions[j] += self.unique_draws[labels[j, i]]+st

                elif labels[j, i] == 11:
                    p_len[j] = i
                    break

        return expressions, p_len


def validity(program: List, max_time: int, timestep: int):
    """
    Checks the validity of the program. In short implements a pushdown automaton that accepts valid strings.
    :param program: List of dictionary containing program type and elements
    :param max_time: Max allowed length of program
    :param timestep: Current timestep of the program, or in a sense length of
    program
    # at evey index
    :return:
    """
    num_draws = 0
    num_ops = 0
    for i, p in enumerate(program):
        if p["type"] == "draw":
            # draw a shape on canvas kind of operation
            num_draws += 1
        elif p["type"] == "op":
            # +, *, - kind of operation
            num_ops += 1
        elif p["type"] == "stop":
            # Stop symbol, no need to process further
            if num_draws > ((len(program) - 1) // 2 + 1):
                return False
            if not (num_draws > num_ops):
                return False
            return (num_draws - 1) == num_ops

        if num_draws <= num_ops:
            # condition where number of operands are lesser than 2
            return False
        if num_draws > (max_time // 2 + 1):
            # condition for stack over flow
            return False
    if (max_time - 1) == timestep:
        return (num_draws - 1) == num_ops
    return True
