from copy import deepcopy
import os
import json
import yaml
import re
from typing import Dict, List, Any
from pathlib import Path
from dataclasses import dataclass

from metagpt.logs import logger

def custom_json_parser(input_string):
    logger.info(f"DEBUG: input_string: {input_string}")
    def find_json_like_structures(text):
        
        # find line by line, 
        # if line only contains {, then start a new structure
        # if line only contains }, then end the current structure
        found = False
        structures = []
        current_structure = ''
        for line in text.split('\n'):
            if line.strip() == '{':
                found = True
                current_structure = '{'
            elif line.strip() == '}':
                found = False
                current_structure += '}'
                structures.append(current_structure)
            elif found:
                current_structure += line
                    
        return structures

    def parse_value(value):
        # Try to parse as JSON first
        try:
            return json.loads(value)
        except json.JSONDecodeError:
            pass
        
        # If it's not valid JSON, return it as a string
        return value.strip()

    # Remove any leading/trailing whitespace
    input_string = input_string.strip()

    # Check if the input is already valid JSON
    try:
        return json.loads(input_string)
    except json.JSONDecodeError:
        pass

    # If not, try to parse it manually
    input_string = find_json_like_structures(input_string)[0] if find_json_like_structures(input_string) else input_string
    result = {}
    
    # Find all key-value pairs
    pairs = re.findall(r'"?(\w+)"?\s*:\s*(.+?)(?=,\s*"?\w+"?\s*:|$)', input_string, re.DOTALL)
    
    for key, value in pairs:
        # Remove quotes from the key if present
        key = key.strip('"')
        
        # Parse the value
        parsed_value = parse_value(value)
        
        result[key] = parsed_value

    return result


def setup_output_directory(args, timestamp_str: str=""):
    base_dir = 'results'
    
    config_name = os.path.splitext(os.path.basename(args.config))[0]
    if args.failure_prob > 0:
        run_dir = f"MultiAgent+failure/{args.task_name}/{config_name}_{args.num_agents}_{args.max_rounds}_{args.failure_prob}"
    else:
        run_dir = f"MultiAgent+/{args.task_name}/{config_name}_{args.num_agents}_{args.max_rounds}"
    if timestamp_str:
        run_dir = run_dir + f"/{timestamp_str}"
    
    output_dir = os.path.join(base_dir, run_dir)
    os.makedirs(output_dir, exist_ok=True)
    return output_dir

def save_config(args, output_dir, timestamp_str):
    config = vars(args)
    config['timestamp'] = timestamp_str
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)


def score_change_to_text(type:str, score:float, change:float, profiles:dict=None, agent_name:str=""):
    prefix = "You profile is"
    prompt = ""
    if score < 0.5:
        if type == "clarity":
            prompt = f"{prefix} not clear or specific. "
        elif type == "differentiation":
            prompt = f"{prefix} not well-differentiated, think about other roles that you can take on. "
            if profiles is not None and agent_name in profiles:
                agents_name = [name for name in profiles.keys() if name != agent_name]
                # agent_name's profile: {profiles[agent_name]}
                prompt += f"Your profile is similar to others: ["
                for name in agents_name:
                    prompt += f"{name}: {profiles[name]} "
                prompt += "]. "
        elif type == "alignment":
            prompt = f"{prefix} not well-aligned, as your profile is not aligned with the task. "


    if change is not None: # first time
        if change > 0.05:
            prompt += "Compared to your previous profile, your current profile has improved. Keep up the good work!"
        elif change < -0.05:
            prompt += "Compared to your previous profile, your current profile has degraded."
        else:
            prompt += "Compared to your previous profile, your current profile has not changed much."
    return prompt
        
def get_all_eval_prompt(clarity_score:float, clarity_change:float, differentiation_score:float, differentiation_change:float, alignment_score:float, alignment_change:float, profiles: dict=None, agent_name: str=""):
    clarity_eval_prompt = score_change_to_text("clarity", clarity_score, clarity_change)
    differentiation_eval_prompt = score_change_to_text("differentiation", differentiation_score, differentiation_change, profiles, agent_name)
    alignment_eval_prompt = score_change_to_text("alignment", alignment_score, alignment_change)
    return clarity_eval_prompt, differentiation_eval_prompt, alignment_eval_prompt


def save_profiles(output_dir: str, task_id: str, profiles: dict, round: int):
    profiles_file = Path(output_dir) / 'agent_profiles.jsonl'
    
    profile_data = {
        "task_id": task_id,
        "round": round,
        "profiles": profiles
    }
    
    with profiles_file.open('a') as f:
        json.dump(profile_data, f, ensure_ascii=False)
        f.write('\n')

def load_config(config_path: str) -> Dict:
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    
    # Replace environment variables
    def replace_env_vars(item):
        if isinstance(item, dict):
            return {k: replace_env_vars(v) for k, v in item.items()}
        elif isinstance(item, list):
            return [replace_env_vars(i) for i in item]
        elif isinstance(item, str) and item.startswith('${') and item.endswith('}'):
            env_var = item[2:-1]
            return os.getenv(env_var, item)
        return item
    
    return replace_env_vars(config)

def distribute_configs(configs: List[Dict], num_agents: int) -> List[Dict]:
    if not configs:
        return []
    # 使用循环索引来均匀分配配置
    return [configs[i % len(configs)] for i in range(num_agents)]


def parse_answer(input_string):
    pattern = r'\\boxed\{((?:[^{}]|\{(?:[^{}]|\{[^{}]*\})*\})*)\}'
    match = re.search(pattern, input_string)
    return match.group(1) if match else ""

@dataclass
class CostTracker:
    total_prompt_tokens: int = 0
    total_completion_tokens: int = 0
    total_cost: float = 0
    total_budget: float = 0

    def update(self, costs: Dict[str, Any]):
        self.total_prompt_tokens += costs["total_prompt_tokens"]
        self.total_completion_tokens += costs["total_completion_tokens"]
                                              
        self.total_cost += costs["total_cost"]
        self.total_budget += costs["total_budget"]

    def get_costs(self) -> Dict[str, Any]:
        return {
            "total_prompt_tokens": self.total_prompt_tokens,
            "total_completion_tokens": self.total_completion_tokens,
            "total_cost": self.total_cost,
            "total_budget": self.total_budget,
        }
