"""
Defines Neural Networks
"""
import torch.nn as nn
import numpy as np
from torch.autograd.variable import Variable
from ..utils.generators.mixed_len_generator import Parser, \
    SimulateStack, Draw
from typing import List
from ..utils.grammar import Stack, Mask, ImageStack
from src.utils.Gumbel import *

torch.set_printoptions(threshold=100000)




class ImitateJoint(nn.Module):
    def __init__(self,
                 hd_sz,
                 input_size,
                 encoder,
                 stack_encoder,
                 mode,
                 batch_size,
                 temperature,
                 num_layers=1,
                 time_steps=3,
                 num_draws=None,
                 canvas_shape=[64, 64],
                 unique_draw=None,
                 dropout=0.5,
                 num_GPU=1,
                 toy = False):
        """
        Defines RNN structure that takes features encoded by CNN and produces program
        instructions at every time step.
        :param num_draws: Total number of tokens present in the dataset or total number of operations to be predicted + a stop symbol = 400
        :param canvas_shape: canvas shape
        :param dropout: dropout
        :param hd_sz: rnn hidden size
        :param input_size: input_size (CNN feature size) to rnn
        :param encoder: Feature extractor network object
        :param mode: Mode of training, RNN, BDRNN or something else
        :param num_layers: Number of layers to rnn
        :param time_steps: max length of program
        """
        super(ImitateJoint, self).__init__()
        self.hd_sz = hd_sz
        self.in_sz = input_size
        self.num_layers = num_layers
        self.encoder = encoder
        self.time_steps = time_steps
        self.mode = mode
        self.canvas_shape = canvas_shape
        self.num_draws = num_draws
        self.unique_draw = unique_draw
        self.temperature = temperature
        self.init_exp = 0.2

        #### Create Shape Dict
        draw = Draw()
        shape_dict = []
        for shape_id in range(len(unique_draw) - 10):
            close_paren = unique_draw[shape_id].index(")")
            value = unique_draw[shape_id][2:close_paren].split(",")
            if unique_draw[shape_id][0] == "c":
                shape_arr = draw.draw_circle([int(value[0]), int(value[1])], int(value[2]))
            elif unique_draw[shape_id][0] == "t":
                shape_arr = draw.draw_triangle([int(value[0]), int(value[1])], int(value[2]))
            elif unique_draw[shape_id][0] == "s":
                shape_arr = draw.draw_square([int(value[0]), int(value[1])], int(value[2]))
            shape_dict.append(shape_arr)
        shape_dict = np.array(shape_dict)
        self.shape_dict = torch.from_numpy(1 * shape_dict).type(torch.FloatTensor)

        # TODO determine batch size
        # self.stack = Stack(int(batch_size/num_GPU))
        # self.mask = Mask()

        # self.n_T = 400 # we don't want EOP end up in the stack

        # Dense layer to project input ops(labels) to input of rnn
        self.input_op_sz = 128
        self.num_GPU = num_GPU
        self.dense_input_op = nn.Linear(
            in_features=self.num_draws, out_features=self.input_op_sz)

        self.rnn = nn.LSTM(
            input_size=self.in_sz + self.input_op_sz,
            hidden_size=self.hd_sz,
            num_layers=self.num_layers,
            batch_first=False)

        # adapt logsoftmax and softmax for different versions of pytorch
        self.pytorch_version = torch.__version__[2]
        """
        if self.pytorch_version == "1":
            self.logsoftmax = nn.LogSoftmax()
            self.softmax = nn.Softmax()

        elif self.pytorch_version == "3" or self.pytorch_version == "4":
        """
        self.logsoftmax = nn.LogSoftmax(1)
        self.softmax = nn.Softmax(1)
        self.dense_fc_1 = nn.Linear(
            in_features=self.hd_sz, out_features=self.hd_sz)
        self.dense_output = nn.Linear(
            in_features=self.hd_sz, out_features=(self.num_draws))
        self.drop = nn.Dropout(dropout)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.batch_norm_emb = nn.BatchNorm1d(self.input_op_sz)
        self.stack_encoder = stack_encoder
        self.toy = toy

    def forward(self, x: List, k = 1, debug = False, normalization = True):
        # program length in this case is the maximum time step that RNN runs

        data, program_len, epoch_n = x

        batch_size = data.size()[0]
        stack = Stack(int(batch_size), k = k, unique_draws=self.unique_draw)
        mask = Mask(unique_draws=self.unique_draw)
        imgS = ImageStack(self.shape_dict.cuda(), int(batch_size), program_len, k = k, unique_draws=self.unique_draw)
        hidden = Variable(torch.zeros(1, batch_size * k, self.hd_sz)).cuda() #[ for i in range(3 * program_len + 1)]
        context = Variable(torch.zeros(1, batch_size * k, self.hd_sz)).cuda() #[ for i in range(3 * program_len + 1)]
        x_f = self.encoder.encode(data[:, 0:1, :, :])
        x_f = x_f.view(1, batch_size, self.in_sz).repeat(1, k, 1)
        outputs = torch.zeros(batch_size*k, program_len*3).cuda()
        samples = torch.zeros(batch_size*k, program_len*3).cuda()
        G = torch.zeros(batch_size * k, 1).cuda()
        logp = torch.zeros(batch_size * k, 1).cuda()
        entropy_element = torch.zeros(batch_size).cuda()
        ent_plot = torch.zeros(batch_size).cuda()


        #neg_entropy = torch.zeros((3,)).cuda()

        stack.init()
        sample = stack.pop()
        for timestep in range(0, program_len * 3):
            # X_f is the input to the RNN at every time step along with previous
            # predicted label

            arr = Variable(
                torch.cuda.FloatTensor(batch_size * k, self.num_draws).fill_(0).scatter_(
                    1, torch.cuda.LongTensor(sample).view(-1, 1), 1.0))

            arr = arr.detach()

            temp_input_op = arr

            input_op_rnn = self.relu(self.dense_input_op(temp_input_op))
            input_op_rnn = input_op_rnn.view(1, batch_size * k,
                                             self.input_op_sz)


            x_f = x_f.view(1, batch_size * k, self.in_sz)

            input = torch.cat((x_f, input_op_rnn), 2)

            out, hc = self.rnn(input, (hidden, context)) #[timestep]
            hidden = hc[0] #[timestep]
            context = hc[1] #[timestep]

            hd = self.relu(self.dense_fc_1(self.drop(hidden[0]))) #[timestep]
            dense_output = self.dense_output(self.drop(hd))

            dense_output_mask = dense_output

            output_ = self.logsoftmax(dense_output_mask) # batch size by num draw

            output_probs_ = self.softmax(dense_output_mask)

            sample = torch.max(output_probs_, 1)[1].view(batch_size*k, 1)

            if timestep > 0:
                phi = output_.detach() + logp.detach().repeat((1, self.num_draws))
                g, argmax_g, g_phi = gumbel_with_maximum(phi, G)
                ##### WITH ADJUSTMENTS #####
                g = torch.split(g, batch_size, dim=0)
                ############################

                # How many unique branches there are
                num_unique_branch = timestep if timestep <= 2 else k
                g = g[:num_unique_branch]
                g = torch.cat(g, dim=1)
                g_phi = torch.cat(torch.split(g_phi, batch_size, dim=0)[:num_unique_branch], dim=1)

                # How many unique branches there could be expanded from
                num_branch = 3 if timestep == 1 else k
                g_val, index = torch.topk(g, num_branch, dim=1)
                #g_phi_select = torch.zeros_like(g_phi).scatter_(
                #    1, torch.cuda.LongTensor(index), 1.0)
                #g_phi_val = torch.masked_select(g_phi, g_phi_select.byte())
                #g_phi_val = torch.transpose(torch.cat(torch.split(g_phi_val.unsqueeze(1), num_branch, dim=0), dim=1), 0, 1)
                #index = index[:, :-1]
                G[:batch_size*num_branch, :] = torch.cat(torch.split(g_val, 1, dim=1), dim=0)

                # deciding on the beam ID and action ID of the selected branching
                beam_ind = (index / self.num_draws).long()
                action_ind = (index % self.num_draws).long()

                # solidify the sample outputs
                sample_ = torch.split(action_ind, 1, dim=1)
                sample_ = torch.cat(sample_, dim=0)
                sample[:num_branch*batch_size, :] = sample_

                # rearrange the log probability for further branching
                beam_ind = torch.split(beam_ind, 1, dim=1)
                beam_ind = torch.cat(beam_ind, dim=0)
                initial_order = torch.arange(0, batch_size * k).cuda().long()
                order = (beam_ind * batch_size).squeeze(1)  + torch.arange(0, batch_size).repeat((num_branch)).cuda().long()
                initial_order[:batch_size*num_branch] = order
                order = initial_order


                output = torch.index_select(output_, 0, order)
                output_probs = torch.index_select(output_probs_, 0, order)
                logp = torch.index_select(logp, 0, order)
                output_mask = torch.cuda.FloatTensor(batch_size*k, self.num_draws).fill_(0).scatter_(
                    1, torch.cuda.LongTensor(sample).view(-1, 1), 1.0)
                logp += torch.sum(output * output_mask, dim = 1).unsqueeze(1)
                #entropy_element = torch.index_select(entropy_element, 0, order)
                samples = torch.index_select(samples, 0, order)
                outputs = torch.index_select(outputs, 0, order)

                # Entropy calculation

                ent = torch.sum((output * output_probs), dim=1).unsqueeze(1)

                g_k = g_val[:, -1].unsqueeze(1).repeat((1, num_branch-1)).detach()
                phi_k = torch.split(logp, batch_size, dim=0)
                phi_k = phi_k[:num_branch-1]
                phi_k = torch.cat(phi_k, dim=1).detach()
                log_q = gumbel_log_survival(g_k - phi_k)
                log_p = torch.cat(torch.split(logp, batch_size, dim=0)[:num_branch-1], dim=1)
                log_p_q = log_p - log_q
                w_p_q = torch.exp(log_p_q).detach()
                W = torch.sum(w_p_q, dim=1)
                ent = torch.split(ent, batch_size, dim=0)
                ent = ent[:num_branch-1]
                ent = torch.cat(ent, dim=1)
                # normalization weights for the entropy
                ent_plot += torch.sum(w_p_q * ent, dim=1)
                ent = torch.sum(w_p_q * ent, dim=1) / W
                entropy_element += ent

                # rnn state rearrange
                hidden[0] = torch.index_select(hidden[0], 0, order) #[timestep]
                context[0] = torch.index_select(context[0], 0, order) #[timestep]

            else:
                output = output_
                output_probs = output_probs_
                #entropy_element[:, 0] = torch.sum((output * output_probs) * mask.get_mask(pop_sym), dim=1)

            # Stopping the gradient to flow backward from samples
            sample = sample.detach()
            samples[:, timestep] = sample.squeeze(1)


            #sample = sample.cpu()

        samples = torch.cat(torch.split(samples, batch_size, dim=0), dim=1)
        outputs = torch.cat(torch.split(outputs, batch_size, dim=0), dim=1)

        return [outputs, samples, entropy_element, w_p_q, log_p, [], [], ent_plot]



class ParseModelOutput:
    def __init__(self, unique_draws: List, stack_size: int, steps: int,
                 canvas_shape: List):
        """
        This class parses complete output from the network which are in joint
        fashion. This class can be used to generate final canvas and
        expressions.
        :param unique_draws: Unique draw/op operations in the current dataset
        :param stack_size: Stack size
        :param steps: Number of steps in the program
        :param canvas_shape: Shape of the canvases
        """
        self.canvas_shape = canvas_shape
        self.stack_size = stack_size
        self.steps = steps
        self.Parser = Parser()
        self.sim = SimulateStack(self.stack_size, self.canvas_shape)
        self.unique_draws = unique_draws
        # self.pytorch_version = torch.__version__[2]

        self.n_T = unique_draws.index("EOP") # 30  # EOP cannot render an image

    def get_final_canvas(self,
                         outputs: List,
                         if_just_expressions=False,
                         if_pred_images=False):
        # TODO
        return

    def expression2stack(self, expressions: List):
        """Assuming all the expression are correct and coming from
        groundtruth labels. Helpful in visualization of programs
        :param expressions: List, each element an expression of program
        """
        stacks = []
        for index, exp in enumerate(expressions):
            program = self.Parser.parse(exp)
            self.sim.generate_stack(program)
            stack = np.array(self.sim.stack_t[-1][0:1])
            # stack = np.stack(stack, axis=0)
            stacks.append(stack)
        stacks = np.stack(stacks, 0).astype(dtype=np.float32)
        return stacks

    def labels2exps(self, labels: np.ndarray, steps: int):

        """
        Assuming grountruth labels, we want to find expressions for them
        :param labels: Grounth labels batch_size x time_steps
        :return: expressions: Expressions corresponding to labels
        """

        if isinstance(labels, np.ndarray):
            batch_size = labels.shape[0]
        else:
            batch_size = labels.size()[0]
            labels = labels.data.cpu().numpy()
        # Initialize empty expression string, len equal to batch_size
        correct_programs = []
        expressions = [""] * batch_size
        p_len = [5] * batch_size
        pre = 0

        for j in range(batch_size):

            for i in range(labels.shape[1]):
                # TODO replace the pre specified value
                if labels[j, i] < self.n_T:
                    expressions[j] += self.unique_draws[int(labels[j, i])]

                elif labels[j, i] == self.unique_draws.index('$'): #35: #404
                    p_len[j] = i
                    break

        return expressions, p_len


def validity(program: List, max_time: int, timestep: int):
    """
    Checks the validity of the program. In short implements a pushdown automaton that accepts valid strings.
    :param program: List of dictionary containing program type and elements
    :param max_time: Max allowed length of program
    :param timestep: Current timestep of the program, or in a sense length of
    program
    # at evey index
    :return:
    """
    num_draws = 0
    num_ops = 0
    for i, p in enumerate(program):
        if p["type"] == "draw":
            # draw a shape on canvas kind of operation
            num_draws += 1
        elif p["type"] == "op":
            # +, *, - kind of operation
            num_ops += 1
        elif p["type"] == "stop":
            # Stop symbol, no need to process further
            if num_draws > ((len(program) - 1) // 2 + 1):
                return False
            if not (num_draws > num_ops):
                return False
            return (num_draws - 1) == num_ops

        if num_draws <= num_ops:
            # condition where number of operands are lesser than 2
            return False
        if num_draws > (max_time // 2 + 1):
            # condition for stack over flow
            return False
    if (max_time - 1) == timestep:
        return (num_draws - 1) == num_ops
    return True
