from typing import Optional, List
import itertools
import numpy as np

import torch.nn.functional as F
from torch import nn

import torch
from torch import distributions, optim

from infrastructure import pytorch_util as ptu
from infrastructure.distributions import make_tanh_transformed, make_multi_normal
from networks.gnn import GCNNet, GINE
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch_geometric.data import Data, Batch
class GCNNestedPolicyPG(nn.Module):
    """
    Base GCN policy, which can take an observation and output a distribution over actions.

    This class implements `forward()` which takes a (batched) observation and returns a distribution over actions.
    """

    def __init__(
        self,
        ac_dim: List[int],
        ob_feature_dim: int,
        embed_dim: int,
        discrete: bool,
        n_gcn_layers: int,
        n_layers: int,
        layer_size: int,
        use_tanh: bool = False,
        state_dependent_std: bool = False,
        fixed_std: Optional[float] = None,
    ):
        super().__init__()
        self.num_actions = len(ac_dim)

        self.use_tanh = use_tanh
        self.discrete = discrete
        self.state_dependent_std = state_dependent_std
        self.fixed_std = fixed_std

        self.gcn_out_dim = embed_dim 
        self.gcn = GINE(ob_feature_dim, embed_dim, self.gcn_out_dim, n_gcn_layers).to(ptu.device)

        self.norm = nn.LayerNorm(self.gcn_out_dim*n_gcn_layers).to(ptu.device)
        self.nn0 = ptu.build_mlp(
                input_size= self.gcn_out_dim*n_gcn_layers,
                output_size=ac_dim[0],
                n_layers=n_layers,
                size=layer_size,
                ).to(ptu.device)
        self.nn1 = ptu.build_mlp(
                input_size= self.gcn_out_dim*n_gcn_layers + ac_dim[0],
                output_size=ac_dim[1],
                n_layers=n_layers,
                size=layer_size,
                ).to(ptu.device)
        self.nn2 = ptu.build_mlp(
                input_size= self.gcn_out_dim*n_gcn_layers + sum(ac_dim[0:2]),
                output_size=ac_dim[2],
                n_layers=n_layers,
                size=layer_size,
                ).to(ptu.device)
        self.nn3 = ptu.build_mlp(
                input_size= self.gcn_out_dim*n_gcn_layers + sum(ac_dim[0:3]),
                output_size=ac_dim[3],
                n_layers=n_layers,
                size=layer_size,
                ).to(ptu.device)
        self.nn4 = ptu.build_mlp(
                input_size= self.gcn_out_dim*n_gcn_layers + sum(ac_dim[0:3]),
                output_size=ac_dim[4],
                n_layers=n_layers,
                size=layer_size,
                ).to(ptu.device)


    def forward(self, obs) -> distributions.Distribution:
        """
        This function defines the forward pass of the network.  You can return anything you want, but you should be
        able to differentiate through it. For example, you can return a torch.FloatTensor. You can also return more
        flexible objects, such as a `torch.distributions.Distribution` object. It's up to you!
        """
        obs_x = []
        obs_edge_index = []
        obs_edge_attr = []
        for o in obs:
            obs_x.append(o.node_type)
            obs_edge_index.append(o.edge_index)
            obs_edge_attr.append(o.edge_attr)

        embeddings = torch.empty((0)).to(ptu.device)
        embeddings.requires_grad = True
        
        data_list = []
        for idx, (x, edge, edge_attr) in enumerate(zip(obs_x, obs_edge_index, obs_edge_attr)):
            if edge is None:
                edge = torch.tensor([[0],[0]], dtype=torch.long) #hack
                edge_attr = torch.tensor([[0, 0, 0, 0, 0, 0, 0]], dtype=torch.float32)
            graph = Data(x = x.to(ptu.device).to(torch.float32).squeeze(),
                         edge_index = edge.to(ptu.device),
                         edge_attr = edge_attr.to(torch.float32).to(ptu.device))
            data_list.append(graph)

        batch = Batch.from_data_list(data_list)
        embeddings = self.gcn(batch.x, edge_index=batch.edge_index, 
                              edge_attr=batch.edge_attr, dropout=0, batch=batch.batch)

        ac0 = self.nn0(embeddings)
        ac_concat = torch.cat((embeddings, ac0), axis=1)
        ac1 = self.nn1(ac_concat)
        ac_concat = torch.cat((ac_concat, ac1), axis=1) # stops learning these two, issue here?
        ac2 = self.nn2(ac_concat)
        ac_concat = torch.cat((ac_concat, ac2), axis=1)
        ac3 = self.nn3(ac_concat)
        ac4 = self.nn4(ac_concat)

        # Each action has some values corresponding to the likelihood of each choice
        return [ac0, ac1, ac2, ac3, ac4]

    @torch.no_grad()
    def get_action(self, obs, mask=None, verbose=False, sym=None) -> List[np.ndarray]:
        """Takes a single observation (as a numpy array) and returns a single action (as a numpy array)."""

        ac = self([obs.to(ptu.device)]) 
        
        if (mask is None) or isinstance(mask, (np.ndarray, np.generic)):
            if mask is None: 
                mask = np.ones(ac[0].shape)
            if verbose:
                print(ac[0].shape, ac[2].shape, ac[3].shape, mask.shape)

            dist0 = torch.distributions.categorical.Categorical(logits=ac[0] + torch.log(ptu.from_numpy(mask)))
            dist1 = torch.distributions.categorical.Categorical(logits=ac[1] + torch.log(ptu.from_numpy(mask)))
            dist2 = torch.distributions.categorical.Categorical(logits=ac[2])
            dist3 = torch.distributions.categorical.Categorical(logits=ac[3])
            dist4 = torch.distributions.categorical.Categorical(logits=ac[4])

            action = [ptu.to_numpy(dist0.sample()),ptu.to_numpy(dist1.sample()),
                    ptu.to_numpy(dist2.sample()), ptu.to_numpy(dist3.sample()), ptu.to_numpy(dist4.sample())]
            if verbose:
                print(action)

            return action
        else: 
            mask_from = mask.get_mask_from()
            dist0 = torch.distributions.categorical.Categorical(logits=ac[0] + torch.log(ptu.from_numpy(mask_from)))
            act0 = ptu.to_numpy(dist0.sample())
            mask_to = mask.get_mask_to(act0[0], sym)
            dist1 = torch.distributions.categorical.Categorical(logits=ac[1] + torch.log(ptu.from_numpy(mask_to)))
            act1 = ptu.to_numpy(dist1.sample())
            mask_edges = mask.get_mask_edge(act0[0], act1[0])
            dist2 = torch.distributions.categorical.Categorical(logits=ac[2] + torch.log(ptu.from_numpy(mask_edges)))
            act2 = ptu.to_numpy(dist2.sample())
            
            mask_sym = mask.get_mask_sym(act0[0], act1[0], act2[0], sym)
            dist4 = torch.distributions.categorical.Categorical(logits=ac[4] + torch.log(ptu.from_numpy(mask_sym)))
            act4 = ptu.to_numpy(dist4.sample())

            mask_term = mask.get_mask_term(act0[0], act1[0],act2[0], act4[0], sym)
            dist3 = torch.distributions.categorical.Categorical(logits=ac[3] + torch.log(ptu.from_numpy(mask_term)))
            act3 = ptu.to_numpy(dist3.sample())

            action = [act0, act1, act2, act3, act4]
            import copy
            masks = copy.deepcopy([mask_from, mask_to, mask_edges, mask_term, mask_sym])

            return action, masks


    def evaluate(self, obs: List[np.ndarray], actions: List[np.ndarray], mask=None, reward=None, valid=None, 
                 verbose=False):
        ac_dist = self(obs)
        if (mask is None) or isinstance(mask, (np.ndarray, np.generic)):
            if mask is None: 
                mask = np.ones(ac_dist[0].shape)
            # Squeeze actions[0] into correct dimension
            act0 = torch.tensor(actions[np.arange(len(actions)), 0]).squeeze().to(ptu.device)
            dist0 = torch.distributions.categorical.Categorical(logits=ac_dist[0] + torch.log(ptu.from_numpy(mask)))
            act1 = torch.tensor(np.array(actions[np.arange(len(actions)), 1])).squeeze().to(ptu.device)
            dist1 = torch.distributions.categorical.Categorical(logits=ac_dist[1] + torch.log(ptu.from_numpy(mask)))
            act2 = torch.tensor(np.array(actions[np.arange(len(actions)), 2])).squeeze().to(ptu.device)
            dist2 = torch.distributions.categorical.Categorical(logits=ac_dist[2])
            act3 = torch.tensor(np.array(actions[np.arange(len(actions)), 3])).squeeze().to(ptu.device)
            dist3 = torch.distributions.categorical.Categorical(logits=ac_dist[3])
            act4 = torch.tensor(np.array(actions[np.arange(len(actions)), 4])).squeeze().to(ptu.device)
            dist4 = torch.distributions.categorical.Categorical(logits=ac_dist[4])
        else: 
            act0 = torch.tensor(actions[np.arange(len(actions)), 0]).squeeze().to(ptu.device)
            mask0 = torch.tensor([mask[i][0] for i in range(len(actions))]).squeeze().to(ptu.device)
            dist0 = torch.distributions.categorical.Categorical(logits=ac_dist[0] + torch.log(mask0))
            
            act1 = torch.tensor(np.array(actions[np.arange(len(actions)), 1])).squeeze().to(ptu.device)
            mask1 = torch.tensor([mask[i][1] for i in range(len(actions))]).squeeze().to(ptu.device)
            dist1 = torch.distributions.categorical.Categorical(logits=ac_dist[1] + torch.log(mask1))

            act2 = torch.tensor(np.array(actions[np.arange(len(actions)), 2])).squeeze().to(ptu.device)
            mask2 = torch.tensor([mask[i][2] for i in range(len(actions))]).squeeze().to(ptu.device)
            dist2 = torch.distributions.categorical.Categorical(logits=ac_dist[2] + torch.log(mask2))

            act3 = torch.tensor(np.array(actions[np.arange(len(actions)), 3])).squeeze().to(ptu.device)
            mask3 = torch.tensor([mask[i][3] for i in range(len(actions))]).squeeze().to(ptu.device)
            dist3 = torch.distributions.categorical.Categorical(logits=ac_dist[3] + torch.log(mask3))

            act4 = torch.tensor(np.array(actions[np.arange(len(actions)), 4])).squeeze().to(ptu.device)
            mask4 = torch.tensor([mask[i][4] for i in range(len(actions))]).squeeze().to(ptu.device)
            dist4 = torch.distributions.categorical.Categorical(logits=ac_dist[4] + torch.log(mask4))

        if verbose: 
            if valid is None:
                for idx, (r, a0, a1, a2, a3, a4)in enumerate(zip(reward, act0, act1, act2, act3, act4)):
                    print("reward: {}, action: {}, prob: {}".format(r, [int(a0), int(a1), int(a2), int(a3), int(a4)], 
                                                    [float(torch.exp(dist0.log_prob(a0)[idx])), 
                                                    float(torch.exp(dist1.log_prob(a1)[idx])),
                                                    float(torch.exp(dist2.log_prob(a2)[idx])), float(torch.exp(dist3.log_prob(a3)[idx])),
                                                    float(torch.exp(dist4.log_prob(a4)[idx]))]))
            else: 
                for idx, (valid,r, a0, a1, a2, a3, a4)in enumerate(zip(valid, reward, act0, act1, act2, act3, act4)):
                    print("valid: {}, reward: {}, action: {}, prob: {}".format(valid, r, [int(a0), int(a1), int(a2), int(a3), int(a4)], 
                                                [float(torch.exp(dist0.log_prob(a0)[idx])), 
                                                float(torch.exp(dist1.log_prob(a1)[idx])),
                                                float(torch.exp(dist2.log_prob(a2)[idx])), float(torch.exp(dist3.log_prob(a3)[idx])),
                                                float(torch.exp(dist4.log_prob(a4)[idx]))]))
        logp = dist0.log_prob(act0) + dist1.log_prob(act1) + dist2.log_prob(act2) + dist3.log_prob(act3) + dist4.log_prob(act4)
        dist_entropy = dist0.entropy() + dist1.entropy() + dist2.entropy() + dist3.entropy() + dist4.entropy()
        return logp, dist_entropy

