import os
import json
import time
import asyncio
from typing import Dict, Any, Tuple, List
import math
import numpy as np
from PIL import Image
import imageio

from agents.base_agent import BaseAgent
from models.base_model import BaseModel, ObservationPrompt, Prompt, MaxTokenLimit
from prompts.action import get_action_prompt
from prompts.observation import get_observation_prompt
from prompts.system import get_system_prompt


class HumanAgent(BaseAgent, agent_type="human_agent"):
    def __init__(self,
                 env_name: str,
                 pic_path: str,
                 log_path: str,
                 polling_interval: float = 0.5,
                 **kwargs):
        self.env_name = env_name
        self.pic_path = pic_path
        self.log_path = log_path
        self.polling_interval = polling_interval
        self.turn_count = 0   


    def init_info(self, observation):
        if observation == []:
            return 
        io_prompt = "Do not include any extra commentary or explanation."
        prompt = Prompt(
            system_prompt=get_system_prompt(self.env_name),
            observation_prompt=get_observation_prompt(self.env_name, observation),
            action_prompt=f"{get_action_prompt(self.env_name, observation)}\n\n{io_prompt}",
        )
        # update jpg and jsonl
        self._update_pic(prompt.observation_prompt.image_paths)
        self._update_jsonl(observation.legal_actions, prompt.observation_prompt.text, self.env_name, is_begin=True)

    async def _act(self, observation) -> Tuple[int, Dict[str, Any]]:
        io_prompt = "Do not include any extra commentary or explanation."
        prompt = Prompt(
            system_prompt=get_system_prompt(self.env_name),
            observation_prompt=get_observation_prompt(self.env_name, observation),
            action_prompt=f"{get_action_prompt(self.env_name, observation)}\n\n{io_prompt}",
        )
        messages = self._get_messages(prompt)

        # update jpg and jsonl
        self._update_pic(prompt.observation_prompt.image_paths)
        self._update_jsonl(observation.legal_actions, prompt.observation_prompt.text, self.env_name)

        # wait for the human player
        action = await self._wait_for_user_action(
            observation.legal_actions, start_line=self._line_count(self.log_path)
        )

        agent_info = {
            "messages": messages,
            "action": action,
            "action_string": observation.legal_actions[action],
            "legal_actions": observation.legal_actions
        }
        return action, agent_info

    def _update_pic(self, image_paths: List[str]) -> None:
        if not image_paths:
            print("[HumanAgent] _update_pic: no images.")
            return

        imgs = []
        for p in image_paths:
            try:
                imgs.append(Image.open(p).convert("RGB"))
            except Exception as e:
                print(f"[HumanAgent] _update_pic: failed to load {p}: {e}")

        if not imgs:
            print("[HumanAgent] _update_pic: all loads failed.")
            return

        w, h = imgs[0].size
        col_gap = int(w / 10)
        row_gap = int(h / 10)

        n = len(imgs)
        cols = math.ceil(math.sqrt(n))          
        rows = math.ceil(n / cols)

        canvas_w = cols * w + (cols - 1) * col_gap
        canvas_h = rows * h + (rows - 1) * row_gap
        canvas = Image.new("RGB", (canvas_w, canvas_h), color=(255, 255, 255))

        for idx, img in enumerate(imgs):
            r, c = divmod(idx, cols)
            x = c * (w + col_gap)
            y = r * (h + row_gap)
            canvas.paste(img, (x, y))

        try:
            canvas.save(self.pic_path)
        except Exception as e:
            print(f"[HumanAgent] _update_pic: failed to save {self.pic_path}: {e}")


    def _update_jsonl(self, legal_actions, rules, env_name, is_begin=False):
        if not is_begin:
            self.turn_count += 1
        obs_record = {
            "env": env_name,
            "step": self.turn_count,      
            "legal_actions": legal_actions,
            "rules": rules,
            "timestamp": time.time(),
            "is_begin": is_begin,
        }
        with open(self.log_path, "a") as f:
            json.dump(obs_record, f)
            f.write("\n")


    async def _wait_for_user_action(self,
                                    legal_actions: Dict[int, str],
                                    start_line: int) -> int:
        while True:
            await asyncio.sleep(self.polling_interval)
            try:
                with open(self.log_path, "r") as f:
                    lines = f.readlines()
            except FileNotFoundError:
                continue
            if len(lines) <= start_line:
                continue
            for i in range(start_line, len(lines)):
                try:
                    data = json.loads(lines[i].strip())
                except json.JSONDecodeError:
                    continue
                if "action" in data:
                    act = data["action"]
                    try:
                        act = int(act)
                    except (TypeError, ValueError):
                        continue
                    if act in legal_actions:
                        return act
            start_line = len(lines)  

    @staticmethod
    def _line_count(path: str) -> int:
        try:
            with open(path, "r") as f:
                return sum(1 for _ in f)
        except FileNotFoundError:
            return 0
        
    def _get_messages(self, prompt):
        system_prompt = [{"type": "text", "text": prompt.system_prompt}]
        user_prompt = []
        if len(prompt.observation_prompt.image_paths) > 0:
            for image_path in prompt.observation_prompt.image_paths:
                user_prompt.append({
                    "type": "image_url",
                    "image_url": {
                        "url": image_path,
                        "detail": "high",
                    },
                })
        user_prompt.append({
            "type": "text",
            "text": f"{prompt.observation_prompt.text}\n\n{prompt.action_prompt}",
        })
        messages = [
            {
                "role": "system",
                "content": system_prompt
            },
            {
                "role": "user",
                "content": user_prompt
            },
        ]
        return messages