"""
Het-JO-IPPO

Jointly Observable Heterogenous IPPO (Het-JO-IPPO) provides the joint observations of all agents as input in addition
to the agent observation.

Configurations:
- Joint observations encoded by SAE to latent_dim (pre-trained/policy losses/reconstruction losses)
- Joint observations encoded by MLP to latent_dim (pre-trained/policy losses/reconstruction losses)
- Joint observations not encoded
"""

from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
import torch

import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../"))
from sae.model import AutoEncoder as PISA
from train_cnn import Autoencoder as CNNAE
from scenario_config import SCENARIO_CONFIG

# TODO: What matters in on-policy RL?
# - Separate independent policy and value networks
# - For policy net, MLP width correlates with task complexity
# - For value net, MLP should be wide (sometimes wider than policy)
# - MLP depth should be 2 hidden layers for both policy and value net
# - Final layer weights should be 100x smaller than other layers

POLICY_WIDTH=128
VALUE_WIDTH=1024

class PolicyHetJOIPPO(TorchModelV2, torch.nn.Module):

    def __init__(self, observation_space, action_space, num_outputs, model_config, name, *args, **kwargs):

        # Call super class constructors
        TorchModelV2.__init__(self, observation_space, action_space, num_outputs, model_config, name)
        torch.nn.Module.__init__(self)

        device = "cuda" if torch.cuda.is_available() else "cpu"

        self.scenario = kwargs["scenario"]
        self.task_agnostic = kwargs["task_agnostic"]
        self.task_specific = kwargs["task_specific"]
        self.train_specific = kwargs["train_specific"]
        self.pisa_dim = kwargs["pisa_dim"]
        self.no_comms = kwargs["no_comms"]
        self.cnn_path = kwargs["cnn_path"]
        self.pisa_path = kwargs["pisa_path"]
        self.n_agents = SCENARIO_CONFIG[self.scenario]["num_agents"]
        self.in_the_matrix = "matrix" in self.scenario
        if self.in_the_matrix:
            self.inv_dim = SCENARIO_CONFIG[self.scenario]["INVENTORY"]

        obs_size = kwargs["single_obs_size"]
        obs_w = int((obs_size / 3) ** 0.5)

        # Load pre-trained image encoder
        self.cnn_autoencoder = CNNAE(
            image_width=obs_w,
            latent_dim=self.pisa_dim
        ).to(device)
        self.cnn_autoencoder.load_state_dict(torch.load(
            self.cnn_path,
            map_location=torch.device(device)
        ))

        # Freeze the image encoder
        for p in self.cnn_autoencoder.parameters():
            p.requires_grad = False
            p.detach_()
        
        if self.task_agnostic or self.task_specific:
            # Load pre-trained PISA
            self.pisa = PISA(
                dim=self.pisa_dim,
                hidden_dim=self.pisa_dim * self.n_agents
            ).to(device)
            self.pisa.load_state_dict(torch.load(
                self.pisa_path,
                map_location=torch.device(device)
            ))

            # Freeze PISA
            for p in self.pisa.parameters():
                p.requires_grad = False
                p.detach_()
        
        if self.train_specific:
            # Construct randomly initialised PISA
            self.pisa = PISA(
                dim=self.pisa_dim,
                hidden_dim=self.pisa_dim * self.n_agents,
            ).to(device)

        # Input features are PISA-encoding of all observations + image encoding
        # of the current agent's observation
        if not self.in_the_matrix:
            feature_dim = self.pisa_dim * (self.n_agents + 1)
        else:
            feature_dim = self.pisa_dim * (self.n_agents + 1) + self.inv_dim * 2

        self.policy_head = torch.nn.Sequential(
            torch.nn.Linear(
                in_features=feature_dim,
                out_features=POLICY_WIDTH
            ),
            torch.nn.Tanh(),
            torch.nn.Linear(
                in_features=POLICY_WIDTH,
                out_features=POLICY_WIDTH,
            ),
            torch.nn.Tanh(),
        )
        for layer in self.policy_head:
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.normal_(layer.weight, mean=0.0, std=1.0)
                torch.nn.init.normal_(layer.bias, mean=0.0, std=1.0)
        policy_last = torch.nn.Linear(
                in_features=POLICY_WIDTH,
                out_features=num_outputs,  # Discrete: action_space[0].n
        )
        torch.nn.init.normal_(policy_last.weight, mean=0.0, std=0.01)
        torch.nn.init.normal_(policy_last.bias, mean=0.0, std=0.01)
        self.policy_head.add_module("policy_last", policy_last)

        self.value_head = torch.nn.Sequential(
            torch.nn.Linear(
                in_features=feature_dim,
                out_features=VALUE_WIDTH
            ),
            torch.nn.Tanh(),
            torch.nn.Linear(
                in_features=VALUE_WIDTH,
                out_features=VALUE_WIDTH,
            ),
            torch.nn.Tanh(),
        )
        for layer in self.value_head:
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.normal_(layer.weight, mean=0.0, std=1.0)
                torch.nn.init.normal_(layer.bias, mean=0.0, std=1.0)
        value_last = torch.nn.Linear(
            in_features=VALUE_WIDTH,
            out_features=1
        )
        torch.nn.init.normal_(value_last.weight, mean=0.0, std=0.01)
        torch.nn.init.normal_(value_last.bias, mean=0.0, std=0.01)
        self.value_head.add_module("value_last", value_last)

        self.current_value = None

    def forward(self, inputs, state, seq_lens):

        observation, batch, agent_features, n_batches = self.process_inputs(inputs)

        x = observation
        x = self.cnn_autoencoder.encode(x)
        if not self.no_comms:
            x = self.pisa.encoder(x, batch=batch)
        else:
            x = x.reshape(-1, self.pisa_dim * self.n_agents)

        o_i = self.cnn_autoencoder.encode(agent_features)

        # If we are looking at an * in the matrix environment,
        # we also need to provide the inventory of the agents.
        # We can add the inventories of both agents to create a permutation-invariant state.
        # And we can add the inventory of the current agent separately.
        # We do not expose the ready_to_shoot function.

        if self.in_the_matrix:
            # Retrieve the relevant agents' own inventory
            self_inv = inputs['obs']['INVENTORY']

            if not self.no_comms:
                # Construct a permutation-invariant state for the inventory
                inv = torch.zeros((agent_features.shape[0], self.inv_dim), device=x.device)
                for key in inputs['obs']['all'].keys():
                    if 'INVENTORY' in key:
                        inv += inputs['obs']['all'][key]

        if not self.no_comms:
            if not self.in_the_matrix:
                input_features = torch.cat((x, o_i), dim=1)
            else:
                input_features = torch.cat((x, o_i, inv, self_inv), dim=1)
        else:
            if not self.in_the_matrix:
                input_features = torch.cat((torch.zeros_like(x), o_i), dim=1)
            else:
                input_features = torch.cat((torch.zeros_like(x), o_i, torch.zeros_like(self_inv), self_inv), dim=1)
            
        self.current_value = self.value_head(input_features).squeeze(1)
        logits = self.policy_head(input_features)

        # If training with policy losses, NaNs can occasionally occur.
        logits = torch.nan_to_num(
            logits, nan=0.0, posinf=0.0, neginf=0.0
        )
        self.current_value = torch.nan_to_num(
            self.current_value, nan=0.0, posinf=0.0, neginf=0.0
        )

        return logits, state

    def value_function(self):
        return self.current_value  # [batches, n_agents]

    def process_inputs(self, inputs):

        this_obs = inputs["obs"]["RGB"].permute(0, 3, 1, 2).float()  # [batches, obs_size] + channel-first
        this_obs /= 255.0

        # Collect other observations
        observation = torch.stack(
            [
                inputs["obs"]["all"][f"player_{i}"].permute(
                    0, 3, 1, 2
                ).float() for i in range(self.n_agents)
            ],
            dim=1
        )  # [batches, agents, obs_size] (hopefully)
        observation /= 255.0

        n_batches = observation.shape[0]

        observation = torch.flatten(observation, start_dim=0, end_dim=1)  # [batches * agents, obs_size]
        batch = torch.arange(
            n_batches, device=observation.device
        ).repeat_interleave(self.n_agents)

        return observation.float().contiguous(), batch, this_obs.float().contiguous(), n_batches
