import time
import inspect
import re
import ast
import os
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Set
import random
import math

from harl.common.skills.skill import Skill
from harl.test import test_params

from transformers import BertTokenizer, BertModel, AutoImageProcessor, AutoModel
# import faiss
import json
import numpy as np
from collections import Counter
from sklearn.cluster import DBSCAN
from sklearn.metrics import silhouette_score
from tqdm import tqdm

def serialize_skills(skills: Dict[str, Skill]) -> Dict[str, Dict]:
    serialized_skills = {name: skill.to_dict() for name, skill in skills.items()}
    return serialized_skills


def deserialize_skills(serialized_skills: Dict[str, Dict]) -> Dict[str, Skill]:
    return {name: Skill.from_dict(skill) for name, skill in serialized_skills.items()}

def convert_expression_to_skill(input_string: str = "open_map()"):

    parsed = ast.parse(input_string, mode='eval')

    if isinstance(parsed.body, ast.Call):

        pattern = re.compile(r'(\w+)\((.*?)\)')

        match = pattern.match(input_string)

        if match:
            function_name = match.group(1)
            raw_arguments = match.group(2)

            # To avoid simple errors based on faulty model output
            if raw_arguments is not None and len(raw_arguments) > 0:
                raw_arguments = raw_arguments.replace("=false", "=False").replace("=true", "=True")

            try:
                parsed_arguments = ast.parse(f"fake_func({raw_arguments})", mode='eval')
            except SyntaxError:
                raise ValueError("Invalid function call/arg format to parse.")

            arguments = {}
            for node in ast.walk(parsed_arguments):
                if isinstance(node, ast.keyword):
                    arguments[node.arg] = ast.literal_eval(node.value)

            if len(raw_arguments) > 0 and len(arguments.keys()) == 0:
                raise ValueError("Call arguments not properly parsed!")

            return function_name, arguments

        else:
            raise ValueError("Invalid function call format string.")
    else:
        raise ValueError("Input must be a function call")

def convert_code_to_skill_info(skill_code: str):
        tree = ast.parse(skill_code)
        function_name = None
        first = True
        arguments = {}
        # TODO: This is a very naive way to get the function name. We should improve this.
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef) and first:
                function_name = node.name
                first = False

                # Extract arguments and their type hints
                for arg in node.args.args:
                    arg_name = arg.arg
                    # Get type hint if exists
                    if arg.annotation:
                        if isinstance(arg.annotation, ast.Name):
                            arg_type = arg.annotation.id
                        elif isinstance(arg.annotation, ast.Subscript):
                            # Handle generic types like List[int]
                            arg_type = f"{arg.annotation.value.id}"
                        else:
                            arg_type = "Any"
                    else:
                        arg_type = "Any"
                    arguments[arg_name] = arg_type
        return function_name, arguments

def test_skill_code(skill_code: str, original_code, overwrite = False, replaced_name = None, updated_docstring = None) -> Tuple[bool, str]:
        """Register the skill function from the code string.

        Args:
            skill_code: the code of skill.
            overwrite: the flag indicates whether to overwrite the skill with the same name or not.

        Returns:
            bool: the true value means that there is no problem in the skill_code. The false value means that we may need to re-generate it.
            str: the detailed information about the bool.
        """
        def lower_func_name(skill_code, replaced_name = None):
            skill_name, _ = convert_code_to_skill_info(skill_code)
            if replaced_name:
                return skill_code.replace(skill_name, replaced_name)
            return skill_code
        
        def update_docstring(skill_code, updated_docstring):
            # Update docstring if provided
            if updated_docstring:
                tree = ast.parse(skill_code)
                for node in ast.walk(tree):
                    if isinstance(node, ast.FunctionDef):
                        node.body[0].value.s = updated_docstring
                        break
                return ast.unparse(tree)

        def check_param_description(skill) -> bool:
            docstring = inspect.getdoc(skill)
            if docstring:
                params = inspect.signature(skill).parameters
                if len(params) > 0:
                    for param in params.values():
                        if not re.search(rf"\s+{param.name}.*:\s*([^\n]+)", docstring):
                            return False
                    return True
                else:
                    return True
            else:
                return True
            
        def find_function_with_indent(script_content, func_name):
            """
            Find a function and its indentation level in the script
            
            Args:
                script_content (str): The full script content
                func_name (str): Name of function to find (e.g. "score_target" or "control_logic")
            
            Returns:
                tuple: (start_index, end_index, base_indent)
            """
            # Find the function definition
            func_def = f"def {func_name}("
            start = script_content.find(func_def)
            if start == -1:
                raise ValueError(f"Could not find function {func_name}")
            
            # Get base indentation by looking at whitespace before function def
            line_start = script_content.rfind('\n', 0, start) + 1
            base_indent = script_content[line_start:start]
            
            # Find the end by looking for the next function at same indent level
            lines = script_content[start:].split('\n')
            
            # Track nested indentation level
            current_indent = 0
            end_idx = 0
            
            for i, line in enumerate(lines):
                if i == 0:  # Skip the function definition line
                    continue
                    
                # Check if this line has less indentation than base
                if line.strip() and len(line) - len(line.lstrip()) <= len(base_indent):
                    if i > 1:  # Make sure we're not on the first few lines
                        end_idx = sum(len(l) + 1 for l in lines[:i])  # -1 to not include current line
                        break
            
            if end_idx == 0:  # If we didn't find the end, take until end of script
                end_idx = len(script_content[start:])
                
            return start, start + end_idx, base_indent

        def apply_indentation(new_func_content, base_indent):
            """
            Apply proper indentation to new function content
            """
            lines = new_func_content.split('\n')
            indented_lines = []
            first = True
            
            for line in lines:
                if line.strip():  # If line is not empty
                    if line.startswith('def '):  # Function definition
                        if first:
                            indented_lines.append(line)
                            first = False
                        else:
                            break
                    else:  # Function body
                        indented_lines.append(base_indent + line)
                else:
                    indented_lines.append('')
                    
            return '\n'.join(indented_lines)

        def replace_function(script_content, func_name, new_func_content):
            """
            Replace a function while preserving proper indentation
            
            Args:
                script_content (str): The full script content
                func_name (str): Name of function to replace
                new_func_content (str): New function content (without indentation)
            
            Returns:
                str: Updated script content
            """
            start, end, base_indent = find_function_with_indent(script_content, func_name)
            indented_new_func = apply_indentation(new_func_content, base_indent)
            
            return script_content[:start] + indented_new_func + '\n' + script_content[end:]
            
        try:

            info = None

            if skill_code.count('(') < 2:
                info = "Skill code contains no functionality."
                return False, info, skill_code

            skill_code = lower_func_name(skill_code)
            skill_name, skill_params = convert_code_to_skill_info(skill_code)

            skill_code = replace_function(original_code, skill_name, skill_code)

            skill_code = lower_func_name(skill_code, replaced_name)
            skill_name, skill_params = convert_code_to_skill_info(skill_code)

            skill_code = update_docstring(skill_code, updated_docstring)

            # Create execution context
            exec_globals = globals().copy()
            
            # Add required imports
            from harl.common.skills.smacv2.atomic_actions.move import (
                move_north, move_south, move_east, move_west
            )
            from harl.common.skills.smacv2.atomic_actions.combat import attack
            from harl.common.skills.smacv2.atomic_actions.heal import heal
            from harl.common.skills.smacv2.atomic_actions.basic import stop
            from harl.common.skills.smacv2.composite_skills import default_tactic as default_action, find_path
            from harl.utils.skill_utils import parse_obs
            from typing import List, Tuple, Dict
            import random
            # from harl.common.skills.smacv2.composite_skills.attack_control import focus_fire, kiting_north, kiting_south, kiting_east, kiting_west
            
            # Create execution context with required imports
            exec_globals.update({
                'move_north': move_north,
                'move_south': move_south, 
                'move_east': move_east,
                'move_west': move_west,
                'attack': attack,
                'heal': heal,
                'stop': stop,
                'default_action': default_action,
                'find_path': find_path,
                'parse_obs': parse_obs,
                'random': random,
                'List': List,
                'Tuple': Tuple,
                'Dict': Dict
            })
            # Execute in prepared context
            exec_locals = {}
            exec(skill_code, exec_globals, exec_locals)
            skill = exec_locals[skill_name]  # Get function from locals
        except:
            info = "The format of skill code is wrong."
            return False, info, skill_code

        if check_param_description(skill) == False:
            info = "The format of parameter description is wrong."
            return False, info, skill_code
        
        # type_to_value = {
        #     'int': 2,
        #     'float': 2.0,
        #     'str': 'test',
        #     'bool': True,
        #     'List': [1, 2],
        #     'Dict': {'key': 'value'},
        #     'Tuple': (1, 2),
        #     'Any': 2
        # }
        # for arg_name, arg_type in skill_params.items():
        #     test_params[arg_name] = type_to_value.get(arg_type, None)

        try:
            # Check for control flow statements
            if any(isinstance(node, (ast.While)) for node in ast.walk(ast.parse(skill_code))):
                info = "Skill code should not contain control flow statements like 'while'"
                return False, info, skill_code
            
            for test_param in test_params:
                action = skill(**test_param)
                    
                if not isinstance(action, int):
                    info = "Skill function must return an integer action instead of a {}".format(type(action))
                    return False, info, skill_code
        except Exception as e:
            info = f"Skill function is defined but failed to execute: {str(e)}"
            return False, info, skill_code

        return True, info, skill_code

def extract_key_events(actions, decisions, errors, reflections=None, summarys=None):
    """Extract key events with consecutive action counting.
    
    Args:
        actions: List of action strings
        decisions: List of decision strings
        max_consecutive_count: Max steps to include for same consecutive action
    """
    key_events = []
    prev_action = None
    prev_decision = None
    prev_error = None
    prev_reflection = None
    prev_summary = None
    prev_step = None
    consecutive_count = 0
    
    if reflections is None:
        for i, (action, decision, error) in enumerate(zip(actions, decisions, errors)):
            # Action transition
            if action != prev_action and action != "":
                if prev_action:
                    # Add count annotation for repeated actions
                    if consecutive_count > 1:
                        action_desc = f"{prev_action} (repeated {consecutive_count}x)"
                    else:
                        action_desc = prev_action
                    key_events.append((prev_step, action_desc, prev_decision, prev_error))
                consecutive_count = 1
                prev_step = i
                prev_decision = decision
                prev_error = error
                
            # Same action but under max count
            else:
                consecutive_count += 1
                
            prev_action = action
    else:
        for i, (action, decision, error, reflection, summary) in enumerate(zip(actions, decisions, errors, reflections, summarys)):
            # Action transition
            if action != prev_action:
                if prev_action:
                    # Add count annotation for repeated actions
                    if consecutive_count > 1:
                        action_desc = f"{prev_action} (repeated {consecutive_count}x)"
                    else:
                        action_desc = prev_action
                    key_events.append((prev_step, action_desc, prev_decision, prev_error, prev_reflection, prev_summary))
                consecutive_count = 1
                prev_step = i
                prev_decision = decision
                prev_error = error
                prev_reflection = reflection
                prev_summary = summary
                
            # Same action but under max count
            else:
                consecutive_count += 1
                
            prev_action = action

    # the last action
    if reflections is None:
        if prev_action:
            if consecutive_count > 1:
                action_desc = f"{prev_action} (repeated {consecutive_count}x)"
            else:
                action_desc = prev_action
            key_events.append((prev_step, action_desc, prev_decision, prev_error))
    else:
        if prev_action:
            if consecutive_count > 1:
                action_desc = f"{prev_action} (repeated {consecutive_count}x)"
            else:
                action_desc = prev_action
            key_events.append((prev_step, action_desc, prev_decision, prev_error, prev_reflection, prev_summary))
    return key_events

class SkillExtractionPipeline:
    def __init__(self):
        # Initialize embedding model for action sequence similarity
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.embedding_model = BertModel.from_pretrained('bert-base-uncased')
        # Initialize FAISS index for fast similarity search
        self.action_index = faiss.IndexFlatL2(768)  # Dimension matches embedding model
        
    def extract_action_patterns(self, trajectory_data):
        def calculate_pattern_score(metrics):
            """Calculate quality score for a pattern"""
            return (
                metrics['damage_dealt'] * 2.0 +
                metrics['enemies_killed'] * 4.0 +
                metrics['average_enemy_health_change'] * 1.0 -
                metrics['damage_got'] * 3.0 -
                metrics['allies_lost'] * 5.0 -
                metrics['average_ally_health_change'] * 2.0
                
            )
        def is_subsequence(seq1, seq2):
            """Check if seq1 is a subsequence of seq2"""
            return seq1['actions'] in seq2['actions']
        """Extract recurring action patterns from trajectory data"""
        action_sequences = []
        _success_metrics = []
        
        for trajectory in trajectory_data:
            # Sliding window over trajectory to identify potential skill segments
            for window_size in range(10, 15):
                for i in range(len(trajectory['share_obs']) - window_size + 1):
                    state_sequence = trajectory['share_obs'][i:i + window_size]
                    action_sequence = trajectory['actions'][i:i + window_size]
                    
                    # Extract features that might indicate a meaningful skill
                    state_before = state_sequence[0].decode('utf-8')
                    state_after = state_sequence[-1].decode('utf-8')

                    # Reorganize actions by agent
                    agent_actions = {}
                    for step, action in enumerate(action_sequence):
                        decoded_actions = [act.decode('utf-8') for act in action]
                        for agent_id, act in enumerate(decoded_actions):
                            if agent_id not in agent_actions:
                                agent_actions[agent_id] = []
                            agent_actions[agent_id].append(act)
                    
                    # Format actions string with each agent's sequence on a new line
                    actions = []
                    for agent_id in sorted(agent_actions.keys()):
                        agent_sequence = f"Agent {agent_id}: {' -> '.join(agent_actions[agent_id])}"
                        actions.append(agent_sequence)
                    actions = "\n".join(actions)
                    # Calculate success metrics (e.g., damage dealt, objective progress)
                    success_metrics = self._calculate_success_metrics(state_before, state_after)
                    
                    if self._is_meaningful_sequence(success_metrics):
                        pattern_score = calculate_pattern_score(success_metrics)
                        _success_metrics.append(pattern_score)
                        action_sequences.append({
                            'actions': actions,
                            'initial_state': state_before,
                            'final_state': state_after,
                            'score': pattern_score,
                            'length': window_size
                        })
        
        mean_length = np.mean(_success_metrics)
        std_length = np.std(_success_metrics)
        z_scores = (_success_metrics - mean_length) / std_length
        valid_indices = [i for i, z_score in enumerate(z_scores) if 2.0 <= z_score]  # Keep episodes within 1 std below mean
        print(f"Kept {len(valid_indices)/len(_success_metrics)*100:.1f}% of episodes")
        action_sequences = [action_sequences[i] for i in valid_indices]

        # Filter out redundant subsequences
        filtered_sequences = []
        action_sequences.sort(key=lambda x: -x['score'])

        for sequence in action_sequences:
        # Only keep if not a subsequence of a higher scoring pattern
            if not any(is_subsequence(sequence, better_seq) 
                    for better_seq in filtered_sequences 
                    if better_seq['score'] > sequence['score']):
                filtered_sequences.append(sequence)
        
        filtered_sequences = filtered_sequences[:len(filtered_sequences)//5]
        
        return filtered_sequences

    def cluster_similar_patterns(self, action_sequences):
        def find_optimal_clusters(embeddings, min_clusters=2, max_clusters=500):
            """Find optimal number of clusters using silhouette analysis"""
            best_score = -1
            best_n = min_clusters
            
            for n in range(min_clusters, min(max_clusters, len(embeddings)//3)):
                kmeans = faiss.Kmeans(embeddings.shape[1], n)
                kmeans.train(embeddings)
                labels = kmeans.assign(embeddings)[1]
                score = silhouette_score(embeddings, labels)
                if score > best_score:
                    best_score = score
                    best_n = n
            return best_n
        def cluster_sequences(action_sequences, embeddings, method='kmeans'):
            """Cluster sequences using either K-means or DBSCAN"""
            
            clusters = {}
            
            if method == 'kmeans':
                # Automatic cluster number selection
                # n_clusters = find_optimal_clusters(embeddings, min_clusters=100, max_clusters=embeddings.shape[0]//3)
                kmeans = faiss.Kmeans(embeddings.shape[1], 1000)
                kmeans.train(embeddings)
                labels = kmeans.assign(embeddings)[1]
                
            else:  # DBSCAN
                dbscan = DBSCAN(eps=0.1, min_samples=2)
                labels = dbscan.fit_predict(embeddings)
                
            # Group sequences and handle outliers
            for idx, label in enumerate(labels):
                if label == -1:  # Outlier handling
                    continue
                if label not in clusters:
                    clusters[label] = []
                clusters[label].append(action_sequences[idx])
            
            return clusters
        """Group similar action sequences into potential skills"""
        # Convert action sequences to embeddings
        # Check if embeddings file exists
        embeddings_file = 'sequence_embeddings.npy'
        if os.path.exists(embeddings_file):
            embeddings = np.load(embeddings_file)
        else:
            sequence_texts = [json.dumps(seq['actions']) for seq in action_sequences]
            embeddings = np.zeros((len(sequence_texts), 768))
            for i, text in tqdm(enumerate(sequence_texts), desc='Generating embeddings', total=len(sequence_texts)):
                inputs = self.tokenizer(
                    text,
                    max_length=512,
                    padding=True,
                    truncation=True,
                    return_tensors='pt'
                )
                outputs = self.embedding_model(**inputs)
                query_embedding = outputs.last_hidden_state.mean(dim=1).squeeze().detach().numpy()
                embeddings[i] = (query_embedding / np.linalg.norm(query_embedding)).tolist()
            # embeddings = self.embedding_model.encode(sequence_texts)
            # Save embeddings to file for faster future loading
            np.save(embeddings_file, embeddings)
        
        cluster_file = "action_clusters.json"
        if not os.path.exists(cluster_file):
            # Cluster similar sequences
            clusters = cluster_sequences(action_sequences, embeddings)

            # Save clusters to a file
            cluster_data = {str(k): [seq for seq in v] for k, v in clusters.items()}
            with open('action_clusters.json', 'w') as f:
                json.dump(cluster_data, f, indent=2)

    def _calculate_success_metrics(self, state_before: Dict, state_after: Dict) -> Dict:
        """Calculate metrics to determine if action sequence was successful"""
        def extract_unit_stats(state_text: str):
            """Helper to extract health values for units from state text"""
            ally_stats = {}
            enemy_stats = {}
            
            # Extract ally info
            ally_section = state_text.split("1. Ally Units:")[1].split("\n2. Enemy Units:")[0]
            ally_blocks = ally_section.split("- Ally #")
            for block in ally_blocks[1:]:  # Skip first empty element
                try:
                    ally_id = int(block.split(":")[0])
                    health = float(block.split("Health: ")[1].split("%")[0]) / 100
                    ally_stats[ally_id] = health
                except:
                    continue
                    
            # Extract enemy info
            enemy_section = state_text.split("2. Enemy Units:")[1].split("\n3." if "\n3." in state_text else "\n4.")[0]
            enemy_blocks = enemy_section.split("- Enemy #")
            for block in enemy_blocks[1:]:  # Skip first empty element
                try:
                    enemy_id = int(block.split(":")[0])
                    health = float(block.split("Health: ")[1].split("%")[0]) / 100
                    enemy_stats[enemy_id] = health
                except:
                    continue

            return ally_stats, enemy_stats
        
        allies_before, enemies_before = extract_unit_stats(state_before)
        allies_after, enemies_after = extract_unit_stats(state_after)

        # Calculate metrics
        metrics = {
            "enemies_killed": 0,
            "damage_dealt": 0.0,
            "damage_got": 0.0,
            "allies_lost": 0,
            "average_enemy_health_change": 0.0,
            "average_ally_health_change": 0.0
        }
        # Count eliminated units and calculate damage
        for enemy_id in enemies_before:
            health_change = enemies_before[enemy_id] - enemies_after[enemy_id]
            metrics["damage_dealt"] += health_change
            if enemies_before[enemy_id] > 0 and enemies_after[enemy_id] == 0:
                metrics["enemies_killed"] += 1

        # Calculate average enemy health change for surviving enemies (health > 0)
        surviving_enemies = {
            e for e in enemies_before.keys() & enemies_after.keys() 
            if enemies_after[e] > 0
        }
        if surviving_enemies:
            health_changes = [enemies_after[e] - enemies_before[e] for e in surviving_enemies]
            metrics["average_enemy_health_change"] = sum(health_changes) / len(surviving_enemies)

        # Count lost allies and calculate healing
        for ally_id in allies_before:
            health_change = allies_before[ally_id] - allies_after[ally_id]
            metrics["damage_got"] += health_change
            if allies_before[ally_id] > 0 and allies_after[ally_id] == 0:
                metrics["allies_lost"] += 1

        # Calculate average ally health change for surviving allies (health > 0)
        surviving_allies = {
            a for a in allies_before.keys() & allies_after.keys()
            if allies_after[a] > 0
        }
        if surviving_allies:
            health_changes = [allies_before[a] - allies_after[a] for a in surviving_allies]
            metrics["average_ally_health_change"] = sum(health_changes) / len(surviving_allies)

        return metrics
    
    def _is_meaningful_sequence(self, metrics: Dict) -> bool:
        """Determine if action sequence represents a meaningful skill"""
        # Define thresholds for various metrics
        return (metrics['damage_dealt'] > 0 or
                metrics['enemies_killed'] > 0 or
                metrics['allies_lost'] < 1 or
                metrics['average_enemy_health_change'] > 0)

    # Additional helper methods would go here...

@dataclass
class Ally:
    id: int
    distance: float
    direction: str
    health: Optional[float] = None
    shield: Optional[float] = None
    position: Optional[Tuple[float, float]] = None
    unit_type: Optional[str] = None
    last_action: int = None
    can_heal: bool = False

@dataclass 
class Enemy:
    id: int
    can_attack: bool
    distance: float
    direction: str
    health: Optional[float] = None
    shield: Optional[float] = None
    position: Optional[Tuple[float, float]] = None
    unit_type: Optional[str] = None

@dataclass
class ObsData:
    can_move: Dict[str, bool]
    enemies: List[Enemy]
    allies: List[Ally]  # Added allies field
    own_health: float
    own_shield: Optional[float] = None
    own_sight_range: Optional[float] = None
    own_shoot_range: Optional[float] = None
    own_position: Optional[Tuple[float, float]] = None
    own_unit_type: Optional[str] = None
    available_actions: Set[int] = None
    region_of_interest: Optional[str] = None
    last_action: int = None

def parse_obs(obs: str) -> ObsData:
    """Parse observation text into structured format"""
    sections = obs.split('\n\n')

    # Parse available actions
    actions_section = next(s for s in sections if "Available Actions:" in s)
    available_actions = set()
    # Basic action mapping
    action_map = {
        'noop': 0,
        'stop': 1,
        'move_north': 2, 
        'move_south': 3,
        'move_east': 4,
        'move_west': 5,
        'attack': 6,  # Base index for attack actions
        'heal': 6     # Uses same base as attack
    }

    for line in actions_section.split('\n')[1:]:  # Skip header
        # Split on last colon to separate availability
        parts = line.rsplit(':', 1)
        if len(parts) < 2 or parts[1].strip() != 'yes':
            continue
            
        # Split into action name and parameters
        action_part = parts[0].strip('- ')
        action_splits = action_part.split(',', 1)
        action_name = action_splits[0].strip()
        
        # Parse parameters if present
        params = {}
        if len(action_splits) > 1:
            param_str = action_splits[1].strip()
            try:
                param_dict = eval(param_str)
                params.update(param_dict)
            except (SyntaxError, ValueError, NameError):
                continue
                
        # Handle special cases for attack/heal
        if action_name in ['attack', 'heal']:
            if 'target_id' in params:
                target_id = params['target_id']
                action_id = action_map[action_name] + target_id
                available_actions.add(action_id)
        else:
            # Handle basic actions
            if action_name in action_map:
                available_actions.add(action_map[action_name])

    # Parse movement
    move_section = next(s for s in sections if "Movement Information:" in s)
    can_move = {
        'north': 'yes' in move_section.split('North:')[1].split('\n')[0],
        'south': 'yes' in move_section.split('South:')[1].split('\n')[0], 
        'east': 'yes' in move_section.split('East:')[1].split('\n')[0],
        'west': 'yes' in move_section.split('West:')[1].split('\n')[0]
    }

    # Parse ego minimap section first to get shared information
    ego_data = {"allies": {}, "enemies": {}}
    ego_section = next((s for s in sections if "Ego Minimap:" in s), None)
    if ego_section:
        current_type = None
        for line in ego_section.split('\n'):
            if "Visible Allies:" in line:
                current_type = "allies"
                continue
            elif "Visible Enemies:" in line:
                current_type = "enemies"
                continue
                
            if current_type and "- " in line:
                unit_info = {}
                unit_id = int(line.split('#')[1].split()[0])
                unit_type = line.split('(')[1].split(')')[0]
                unit_info['unit_type'] = unit_type
                unit_info['shield'] = 0
                obs_type = "direct" if "Directly observed" in line else "shared"
                
                # Look ahead for details
                next_lines = ego_section.split('\n')[ego_section.split('\n').index(line)+1:ego_section.split('\n').index(line)+6]
                for next_line in next_lines:
                    health = None
                    if 'Health:' in next_line:
                        health = float(next_line.split('Health:')[1].split('%')[0].strip()) / 100
                        unit_info['health'] = health
                    elif 'Shield:' in next_line:
                        shield = float(next_line.split('Shield:')[1].split('%')[0].strip()) / 100
                        unit_info['shield'] = shield
                    elif "Relative position:" in next_line:
                        pos_str = next_line.split('(')[1].split(')')[0]
                        x, y = map(float, pos_str.split(', '))
                        unit_info['position'] = (x, y)
                    elif "Relative direction:" in next_line:
                        direction = next_line.split(':')[1].strip()
                        unit_info['direction'] = direction
                    elif "Last seen:" in next_line:
                        last_seen = int(next_line.split(':')[1].split()[0])
                        unit_info['last_seen'] = last_seen
                
                if unit_info['health'] > 0:
                    ego_data[current_type][unit_id] = {
                        'obs_type': obs_type,
                        **unit_info
                    }

    # Parse enemies section
    enemies = []
    enemy_section = next(s for s in sections if "Enemy Units Information:" in s)
    for enemy_block in enemy_section.split('Enemy #')[1:]:
        enemy_id = int(enemy_block.split(':')[0])

        if 'Can be attacked: yes' in enemy_block:
            attackable = True
        else:
            attackable = False

        distance = float(enemy_block.split('Distance:')[1].split('units')[0].strip())
        direction = enemy_block.split('Relative direction:')[1].split('\n')[0].strip()
        
        health = None
        if '* Health:' in enemy_block:
            health = float(enemy_block.split('* Health:')[1].split('%')[0].strip()) / 100
        
        if health is None or health <= 0:
            continue
        
        shield = 0 
        if '* Shield:' in enemy_block:
            shield = float(enemy_block.split('* Shield:')[1].split('%')[0].strip()) / 100
            
        pos = None
        if 'Relative position:' in enemy_block:
            pos_str = enemy_block.split('Relative position:')[1].split(')')[0].strip('( ')
            x, y = map(float, pos_str.split(','))
            pos = (x, y)
        
        unit_type = None
        if 'Unit type:' in enemy_block:
            unit_type = enemy_block.split('Unit type:')[1].split('\n')[0].strip()
            
        enemies.append(Enemy(
            id=enemy_id,
            can_attack=attackable,
            distance=distance,
            direction=direction,
            health=health,
            shield=shield,
            position=pos,
            unit_type=unit_type,
        ))

    # Parse allies section
    allies = []
    ally_section = next(s for s in sections if "Ally Units Information:" in s)
    for ally_block in ally_section.split('Ally #')[1:]:
        ally_id = int(ally_block.split(':')[0])
        distance = float(ally_block.split('Distance:')[1].split('units')[0].strip())
        direction = ally_block.split('Relative direction:')[1].split('\n')[0].strip() if 'Relative direction:' in ally_block else None
        
        health = None
        if '* Health:' in ally_block:
            health = float(ally_block.split('* Health:')[1].split('%')[0].strip()) / 100

        if health is None or health <= 0:
            continue
        
        shield = 0
        if '* Shield:' in ally_block:
            shield = float(ally_block.split('* Shield:')[1].split('%')[0].strip()) / 100
            
        pos = None
        if 'Relative position:' in ally_block:
            pos_str = ally_block.split('Relative position:')[1].split(')')[0].strip('( ')
            x, y = map(float, pos_str.split(','))
            pos = (x, y)

        unit_type = None
        if 'Unit type:' in ally_block:
            unit_type = ally_block.split('Unit type:')[1].split('\n')[0].strip()

        last_action = None
        if 'Last action:' in ally_block:
            action_part = ally_block.split('Last action:')[1].split('\n')[0].strip()
            action_splits = action_part.split(',', 1)
            action_name = action_splits[0].strip()
            
            # Parse parameters if present
            params = {}
            if len(action_splits) > 1:
                param_str = action_splits[1].strip()
                try:
                    param_dict = eval(param_str)
                    params.update(param_dict)
                except (SyntaxError, ValueError, NameError):
                    continue
                    
            # Handle special cases for attack/heal
            if action_name in ['attack', 'heal']:
                if 'target_id' in params:
                    target_id = params['target_id']
                    last_action = action_map[action_name] + target_id
            else:
                # Handle basic actions
                if action_name in action_map:
                    last_action = action_map[action_name]
            
        allies.append(Ally(
            id=ally_id,
            distance=distance,
            direction=direction,
            health=health,
            shield=shield,
            position=pos,
            unit_type=unit_type,
            last_action=last_action,
            can_heal=(unit_type.lower() == 'medivac')
        ))

    # Parse own info
    own_section = next(s for s in sections if "Own Unit Information:" in s)
    own_health = float(own_section.split('Health:')[1].split('\n')[0].strip()[:-1]) / 100
    
    own_shield = 0
    if 'Shield:' in own_section:
        own_shield = float(own_section.split('Shield:')[1].split('\n')[0].strip()[:-1]) / 100
    
    own_sight_range = None
    if 'Sight range:' in own_section:
        own_sight_range = float(own_section.split('Sight range:')[1].split('units')[0].strip())

    own_shoot_range = None
    if 'Shoot range:' in own_section:
        own_shoot_range = float(own_section.split('Shoot range:')[1].split('units')[0].strip())

    own_pos = None
    if 'Position:' in own_section:
        pos_str = own_section.split('Position:')[1].split(')')[0].strip('( ')
        x, y = map(float, pos_str.split(','))
        own_pos = (x, y)

    own_type = None
    if 'Unit type:' in own_section:
        own_type = own_section.split('Unit type:')[1].split('\n')[0].strip()
    
    last_action = None
    if 'Last action:' in own_section:
        action_part = own_section.split('Last action:')[1].split('\n')[0].strip()
        action_splits = action_part.split(',', 1)
        action_name = action_splits[0].strip()
        
        # Parse parameters if present
        params = {}
        if len(action_splits) > 1:
            param_str = action_splits[1].strip()
            try:
                param_dict = eval(param_str)
                params.update(param_dict)
            except (SyntaxError, ValueError, NameError):
                pass
                
        # Handle special cases for attack/heal
        if action_name in ['attack', 'heal']:
            if 'target_id' in params:
                target_id = params['target_id']
                last_action = action_map[action_name] + target_id
        else:
            # Handle basic actions
            if action_name in action_map:
                last_action = action_map[action_name]
    
    # Extend enemies with ego data
    existing_enemy_ids = [enemy.id for enemy in enemies]
    for enemy_id in ego_data['enemies'].keys():
        # Check if enemy_id in enemies list
        if enemy_id in existing_enemy_ids:
            continue
        enemy_pos = ego_data['enemies'][enemy_id]['position']
        # Calculate distance
        distance = math.hypot(enemy_pos[0], enemy_pos[1])
        enemies.append(Enemy(
            id=enemy_id,
            can_attack= distance <= own_shoot_range / own_sight_range,
            distance=distance,
            direction=ego_data['enemies'][enemy_id]['direction'],
            health=ego_data['enemies'][enemy_id]['health'],
            shield=ego_data['enemies'][enemy_id]['shield'],
            position=enemy_pos,
            unit_type=ego_data['enemies'][enemy_id]['unit_type'],
        ))
    
    # Extend allies with ego data
    existing_ally_ids = [ally.id for ally in allies]
    for ally_id in ego_data['allies'].keys():
        # Check if ally_id in allies list
        if ally_id in existing_ally_ids:
            continue
        ally_pos = ego_data['allies'][ally_id]['position']
        # Calculate distance
        distance = math.hypot(ally_pos[0], ally_pos[1])
        allies.append(Ally(
            id=ally_id,
            distance=distance,
            direction=ego_data['allies'][ally_id]['direction'],
            health=ego_data['allies'][ally_id]['health'],
            shield=ego_data['allies'][ally_id]['shield'],
            position=ally_pos,
            unit_type=ego_data['allies'][ally_id]['unit_type'],
            last_action=0,
            can_heal=(unit_type.lower() == 'medivac')
        ))
    
    # Parse region of interest
    roi = ""
    roi_section = next((s for s in sections if "Region of Interest:" in s), None) 
    if roi_section:
        roi = roi_section.strip()


    return ObsData(
        can_move=can_move,
        enemies=enemies,
        allies=allies,
        own_health=own_health,
        own_shield=own_shield,
        own_sight_range=own_sight_range,
        own_shoot_range=own_shoot_range,
        own_position=own_pos,
        own_unit_type=own_type,
        available_actions=available_actions,
        region_of_interest=roi,
        last_action=last_action
    )