import copy
import itertools
import random

import torch
import torch.nn as nn
from torch_geometric.data.batch import Batch

from ..buffers import RolloutBuffer
import gcip.utils.io as pb_io
import itertools
class GraphPPO(nn.Module):
    """GraphActorCritic: A class representing the Actor-Critic architecture for graph-based reinforcement learning.

        This class implements a graph-based Actor-Critic architecture for reinforcement learning. It takes as input a
        graph neural network (gnn) and applies a set of fully-connected layers to compute the action logits and the
        value function.

        Args:
            gnn (nn.Module): A graph neural network to extract features from the input graph.
            action_refers_to (str): A string indicating whether the actions should be taken on nodes ('node') or edges ('edge').
            pool_type (str): A string indicating the type of graph pooling to be applied.
            action_distr (str): A string indicating the type of action distribution to use, either 'ber' for Bernoulli or 'cb'
                for continuous Bernoulli.
            act_fn (str): A string indicating the activation function to use in the network.
            init_fn (callable, optional): A function to initialize the network weights. Defaults to None.
        """

    def __init__(self, policy,
                 eps_clip=0.9,
                 gamma=0.99,
                 coeff_mse=0.0,
                 coeff_entropy=0.0):

        self.eps_clip = eps_clip

        self.gamma = gamma

        self.coeff_mse = coeff_mse
        self.coeff_entropy = coeff_entropy

        super(GraphPPO, self).__init__()

        self.buffer = RolloutBuffer()
        self.policy = policy

        self.policy_old = copy.deepcopy(policy)

        self.mse_loss = nn.MSELoss(reduction='none')

    def get_optimization_config(self, lr_actor, lr_critic):
        l_common = (lr_actor + lr_critic) / 2.0
        return [
            {'params': self.policy.actor_params(), 'lr': lr_actor},
            {'params': self.policy.critic_params(), 'lr': lr_critic}
        ]

    @torch.no_grad()
    def act(self, state, return_logprobs=False, sample=True, values=False):

        output = self.policy_old.compute(state,
                                         sample=sample,
                                         detach=True,
                                         action=2 if return_logprobs else 1,
                                         values=values)
        if len(output) == 1:
            return output[0]
        else:
            return output

    def forward(self, shuffle=False):
        assert not shuffle
        assert self.policy.training
        assert self.policy.gnn_actor.training
        assert self.policy.gnn_critic.training
        assert self.policy.actor.training
        assert self.policy.critic.training

        old_states, old_actions, old_logprobs, rewards_norm, rewards, advantages = self.buffer.experiences
        logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
        action_mean = self.policy.action_distr.mean()
        # match state_values tensor dimensions with rewards tensor
        state_values = torch.squeeze(state_values)

        # Finding the ratio (pi_theta / pi_theta__old)
        ratios = torch.exp(logprobs - old_logprobs.detach())

        # Finding Surrogate Loss
        surr1 = ratios * advantages
        surr2 = torch.clamp(ratios, 1 - self.eps_clip, 1 + self.eps_clip) * advantages

        diff_surr = surr1 - surr2
        # final loss of clipped objective PPO
        my_surr = torch.min(surr1, surr2)

        my_mse = self.mse_loss(state_values, rewards_norm)

        my_mse_coeff = self.coeff_mse * my_mse

        objective = my_surr - my_mse_coeff + self.coeff_entropy * dist_entropy

        loss = - objective

        loss_dict = {'loss': loss,
                     'surrogate': my_surr.detach(),
                     'diff_surr': diff_surr.detach(),
                     'ratios': ratios.detach(),
                     'actions_mu': action_mean.detach(),
                     'mse': my_mse.detach(),
                     'mse_coeff': my_mse_coeff.detach(),
                     'entropy': dist_entropy.detach(),
                     'rewards': rewards.detach(),
                     }

        return loss_dict

    @torch.no_grad()
    def prepare_forward(self, env, n_steps=10):
        done = True
        state = None

        idx_list = list(range(len(env.loader)))

        random.shuffle(idx_list)

        if len(idx_list) < n_steps:
            pb_io.print_warning(f"Number of graphs {len(idx_list)} in the dataset is less than the number of steps requested. ")
        for idx in idx_list[:n_steps]:
            if done:
                state = env.reset(idx=idx)
            action, logprobs, state_values = self.act(state,
                                                      return_logprobs=True,
                                                      sample=True,
                                                      values=True)
            self.buffer.append(state=state,
                               state_value=state_values)
            state, reward, done, info = env.step(action)

            self.buffer.append(reward=reward,
                               done=done,
                               action=action,
                               action_logprob=logprobs)

        self.buffer.prepare(gamma=self.gamma,
                            device='cpu')

    def forward_end(self):
        self.policy_old.load_state_dict(self.policy.state_dict())
        self.buffer.clear()

    @torch.no_grad()
    def run_episode(self, batch, env, sample=True, transform=True, num_samples=1):

        data_list = batch.to_data_list()
        batch_out = []

        assert self.policy.training == False
        assert self.policy_old.training == False

        for graph_i in data_list:
            for _ in range(num_samples):
                state = env.reset(graph=graph_i, transform=transform)
                done = False
                while not done:
                    action = self.act(state=state, sample=sample)
                    state, _, done, info = env.step(action=action)

                graph_out = env.get_final_graph()
                delattr(graph_out, 'batch')

                graph_out.action = action
                batch_out.append(graph_out)

        batch = Batch.from_data_list(data_list=batch_out)
        batch.batch = batch.batch

        env.on_episode_end()

        return batch
