import json
import datetime
from typing import Any, Optional
import re
import hashlib
from copy import deepcopy

def get_prompt_hash(prompt: str) -> str:
    # Remove special characters and whitespace
    cleaned_prompt = re.sub(r'[^A-Za-z0-9]+', '', prompt)
    return hashlib.md5(cleaned_prompt.encode()).hexdigest()

class AgentTrajectoryLogger:
    log_trajectory: bool = True
    
    def log_init(self, dt: Optional[datetime.datetime] = None):
        # log_init must be called after the first call to self.update_from_env
        self._initial_observation = deepcopy(self.current_observation)

        self._traj_log: list[dict[Any, Any]] = []
        self._traj_log_total_steps: int = 0
        self._traj_log_metadata: dict[str, Any] = {
            'Total Reward': '-',
            'Total Steps': '-',
            'Done': 'false',
            'Running Time': '-',
            'Start Time': '-',
            'End Time': '-',
            'System Prompt Hash': '-',
        }
        
        if dt is None:
            dt = datetime.datetime.now()
        self._traj_log_metadata['Start Time'] = dt.strftime('%Y-%m-%d %H:%M:%S')

    def log_step(self):
        assert len(self._traj_log) <= len(self.messages)
        # This function must be called for every step; do not call it after several steps as it uses self.current_observation which may cause bugs
        assert len(self._trajectory.steps) - self._traj_log_total_steps <= 1
        
        for i in range(len(self._traj_log), len(self.messages)):
            msg = self.messages[i]
            role = msg['role']

            log_dict = {
                'id': i+1,
                'role': role,
            }

            if role == 'assistant' and self._traj_log_total_steps < len(self._trajectory.steps):
                # step
                log_dict['type'] = 'dict'
                step = self._trajectory.steps[self._traj_log_total_steps]
                
                log_dict['step'] = self._traj_log_total_steps + 1
                log_dict['stepinfo'] = {
                    'Action': step.action,
                    'Reward': step.reward,
                }
                log_dict['fullText'] = step.model_response
                log_dict['content'] = {
                    'thought': step.thought,
                    'action': step.action,
                    'message_content': {
                        'value': msg['content'],
                        'detailed': True,
                    }
                }
                log_dict['environment'] = {
                    'header': f'Step {self._traj_log_total_steps + 1}',
                    'thumbnail': self._get_thumbnail(self.current_observation),
                    'info': {
                        'observation': {'value': self.current_observation['observation'], 'detailed': True},
                        **{k: {'value': v, 'detailed': True} for k, v in self.current_observation['observation_info'].items()},
                        **{k: {'value': v, 'detailed': True} for k, v in step.info.items()},
                        'reward': str(step.reward),
                        'done': str(step.done),
                        **self._get_visible_obs_info(self.current_observation),
                    },
                }
                self._traj_log_total_steps += 1
            elif role == 'system' and len(self._traj_log) == 0:
                # Attach initial environment info
                log_dict['type'] = 'text'
                log_dict['content'] = msg['content']

                log_dict['environment'] = {
                    'header': f'Initial State',
                    'thumbnail': self._get_thumbnail(self._initial_observation),
                    'info': {
                        'observation': {'value': self._initial_observation['observation'], 'detailed': True},
                        **{k: {'value': v, 'detailed': True} for k, v in self._initial_observation['observation_info'].items()},
                        **self._get_visible_obs_info(self._initial_observation),
                    },
                }
            else:
                # Just a message
                log_dict['type'] = 'text'
                log_dict['content'] = msg['content']

            # log system message
            if role == 'system' and self._traj_log_metadata['System Prompt Hash'] == '-':
                self._traj_log_metadata['System Prompt Hash'] = get_prompt_hash(msg['content'])

            self._traj_log.append(log_dict)
    
    def log_last(self):
        self._traj_log_metadata['Done'] = 'true'
        self._traj_log_metadata['Total Reward'] = str(self._trajectory.reward)
        self._traj_log_metadata['Total Steps'] = str(len(self._trajectory.steps))
        self._traj_log_metadata['End Time'] = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    
    def save_traj_log(self, file_path: str):
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(self.dump_traj_log(), f, ensure_ascii=False, indent=2)

    def get_traj_log_dict(self):
        if 'Running Time' in self._traj_log_metadata and 'Start Time' in self._traj_log_metadata and self._traj_log_metadata.get('Done', 'false') == 'false':
            self._traj_log_metadata['Running Time'] = str(datetime.datetime.now() - datetime.datetime.strptime(self._traj_log_metadata['Start Time'], '%Y-%m-%d %H:%M:%S'))

        return {'dashboard': self._traj_log_metadata, 'messages': self._traj_log}

    def get_traj_log(self):
        return self._traj_log

    def _get_thumbnail(self, observation: dict) -> str:
        return ''

    def _get_visible_obs_info(self, observation: dict) -> dict:
        # Information that will be shown directly on the environment step card. (e.g. progress, correct_cells)
        # If not visible_obs_info, it will be classified as detailed info and will be shown only when expanded.
        return {}
