# MIT License

# Copyright (c) 2023 Replicable-MARL

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from ray.rllib.utils.torch_ops import FLOAT_MIN
from functools import reduce
import copy
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.rllib.models.torch.misc import SlimFC, SlimConv2d, normc_initializer
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.typing import Dict, TensorType, List
from torch.optim import Adam

# from marllib.marl.models.zoo.encoder.cc_encoder import CentralizedEncoder
from marllib.marl.models.zoo.tree.base_tree import BaseTree 
from marllib.marl.models.zoo.tree.tree_utils import TreeLSTM, extract_features, Transformer

torch, nn = try_import_torch()

class CentralizedCriticTree(BaseTree):
    """Generic fully connected network."""

    def __init__(
            self,
            obs_space,
            action_space,
            num_outputs,
            model_config,
            name,
            **kwargs,
    ):

        super().__init__(obs_space, action_space, num_outputs, model_config,
                         name, **kwargs)

        agent_attr_sz = 83
        node_sz = 12 
        hidden_sz = 128
        tree_embedding_sz = 128

        # encoder for centralized VF
        self.cc_vf_tree_lstm = TreeLSTM(node_sz, tree_embedding_sz)
        self.cc_vf_attr_embedding = nn.Sequential(
                nn.Linear(agent_attr_sz, 2 * hidden_sz),
                nn.GELU(),
                nn.Linear(2 * hidden_sz, 2 * hidden_sz),
                nn.GELU(),
                nn.Linear(2 * hidden_sz, 2 * hidden_sz),
                nn.GELU(),
                nn.Linear(2 * hidden_sz, hidden_sz),
                nn.GELU(),
        )

        self.cc_vf_transformer = nn.Sequential(
                Transformer(hidden_sz + tree_embedding_sz , 4),
                Transformer(hidden_sz + tree_embedding_sz, 4),
                Transformer(hidden_sz + tree_embedding_sz, 4),
            )

        self.cc_vf_branch = SlimFC(
            in_size=hidden_sz + tree_embedding_sz,
            out_size=1,
            initializer=normc_initializer(0.01),
            activation_fn=None)


        self.q_flag = False
        if self.custom_config["algorithm"] in ["coma"]:
            raise NotImplementedError

        if self.custom_config["algorithm"] in ["hatrpo", "happo"]:
            raise NotImplementedError

    def central_value_function(self, state, opponent_actions=None) -> TensorType:
        assert self._features is not None, "must call forward() first"

        agent_attr, forest, adjacency, node_order, edge_order = extract_features(state)
        
        # M: tree-based encoder: same as base_tree.py
        tree_embedding = self.cc_vf_tree_lstm(forest, adjacency, node_order, edge_order)
        agent_attr_embedding = self.cc_vf_attr_embedding(agent_attr)

        features = torch.cat([agent_attr_embedding, tree_embedding], dim=-1) # [bs, num_agents, hidden_dim]

        features = self.cc_vf_transformer(features)

        # M: mean -> value or value -> mean
        x = self.cc_vf_branch(features) # [bs, num_agents, 1]
        x = x.mean(1) # [bs, 1]

        if self.q_flag:
            return torch.reshape(x, [-1, self.num_outputs])
        else:
            return torch.reshape(x, [-1])

    @override(BaseTree)
    def critic_parameters(self):
        critics = [self.cc_vf_tree_lstm, self.cc_vf_attr_embedding, self.cc_vf_transformer, self.cc_vf_branch]
        return reduce(lambda x, y: x + y, map(lambda p: list(p.parameters()), critics))

    def link_other_agent_policy(self, agent_id, policy):
        if agent_id in self.other_policies:
            if self.other_policies[agent_id] != policy:
                raise ValueError('the policy is not same with the two time look up')
        else:
            self.other_policies[agent_id] = policy

    def update_actor(self, loss, lr, grad_clip):
        CentralizedCriticTree.update_use_torch_adam(
            loss=(-1 * loss),
            optimizer=self.actor_optimizer,
            parameters=self.parameters(),
            grad_clip=grad_clip
        )

    @staticmethod
    def update_use_torch_adam(loss, parameters, optimizer, grad_clip):
        optimizer.zero_grad()
        loss.backward()
        # total_norm = torch.norm(torch.stack([torch.norm(p.grad) for p in parameters if p.grad is not None]))
        torch.nn.utils.clip_grad_norm_(parameters, grad_clip)
        optimizer.step()

