import itertools
from typing import List
from torch import nn
from torch import optim

import numpy as np
import torch

from infrastructure import pytorch_util as ptu
from networks.gnn import GCNNet, GINE
class ValueCritic(nn.Module):
    """Value network, which takes an observation and outputs a value for that observation."""

    def __init__(
        self,
        ob_feature_dim: int,
        embed_dim: int,
        n_gcn_layers: int,
        n_layers: int,
        layer_size: int,
        learning_rate: float,
    ):  
        super().__init__()
        self.gcn_out_dim = embed_dim
        self.gcn = GINE(ob_feature_dim, embed_dim, self.gcn_out_dim, n_gcn_layers) 

        self.norm = nn.BatchNorm1d(self.gcn_out_dim*n_gcn_layers).to(ptu.device)
        self.network = ptu.build_mlp(
            input_size=self.gcn_out_dim*n_gcn_layers,
            output_size=1,
            n_layers=n_layers*2,
            size=layer_size,
        ).to(ptu.device)

        self.parameters = itertools.chain(self.gcn.parameters(), self.network.parameters())
        self.optimizer = optim.Adam(
            self.parameters,
            lr = learning_rate,
        )

    def forward(self, obs: List[np.ndarray]) -> torch.Tensor:
        # First compute node embeddings for every graph passed in
        obs_x = []
        obs_edge_index = []
        obs_edge_attr = []
        for o in obs:
            obs_x.append(o.node_type)
            obs_edge_index.append(o.edge_index)
            obs_edge_attr.append(o.edge_attr)

        embeddings = torch.empty((0)).to(ptu.device)
        embeddings.requires_grad = True
        for idx, (x, edge, edge_attr) in enumerate(zip(obs_x, obs_edge_index, obs_edge_attr)):
            embeddings = torch.cat((embeddings, self.gcn(x.to(torch.float32).squeeze(), 
                                                         edge, edge_attr, dropout=0)), axis=0)
        embeddings = self.norm(embeddings)
        # Then compute the  multilayer perceptron value
        value = self.network(embeddings)
        return value
        

    def update(self, obs: List[np.ndarray], q_values: np.ndarray) -> dict:
        # update the critic using the observations and q_values
        out = self(obs).squeeze()
        q_values = ptu.from_numpy(q_values)

        loss = torch.mean((out-q_values)**2/2)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return {
            "Baseline Loss": ptu.to_numpy(loss.detach()),
        }
   