import torch
import torch.nn as nn
from src.model.mlp import MLP
from torch.distributions import Categorical
import numpy as np

"""
Categorical 1 : left and right edges model
"""


# class EdgesModelCategorical(nn.Module):
#     def __init__(self, edge_cat_cfg):
#         super(EdgesModelCategorical, self).__init__()
#         self.root_edge_model = nn.Linear(384, 40)#MLP(edge_cat_cfg.ROOT_EDGE_HEAD)
#         # self.lr_model = nn.Linear(3, 5) #MLP(edge_cat_cfg.HEAD)
#
#     def forward(self, root_edges_reps):
#         root_edges_logits = self.root_edge_model(root_edges_reps)
#         return root_edges_logits

class EdgesModelCategorical(nn.Module):

    def __init__(self, edge_cat_cfg):
        super(EdgesModelCategorical, self).__init__()
        # model last step edge
        self.root_edge_model = MLP(edge_cat_cfg.ROOT_EDGE_HEAD)

        # model for first n-1 steps left and right edges
        self.lr_model = MLP(edge_cat_cfg.HEAD)
        self.logsoftmax = torch.nn.LogSoftmax(dim=1)
        self.lr_actions_list = np.arange(edge_cat_cfg.HEAD.OUTPUT_SIZE)
        self.root_edge_actions_list = np.arange(edge_cat_cfg.ROOT_EDGE_HEAD.OUTPUT_SIZE)
#
    # def forward(self, root_edges_reps):
    #     root_edges_logits = self.root_edge_model(root_edges_reps)
    #     return root_edges_logits

    def forward(self, summary_reps, left_trees, right_trees, input_dict):

        random_spec = input_dict.get('random_spec', None)
        input_edge_actions = input_dict.get('input_edge_actions', None)

        # for now the representation is the concatenation of left and right tree
        rep = torch.cat([summary_reps, left_trees, right_trees], dim=1)
        batch_nb_seq = input_dict['batch_nb_seq']
        root_edges_flag = batch_nb_seq == 2
        first_edges_flag = batch_nb_seq > 2
        first_edges_reps = rep[first_edges_flag]
        root_edges_reps = rep[root_edges_flag]

        ret = {}
        edge_actions = torch.zeros(len(batch_nb_seq)).long().to(batch_nb_seq)
        if first_edges_flag.sum().item() > 0:
            first_edges_logits = self.lr_model(first_edges_reps)
            first_edges_ret = {
                'logits': first_edges_logits
            }
            actions = self.sample(first_edges_ret, self.lr_actions_list, random_spec)
            ret['first_edges_actions'] = actions
            edge_actions[first_edges_flag] = actions
            ret['first_edges_ret'] = first_edges_ret

        if root_edges_flag.sum().item() > 0:
            root_edges_logits = self.root_edge_model(root_edges_reps)
            root_edges_ret = {
                'logits': root_edges_logits
            }
            actions = self.sample(root_edges_ret, self.root_edge_actions_list, random_spec)
            ret['root_edges_actions'] = actions
            edge_actions[root_edges_flag] = actions
            ret['root_edges_ret'] = root_edges_ret

        if input_edge_actions is not None:
            edge_actions[:len(input_edge_actions)] = input_edge_actions
        ret['edge_actions'] = edge_actions
        log_paths_pf = self.compute_log_path_pf(ret, input_dict['batch_nb_seq'], edge_actions)
        ret['log_paths_pf'] = log_paths_pf
        return ret

    def sample(self, ret, actions_list, random_spec):
        logits = ret['logits']
        if random_spec is None:
            random_spec = {
                'random_action_prob': 0.0
            }
        if 'random_action_prob' in random_spec:
            random_p = random_spec['random_action_prob']
            distribution = Categorical(logits=logits)
            edge_action = distribution.sample()
            if random_p > 0:
                batch_size = edge_action.shape[0]
                rand_flag = (torch.empty(batch_size).uniform_(0, 1)) <= random_p
                rand_num = rand_flag.sum().item()
                if rand_num > 0:
                    rand_actions = torch.tensor(np.random.choice(actions_list, rand_num)).to(edge_action)
                    edge_action[rand_flag] = rand_actions
        else:
            T = random_spec['T']
            distribution = Categorical(logits=logits / T)
            edge_action = distribution.sample()
        return edge_action

    def compute_log_path_pf(self, ret, batch_nb_seq, edge_actions):

        root_edges_flag = batch_nb_seq == 2
        first_edges_flag = batch_nb_seq > 2
        log_paths_pf = torch.zeros(len(batch_nb_seq)).to(edge_actions.device)

        if first_edges_flag.sum().item() > 0:
            first_edges_ret = ret['first_edges_ret']
            first_edges_actions = edge_actions[first_edges_flag]
            log_p = self.logsoftmax(first_edges_ret['logits'])
            pf = log_p[torch.arange(len(first_edges_actions)), first_edges_actions]
            log_paths_pf[first_edges_flag] = log_paths_pf[first_edges_flag] + pf

        if root_edges_flag.sum().item() > 0:
            root_edges_ret = ret['root_edges_ret']
            log_p = self.logsoftmax(root_edges_ret['logits'])
            root_edges_actions = edge_actions[root_edges_flag]
            pf = log_p[torch.arange(len(root_edges_actions)), root_edges_actions]
            log_paths_pf[root_edges_flag] = log_paths_pf[root_edges_flag] + pf

        return log_paths_pf
