import copy

import torch
import torch.nn as nn

from gcip.modules.gnn.pooling import GraphPooling
from gcip.utils.activations import get_act_fn
from ..distributions import MultiGraphBernoulliDistribution, MultiGraphContBernoulliDistribution

import gcip.utils.io as pb_io
from gcip.modules.mlp import MLP


class GraphActorCritic(nn.Module):
    """GraphActorCritic: A class representing the Actor-Critic architecture for graph-based reinforcement learning.

        This class implements a graph-based Actor-Critic architecture for reinforcement learning. It takes as input a
        graph neural network (gnn) and applies a set of fully-connected layers to compute the action logits and the
        value function.

        Args:
            gnn (nn.Module): A graph neural network to extract features from the input graph.
            action_refers_to (str): A string indicating whether the actions should be taken on nodes ('node') or edges ('edge').
            pool_type (str): A string indicating the type of graph pooling to be applied.
            action_distr (str): A string indicating the type of action distribution to use, either 'ber' for Bernoulli or 'cb'
                for continuous Bernoulli.
            act_fn (str): A string indicating the activation function to use in the network.
            init_fn (callable, optional): A function to initialize the network weights. Defaults to None.
        """

    def __init__(self, gnn,
                 action_refers_to,
                 pool_type,
                 action_distr,
                 act_fn,
                 bn=False,
                 dropout=0.0,
                 init_fn=None):
        super(GraphActorCritic, self).__init__()

        assert action_refers_to in ['node', 'edge']
        assert action_distr in ['ber', 'cb']  # Bernoulli or continuous bernoulli

        hidden_channels = gnn.output_size

        self.action_refers_to = action_refers_to

        self.gnn_actor = gnn
        self.device = gnn.device

        self.gnn_critic = copy.deepcopy(gnn)
        dim = hidden_channels if self.action_refers_to == 'node' else 2 * hidden_channels

        hidden_size = max(dim//2, 1)
        # self.actor = nn.Sequential(
        #     MLP(input_size=dim,
        #         hidden_size=hidden_size,
        #         output_size=1,
        #         num_layers=2,
        #         act_fn=act_fn,
        #         has_bn=bn,
        #         dropout=dropout,
        #         drop_last=True))

        self.actor = nn.Sequential(
            nn.Linear(dim, 1, device=self.device),
        ).to(self.device)

        # self.critic = nn.Sequential(
        #     MLP(input_size=hidden_channels,
        #         hidden_size=hidden_size,
        #         output_size=1,
        #         num_layers=2,
        #         act_fn=act_fn,
        #         has_bn=bn,
        #         dropout=dropout,
        #         drop_last=True))

        self.critic = nn.Sequential(
            nn.Linear(hidden_channels, 1, device=self.device),
        ).to(self.device)

        self.act_fn = get_act_fn(act_fn)

        self.graph_pooling = GraphPooling(pool_type=pool_type,
                                          in_channels=hidden_channels,
                                          activation=get_act_fn(act_fn),
                                          bn=False)

        if action_distr == 'ber':
            self.action_distr = MultiGraphBernoulliDistribution()
        elif action_distr == 'cb':
            self.action_distr = MultiGraphContBernoulliDistribution()

        if init_fn is not None:
            self.apply(init_fn)

    def _get_attr(self, batch, attr, refers_to='node'):
        my_attr = getattr(batch, attr)

        if attr == 'batch' and refers_to == 'edge':
            edge_index = getattr(batch, 'edge_index')
            my_attr = my_attr[edge_index[0]]

        return my_attr

    def _get_edge_attr(self, state):
        if hasattr(state, 'edge_attr'):
            edge_attr = self._get_attr(state, 'edge_attr')
        elif hasattr(state, 'edge_feature'):
            edge_attr = self._get_attr(state, 'edge_feature')
        else:
            edge_attr = None

        return edge_attr

    def actor_params(self):
        params = list(self.actor.parameters())
        params.extend(list(self.gnn_actor.parameters()))
        return nn.ParameterList(params)

    def critic_params(self):
        params = list(self.critic.parameters())
        params.extend(list(self.gnn_critic.parameters()))
        return nn.ParameterList(params)

    def feature_extractor_actor(self, batch, **kwargs):
        logits_actor = self.gnn_actor(batch.clone(),
                          **kwargs)

        logits_actor = self.act_fn(logits_actor)
        return logits_actor

    def feature_extractor_critic(self, batch, **kwargs):

        logits_critic = self.gnn_critic(batch.clone(),
                                        **kwargs)

        logits_critic = self.act_fn(logits_critic)


        return logits_critic

    def forward(self):
        raise NotImplementedError

    def get_action_logits_from_features(self, z, edge_index):
        if self.action_refers_to == 'edge':
            zl = z[edge_index[0]]
            zr = z[edge_index[1]]
            z2 = torch.cat([zl, zr], dim=1)
            action_logits = self.actor(z2)
        else:
            action_logits = self.actor(z)

        return action_logits - 2.0

    def compute_action_logits(self, state):

        edge_index = self._get_attr(state, 'edge_index')

        z_actor = self.feature_extractor_actor(batch=state)

        action_logits = self.get_action_logits_from_features(z=z_actor,
                                                             edge_index=edge_index)
        return action_logits

    def act(self, state, sample=True):
        action_logits = self.compute_action_logits(state=state)
        batch = self._get_attr(state, 'batch', refers_to=self.action_refers_to)

        self.action_distr.proba_distribution(action_logits=action_logits,
                                             batch_idx=batch)
        if sample:
            actions = self.action_distr.sample()
        else:
            actions = self.action_distr.mode()

        action_logprobs = self.action_distr.log_prob(actions)
        return actions.detach(), action_logprobs.detach()

    def _graph_pooling(self, z, batch):

        z = self.graph_pooling.forward2(x=z, batch=batch)
        return z

    def compute(self, state, sample, detach=True, action=2, values=True):
        state = state.to(self.device)
        edge_index = self._get_attr(state, 'edge_index')

        batch = self._get_attr(state, 'batch', refers_to=self.action_refers_to)

        z_actor = self.feature_extractor_actor(batch=state)

        output = []

        if action > 0:

            action_logits = self.get_action_logits_from_features(z=z_actor,
                                                                 edge_index=edge_index)


            self.action_distr.proba_distribution(action_logits=action_logits,
                                                 batch_idx=batch)

            if sample:
                actions = self.action_distr.sample()
            else:
                actions = self.action_distr.mode()

            output.append(actions)
            if action == 2:
                action_logprobs = self.action_distr.log_prob(actions)

                output.append(action_logprobs)

        if values:
            z_critic = self.feature_extractor_critic(batch=state)
            z_pool = self._graph_pooling(z=z_critic, batch=state.batch)

            state_values = self.critic(z_pool)
            # pb_io.print_debug_tensor(state_values, 'state_values')
            output.append(state_values)

        if detach:
            output = [o.detach() for o in output]
        return output

    def get_state_values(self, state):
        z = self.feature_extractor_critic(batch=state)
        batch = self._get_attr(state, 'batch')
        z_pool = self._graph_pooling(z=z, batch=batch)

        state_values = self.critic(z_pool)

        return state_values

    def evaluate(self, state, action):
        state = state.to(self.device)
        action = action.to(self.device)
        edge_index = self._get_attr(state, 'edge_index')

        z_actor = self.feature_extractor_actor(batch=state)

        action_logits = self.get_action_logits_from_features(z=z_actor,
                                                             edge_index=edge_index)

        batch = self._get_attr(state, 'batch', refers_to=self.action_refers_to)
        self.action_distr.proba_distribution(action_logits=action_logits,
                                             batch_idx=batch)
        action_logprobs = self.action_distr.log_prob(action)
        dist_entropy = self.action_distr.entropy()

        batch = self._get_attr(state, 'batch')
        z_critic = self.feature_extractor_critic(batch=state)
        z_pool = self._graph_pooling(z=z_critic, batch=batch)
        state_values = self.critic(z_pool)
        if "cuda" in action_logprobs.device.type:
            action_logprobs = action_logprobs.to("cpu")
        if "cuda" in state_values.device.type:
            state_values = state_values.to("cpu")
        if "cuda" in dist_entropy.device.type:
            dist_entropy = dist_entropy.to("cpu")
        return action_logprobs, state_values, dist_entropy
