# from modules.agents import REGISTRY as agent_REGISTRY
from modules.decomposers import REGISTRY as decomposer_REGISTRY
from components.action_selectors import REGISTRY as action_REGISTRY
import torch as th
import numpy as np
import torch.nn.functional as F
from copy import deepcopy
from modules.agents import REGISTRY
from modules.agents.transfer.tr_agent import UnifiedTrActor
from utils.rl_utils import EMAMeanStd
from modules.critics.comad import SkillEncoder

class TrCOMADMAC:
    def __init__(self, all_tasks, train_tasks, trans_tasks, task2scheme, task2args, main_args):
        self.all_tasks = all_tasks
        self.train_tasks = train_tasks
        self.trans_tasks = trans_tasks
        self.task2scheme = task2scheme
        self.task2args = task2args
        self.task2n_agents = {task: self.task2args[task].n_agents for task in all_tasks}
        self.main_args = main_args
        self.xi_dim = 1
        self.rho_dim = self.main_args.rho_dim
        self.xi_dim = self.main_args.xi_dim

        self.agent_output_type = main_args.agent_output_type
        self.action_selector = action_REGISTRY[main_args.action_selector](main_args)

        if self.main_args.env not in ["sc2", "sc2_v2", "gymma", "grid_mpe", "mamujoco"]:
            raise NotImplementedError
        env2decomposer = {
            "sc2": "sc2_decomposer",
            "sc2_v2": "sc2_v2_decomposer",
            "gymma": "gymma_decomposer",
            "grid_mpe": "mpe_decomposer",
            "mamujoco": "mamujoco_decomposer",
        }
        self.task2decomposer = {}
        self.surrogate_decomposer = None
        
        self.is_finetune = getattr(main_args, 'is_finetune', False)

        match self.main_args.env:
            case "sc2" | "sc2_v2":
                (
                    aligned_unit_type_bits,
                    aligned_shield_bits_ally,
                    aligned_shield_bits_enemy,
                ) = (0, 0, 0)
                map_type_set = set()
                for task in all_tasks:
                    task_args = self.task2args[task]
                    task_decomposer = decomposer_REGISTRY[
                        env2decomposer[task_args.env]
                    ](task_args)

                    aligned_shield_bits_ally = max(
                        aligned_shield_bits_ally, task_decomposer.shield_bits_ally
                    )
                    aligned_shield_bits_enemy = max(
                        aligned_shield_bits_enemy, task_decomposer.shield_bits_enemy
                    )
                    for unit_type in task_decomposer.unit_types:
                        map_type_set.add(unit_type)

                    self.task2decomposer[task] = task_decomposer
                aligned_unit_type_bits = (
                    0 if len(map_type_set) == 1 else len(map_type_set)
                )
                for task in all_tasks:
                    self.task2decomposer[task].align_feats_dim(
                        aligned_unit_type_bits,
                        aligned_shield_bits_ally,
                        aligned_shield_bits_enemy,
                        map_type_set,
                    )
                    if not self.surrogate_decomposer:
                        self.surrogate_decomposer = self.task2decomposer[task]
            case "gymma" | "grid_mpe" :
                for task in all_tasks:
                    task_args = self.task2args[task]
                    task_decomposer = decomposer_REGISTRY[env2decomposer[task_args.env]](task_args)
                    self.task2decomposer[task] = task_decomposer
                for task in all_tasks:
                    if not self.surrogate_decomposer:
                        self.surrogate_decomposer = self.task2decomposer[task]
            case "mamujoco": # TODO mamujoco
                for task in all_tasks:
                    task_args = self.task2args[task]
                    task_decomposer = decomposer_REGISTRY[env2decomposer[task_args.env]](task_args)
                    self.task2decomposer[task] = task_decomposer
                for task in all_tasks:
                    if not self.surrogate_decomposer:
                        self.surrogate_decomposer = self.task2decomposer[task]

        # build agents
        self.task2input_shape_info = self._get_input_shape()
        self._build_agents()

        self.hidden_states = None

    def select_actions(self, ep_batch, t_ep, t_env, task, bs=slice(None), test_mode=False): # for execution
        avail_actions = ep_batch["avail_actions"][:, t_ep]
        
        agent_outputs = self.forward(ep_batch, None, t_ep, task, train_stage=0, test_mode=test_mode)[0]
        
        agent_outputs = agent_outputs.expand_as(avail_actions)

        chosen_actions = self.action_selector.select_action(agent_outputs[bs], avail_actions[bs], t_env, test_mode=test_mode)
        
        return chosen_actions
    
    def process_logits(self, agent_out, ep_batch, t, task, test_mode):
        if agent_out is None:
            return None
        
        avail_actions = ep_batch["avail_actions"][:, t] # (bs, n_agents, n_act)
        if len(agent_out.shape) == 4:
            avail_actions = avail_actions.unsqueeze(0).expand_as(agent_out)
        if getattr(self.main_args, "mask_before_softmax", True):
            # Make the logits for unavailable actions very negative to minimise their affect on the softmax
            agent_out[avail_actions == 0] = -1e10

        agent_out = F.softmax(agent_out, dim=-1)
        if not test_mode:
            epsilon_action_num = agent_out.size(-1)
            if getattr(self.main_args, "mask_before_softmax", True):
                epsilon_action_num = avail_actions.sum(dim=-1, keepdim=True).float()

            agent_out = ((1 - self.action_selector.epsilon) * agent_out
                            + th.ones_like(agent_out) * self.action_selector.epsilon/epsilon_action_num)

            if getattr(self.main_args, "mask_before_softmax", True):
                # Zero out the unavailable actions, which have been softmax.
                agent_out[avail_actions == 0] = 0.0
                
        return agent_out
    
    def forward(self, ep_batch, rho, t, task, train_stage=None, test_mode=False):
        agent_inputs = self._build_inputs(ep_batch, None, task)

        mask = ep_batch["filled"].float() * (1 - ep_batch["terminated"].float())
        mask = mask.squeeze(-1)
        
        all_forward = True if train_stage == 2 else False
        self.hidden_states, agent_out, old_agent_out, off_agent_out, rho_infer = self.agent.forward(agent_inputs[:, t], rho, self.hidden_states, task, all_forward=all_forward)

        processed_old_agent_out = []
        if self.agent_output_type == "pi_logits":
            agent_out = self.process_logits(agent_out, ep_batch, t, task, test_mode)
            if all_forward:
                off_agent_out = self.process_logits(off_agent_out, ep_batch, t, task, test_mode)
                old_agent_out = self.process_logits(old_agent_out, ep_batch, t, task, test_mode)

        return agent_out, old_agent_out, off_agent_out, rho_infer
    
    def init_hidden(self, batch_size, task):
        self.hidden_states = self.agent.init_hidden(batch_size, task)

    def parameters(self):
        return self.agent.enc.parameters(), self.rho.parameters()

    def load_state(self, other_mac):
        self.agent.load_state_dict(other_mac.agent.state_dict())
        self.rho.load_state_dict(other_mac.rho.state_dict())

    def cuda(self):
        self.agent.cuda()
        self.rho.cuda()

    def save_models(self, path):
        self.agent.save(path, 'agent')
        self.rho.save(path)

    def load_models(self, path):
        self.agent.load(path, 'agent')
        self.rho.load(path)
    
    def _build_agents(self):
        inputs = [self.task2input_shape_info, self.task2decomposer, self.task2n_agents, self.surrogate_decomposer, self.main_args]
        self.agent = UnifiedTrActor(*inputs)
        self.rho = SkillEncoder(*inputs) # Posterior
    
    def _build_inputs(self, batch, t, task):
        # Assumes homogenous agents with flat observations.
        # Other MACs might want to e.g. delegate building inputs to each agent
        bs, max_t = batch.batch_size, batch.max_seq_length
        task_args, n_agents = self.task2args[task], self.task2n_agents[task]
        
        if not hasattr(self, '_agent_id_cache'):
            self._agent_id_cache = {}
        
        if task not in self._agent_id_cache:
            self._agent_id_cache[task] = th.eye(n_agents, device=batch.device)
        
        agent_id_tensor = self._agent_id_cache[task]
        
        inputs = []
        if t is not None:
            inputs.append(batch["obs"][:, t])
            if task_args.obs_last_action:
                if t == 0:
                    inputs.append(th.zeros_like(batch["actions_onehot"][:, t]))
                else:
                    inputs.append(batch["actions_onehot"][:, t-1])
            if task_args.obs_agent_id: # Not used, binary_embed is used instead
                inputs.append(agent_id_tensor.unsqueeze(0).expand(bs, -1, -1))

            if len(inputs) > 1:
                return th.cat(inputs, dim=-1)
            else:
                return inputs[0]
        else:
            inputs.append(batch["obs"])
            assert batch.max_seq_length == batch["state"].shape[1]
            if task_args.obs_last_action:
                inputs.append(th.cat([th.zeros_like(batch["actions_onehot"][:, :1]), batch["actions_onehot"][:, :-1]], dim=1))
            if task_args.obs_agent_id:
                inputs.append(agent_id_tensor.unsqueeze(0).unsqueeze(0).expand(bs, max_t, -1, -1).to(task_args.device))
            
            if len(inputs) > 1:
                return th.cat(inputs, dim=-1)
            else:
                return inputs[0]

    def _get_input_shape(self):
        task2input_shape_info = {}
        for task in self.all_tasks:
            task_scheme = self.task2scheme[task]
            obs_shape = task_scheme["obs"]["vshape"]
            input_shape = obs_shape
            last_action_shape = task_scheme["actions_onehot"]["vshape"][0]
            # joint_action_shape = task_scheme["actions_onehot"]["vshape"][0] * self.task2n_agents[task]
            agent_id_shape = self.task2n_agents[task]
            if self.task2args[task].obs_last_action:
                input_shape += last_action_shape
            if self.task2args[task].obs_agent_id:
                input_shape += agent_id_shape

            task2input_shape_info[task] = {
                "input_shape": input_shape,
                "obs_shape": obs_shape,
                "last_action_shape": last_action_shape,
                "agent_id_shape": agent_id_shape,
                #"joint_action_shape": joint_action_shape,
            }
        return task2input_shape_info
