import json
import os
import time
from datetime import datetime
from typing import Dict, Any, List
from core.environment import WorkEnvironment

class HistoryManager:
    def __init__(self, save_dir: str = "data/env_history"):
        self.history_dir = os.path.join(save_dir, 'env_history')
        if not os.path.exists(self.history_dir):
            os.makedirs(self.history_dir)
        
        self.env_turns: List[Dict] = []
        self.memory_buffer: Dict[str, List[Dict]] = {}

    def save(self, env: WorkEnvironment, dataset_id: str, step: int, finish_msg: str, 
             turn_index: int = 0, to_write: bool = False):
        """
        turn_index:
        """
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        
        env_snapshot = {
            "objective": env.objective,
            "scratchpads": env.scratchpads,
            "task_history": env.task_history,
            "thinking_history": env.thinking_history,
            "dispatch_list_history": env.dispatch_list_history,
        }

        while len(self.env_turns) <= turn_index:
             self.env_turns.append({})
        self.env_turns[turn_index] = env_snapshot

        snapshot_record = {
            "turn_index": turn_index,
            "steps": step + 1,
            "timestamp": timestamp,
            "finish_msg": finish_msg,
            "WorkEnvironment": env_snapshot
        }

        file_name = f"{dataset_id}.json"
        if file_name not in self.memory_buffer:
            self.memory_buffer[file_name] = []
        
        while len(self.memory_buffer[file_name]) <= turn_index:
            self.memory_buffer[file_name].append({})

        self.memory_buffer[file_name][turn_index] = snapshot_record

        if to_write:
            file_path = os.path.join(self.history_dir, file_name)

            with open(file_path, 'w', encoding='utf-8') as f:
                json.dump(self.memory_buffer[file_name], f, indent=4, ensure_ascii=False)

    def get_history(self, dataset_id: str) -> List[Dict]:
        file_name = f"{dataset_id}.json"
        return self.memory_buffer.get(file_name, [])

    def clear_buffer(self, dataset_id: str):
        file_name = f"{dataset_id}.json"
        if file_name in self.memory_buffer:
            del self.memory_buffer[file_name]
