import torch as th
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class MamujocoDecomposer:
    def __init__(self, args):
        # Get environment configuration
        self.env_scenario = args.env_args.get('scenario', 'Ant-v2')
        self.agent_conf = args.env_args.get('agent_conf', '2x4')
        self.agent_obsk = args.env_args.get('agent_obsk', 1)
        
        # Number of agents
        self.n_agents = args.n_agents
        self.n_enemies = 0  # No enemies in MuJoCo environments

        # State parameters
        self.state_last_action = False
        self.state_timestep_number = False
        self.timestep_number_state_dim = 0

        # Parse agent configuration to determine dimensions
        # For example, "2x4" means 2 agents with 4 body parts each
        try:
            n_agents, n_parts_per_agent = map(int, self.agent_conf.split('x'))
        except:
            # Default values if parsing fails
            n_agents, n_parts_per_agent = 2, 4
            
        # Estimate dimensions based on environment type
        if 'ant' in self.env_scenario.lower():
            # Ant environments typically have around 25-30 observation dimensions per agent
            self.obs_dim_per_agent = 30
            self.state_nf_al = 30  # State features per agent
        elif 'cheetah' in self.env_scenario.lower():
            # Half cheetah environments
            self.obs_dim_per_agent = 20
            self.state_nf_al = 20
        elif 'humanoid' in self.env_scenario.lower():
            # Humanoid environments
            self.obs_dim_per_agent = 45
            self.state_nf_al = 45
        else:
            # Default fallback
            self.obs_dim_per_agent = 25
            self.state_nf_al = 25
            
        self.state_nf_en = 0  # No enemies in MuJoCo

        # Observation parameters
        self.own_obs_dim = self.obs_dim_per_agent
        self.obs_nf_al = self.obs_dim_per_agent
        self.obs_nf_en = 0  # No enemies in MuJoCo
        self.obs_dim = self.obs_dim_per_agent  # Each agent observes the full state plus its ID
        
        # Action parameters
        # Get action space from args or set defaults
        self.n_actions_no_attack = getattr(args, 'n_actions', 8)  # Continuous actions in MuJoCo
        self.n_actions_attack = 0  # No attack actions in MuJoCo
        self.n_actions = self.n_actions_no_attack + self.n_actions_attack

        # Additional parameters for decomposition
        self.agent_id_dim = self.n_agents  # One-hot encoding of agent ID


    def decompose_state(self, state_input):
        """
        Decompose the global state into individual agent components.
        
        Args:
            state_input: tensor of shape [batch_size, state_dim] or [batch_size, n_agents, state_dim]
            
        Returns:
            agent_states: list of tensors, each of shape [batch_size, state_features_per_agent]
            enemy_states: empty list (no enemies in MuJoCo)
            last_action_states: empty list (no last actions in state)
            timestep_number_state: empty list (no timestep in state)
        """
        # Handle different input shapes
        if len(state_input.shape) == 2:
            # Shape: [batch_size, state_dim]
            batch_size, state_dim = state_input.shape
            # Reshape to [batch_size, 1, state_dim] for consistent processing
            state_input = state_input.unsqueeze(1)
            single_step = True
        else:
            # Shape: [batch_size, n_agents, state_dim] or [batch_size, seq_len, state_dim]
            single_step = False
            
        batch_size = state_input.shape[0]
        
        # In MuJoCo, the state is the same for all agents (global state with agent IDs)
        # We need to extract the base state (without agent ID) and then distribute it
        
        # Extract base state (remove agent ID encoding if present)
        if state_input.shape[-1] > self.obs_dim_per_agent:
            # Assume last n_agents elements are agent ID encoding
            base_state = state_input[..., :-self.n_agents]
        else:
            base_state = state_input
            
        # Distribute the base state to all agents
        # In MuJoCo, all agents see the same global state
        agent_states = [base_state.squeeze(1) for _ in range(self.n_agents)]
        
        # Empty lists for MuJoCo (no enemies, no last actions in state, no timestep)
        enemy_states = []
        last_action_states = []
        timestep_number_state = []

        return agent_states, enemy_states, last_action_states, timestep_number_state

    def decompose_obs(self, obs_input):
        """
        Decompose individual agent observations.
        
        Args:
            obs_input: tensor of shape [batch_size, obs_dim] 
                      or [batch_size, n_agents, obs_dim]
                      
        Returns:
            own_obs: tensor of shape [batch_size, obs_dim] or [batch_size, n_agents, obs_dim]
            enemy_feats: empty list (no enemies)
            ally_feats: list of tensors for other agents' features (if available)
        """
        # Handle different input shapes
        if len(obs_input.shape) == 2:
            # Shape: [batch_size, obs_dim]
            batch_size, obs_dim = obs_input.shape
            # Check if this is a single agent observation or needs to be split
            if obs_dim == self.obs_dim:
                # Single agent observation
                own_obs = obs_input
            elif obs_dim == self.obs_dim * self.n_agents:
                # Concatenated observations for all agents, reshape
                obs_input = obs_input.view(batch_size, self.n_agents, self.obs_dim)
                own_obs = obs_input
            else:
                # Unexpected dimension
                own_obs = obs_input
        else:
            # Shape: [batch_size, n_agents, obs_dim]
            own_obs = obs_input
            
        # In MuJoCo environments, there are no explicit enemies
        enemy_feats = []
        
        # Extract features for allies (other agents)
        # Since each agent sees the global state, we can consider other agents' features
        # as part of the observation, but we separate the agent's own features
        ally_feats = []
        
        # If we have multi-agent observations, we can extract features for each agent
        if len(own_obs.shape) == 3 and own_obs.shape[1] == self.n_agents:
            # Separate observations for each agent
            for i in range(self.n_agents):
                # Each agent's observation contains global information
                # We can treat the observation as the agent's own features
                pass  # own_obs already correctly shaped
                
        return own_obs, enemy_feats, ally_feats

    def decompose_action_info(self, action_info):
        """
        Decompose action information.
        
        Args:
            action_info: tensor of shape [(bs), n_agent, n_action]
            
        Returns:
            no_attack_action_info: continuous action values
            attack_action_info: empty (no attack actions in MuJoCo)
            compact_action_info: same as no_attack_action_info for MuJoCo
        """
        shape = action_info.shape
        # For MuJoCo, all actions are continuous control actions, no discrete attack actions
        
        # Continuous actions (main actions)
        no_attack_action_info = action_info
        
        # No attack actions in MuJoCo
        attack_action_info = None
        
        # Compact action is the same as the full action in MuJoCo
        compact_action_info = action_info
        
        return no_attack_action_info, attack_action_info, compact_action_info
    