import base64
import io
import re
import os
import json
import uuid
from typing import Any, Dict, List, Optional
from pathlib import Path

import requests
import torch
import numpy as np
from PIL import Image

from verl import DataProto
from verl.workers.reward_manager import register
from signal_config import SignalConfig
from dataclasses import fields

# 全局变量：保存路径
SAVE_DIR = os.environ.get("SAVE_DIR", "responses_sft_1")

def _extract_prompt(text: str) -> str:
    """从生成的文本中提取<prompt>标签内容"""
    m = re.search(r"Revised Prompt:\n{(.*)}", text, re.DOTALL)
    if not m:
        m = re.search(r"Revised Prompt:(.*)", text, re.DOTALL)
    return m.group(1).strip() if m else text.strip()

def _to_data_url(obj) -> Optional[str]:
    """将各种图像格式转换为data URL"""
    if isinstance(obj, str):
        return obj
    
    try:
        if isinstance(obj, Image.Image):
            buf = io.BytesIO()
            obj.convert("RGB").save(buf, format="JPEG", quality=90)
            return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
        
        if isinstance(obj, (bytes, bytearray)):
            return "data:image/jpeg;base64," + base64.b64encode(obj).decode("utf-8")
        
        if isinstance(obj, np.ndarray):
            if obj.dtype != np.uint8:
                obj = np.clip(obj, 0, 255).astype(np.uint8)
            if obj.ndim == 2:
                pil = Image.fromarray(obj, mode="L").convert("RGB")
            else:
                pil = Image.fromarray(obj[..., :3]).convert("RGB")
            buf = io.BytesIO()
            pil.save(buf, format="JPEG", quality=90)
            return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
    except Exception:
        pass
    
    return None


@register("diff")
class DiffusionImageEditRewardManager:
    """
    基于Signal服务的图像编辑奖励管理器
    1. 提取improved_prompt
    2. 通过signal服务提交数据
    3. 等待外部服务返回奖励
    """
    
    def __init__(self, tokenizer, fastapi_base_url: str, num_examine=5, compute_score=None, **kwargs):
        self.tokenizer = tokenizer
        
        self.signal_settings = SignalConfig()
        self.signal_base_url = self.signal_settings.url
        self.timeout = self.signal_settings.timeout
        
        self.num_examine = num_examine
        self.compute_score = compute_score
        self.session = requests.Session()
        
        # 创建保存目录
        self.save_dir = Path(SAVE_DIR)
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def _submit_signal_payload(self, key: str, payload: Dict[str, Any]) -> bool:
        try:
            resp = self.session.post(
                f"{self.signal_base_url}/submit/{key}",
                json=payload,
                timeout=self.timeout,
            )
            resp.raise_for_status()
            return True
        except Exception as exc:  # noqa: BLE001
            print(f"[ERROR] Failed to submit payload for key {key}: {exc}")
            return False

    def _wait_signal_payload(self, key: str) -> Dict[str, Any]:
        try:
            resp = self.session.get(
                f"{self.signal_base_url}/wait/{key}",
                params={"timeout": self.timeout},
                timeout=self.timeout + 10,
            )
            resp.raise_for_status()
            data = resp.json().get("data", {})
            return data or {}
        except Exception as exc:  # noqa: BLE001
            print(f"[ERROR] Failed to receive payload for key {key}: {exc}")
            return {}

    def _extract_images(self, sample_data) -> List[str]:
        """从sample数据中提取图像"""
        images = []
        try:
            image_data = sample_data.get("image", [])
            if not isinstance(image_data, (list, tuple)):
                image_data = [image_data]
            
            for item in image_data:
                url = _to_data_url(item)
                if url:
                    images.append(url)
        except Exception:
            pass
        
        return images

    def _save_results(
        self,
        global_steps: int,
        data_source: str,
        idx: int,
        rollout_id: str,
        full_response: str,
        original_prompt: str,
        improved_prompt: str,
        scores: dict,
        extra_info: dict,
        original_images: List[str] = None,
        category: str = None,
        split: str = "train",
    ):
        """保存完整响应和评分结果"""
        try:
            # 创建目录结构: save_dir/{split}_{global_steps}/{data_source}_{idx}/
            safe_data_source = data_source.replace("/", "_").replace(" ", "_")
            parent_dir = self.save_dir / f"{split}_{global_steps}" / f"{safe_data_source}_{idx}"
            parent_dir.mkdir(parents=True, exist_ok=True)
            
            # 创建JSON数据
            result_data = {
                "global_steps": global_steps,
                "data_source": data_source,
                "idx": idx,
                "rollout_id": rollout_id,
                "original_prompt": original_prompt,
                "improved_prompt": improved_prompt,
                "full_response": full_response,
                "extra_info": extra_info,
                "category": category,
                **scores
            }
            
            # 保存原始图像（如果有）
            if original_images:
                image_dir = parent_dir / f"{rollout_id}_images"
                image_dir.mkdir(parents=True, exist_ok=True)
                original_image_paths = []
                for i, orig_img in enumerate(original_images):
                    try:
                        if orig_img.startswith("data:image"):
                            orig_img_base64 = orig_img.split(",", 1)[1]
                        else:
                            orig_img_base64 = orig_img
                        
                        orig_img_bytes = base64.b64decode(orig_img_base64)
                        orig_img_pil = Image.open(io.BytesIO(orig_img_bytes))
                        
                        orig_img_file = image_dir / f"original_{i}.png"
                        orig_img_pil.save(orig_img_file)
                        original_image_paths.append(str(orig_img_file.relative_to(self.save_dir)))
                    except Exception as e:
                        print(f"[ERROR] Failed to save original image {i}: {e}")
                
                if original_image_paths:
                    result_data["original_image_paths"] = original_image_paths
            
            # 保存JSON
            json_file = parent_dir / f"{rollout_id}.json"
            with open(json_file, "w", encoding="utf-8") as f:
                json.dump(result_data, f, indent=2, ensure_ascii=False)
            
#             print(f"[SAVE] Saved to {parent_dir}/{rollout_id}.json")
            
        except Exception as e:
            print(f"[ERROR] Failed to save sample {global_steps}/{data_source}_{idx}/{rollout_id}: {e}")

    def _group_samples_by_source(self, data: DataProto, batch_size: int) -> dict:
        """根据data_source和idx将样本分组"""
        groups = {}
        
        for idx in range(batch_size):
            data_source = data.non_tensor_batch["data_source"][idx]
            try:
                extra_info = data.non_tensor_batch.get("extra_info", [{}])[idx]
                sample_idx = extra_info.get("idx", idx)
            except Exception:
                sample_idx = idx
            
            # 使用 (data_source, sample_idx) 作为组的key
            group_key = (data_source, sample_idx)
            
            if group_key not in groups:
                groups[group_key] = []
            groups[group_key].append(idx)
        
        return groups

    def _process_batch(self, data: DataProto, attention_mask, responses, batch_size: int):
        """处理整个batch：提取所有prompts并通过signal服务提交"""

        mode = "eval" if bool(data.meta_info.get("validate")) else "train"
        submit_samples: List[Dict[str, Any]] = []
        sample_metadata: List[Dict[str, Any]] = []
        metadata_by_rollout: Dict[str, Dict[str, Any]] = {}

        for idx in range(batch_size):
            data_source = data.non_tensor_batch["data_source"][idx]

            try:
                extra_info = data.non_tensor_batch.get("extra_info", [{}])[idx]
                original_prompt = extra_info.get("prompt", "")
                explanation = extra_info.get("explanation", "No additional explanation provided.")
                category = extra_info.get("category", "unknown")
                sample_idx = extra_info.get("idx", idx)
            except Exception:
                extra_info = {}
                original_prompt = ""
                explanation = "No additional explanation provided."
                category = "unknown"
                sample_idx = idx

            global_steps = data.meta_info.get("global_steps", 0)

            # try:
            #     images = self._extract_images(data.non_tensor_batch.get("multi_modal_data", [{}])[idx])
            # except Exception:
            #     images = []

            

            rollout_id = str(uuid.uuid4())[:8]

            attn = attention_mask[idx]
            resp_ids = responses[idx]
            resp_len = resp_ids.shape[-1]
            valid_len = int(attn[-resp_len:].sum().item())

            if valid_len <= 0:
                improved_prompt = ""
                full_response = ""
            else:
                valid_resp = resp_ids[:valid_len]
                full_response = self.tokenizer.decode(valid_resp, skip_special_tokens=True)
                improved_prompt = _extract_prompt(full_response)

            metadata_payload = {
                "idx": idx,
                "rollout_id": rollout_id,
                "sample_idx": sample_idx,
                "data_source": data_source,
                "global_steps": global_steps,
                "original_prompt": original_prompt,
                "improved_prompt": improved_prompt,
                "category": category,
                "explanation": explanation,
            }
            
            if "multi_modal_data" in data.non_tensor_batch:
                metadata_payload["ori_image"] = extra_info.get("ori_image", None)
                metadata_payload["gt_image"] = extra_info.get("gt_image", None)
            # if full_response:
            #     metadata_payload["full_response"] = full_response

            submit_samples.append({
                "prompt": improved_prompt,
                "metadata": metadata_payload,
            })

            meta_record = {
                "idx": idx,
                "valid_len": valid_len,
                "rollout_id": rollout_id,
                "data_source": data_source,
                "sample_idx": sample_idx,
                "global_steps": global_steps,
                "original_prompt": original_prompt,
                "improved_prompt": improved_prompt,
                "full_response": full_response,
                "extra_info": extra_info,
                "category": category,
                "metadata": metadata_payload,
            }
            sample_metadata.append(meta_record)
            metadata_by_rollout[rollout_id] = meta_record

            if idx < self.num_examine:
                print(f"\n[Sample {idx} - {rollout_id}]")
                print(f"[global_steps] {global_steps}")
                print(f"[data_source] {data_source}")
                print(f"[sample_idx] {sample_idx}")
                print(f"[original_prompt] {original_prompt}")
                print(f"[full_response] {full_response}")
                print(f"[improved_prompt] {improved_prompt}")

        default_results = [
            (meta["idx"], max(0, meta["valid_len"] - 1), 0.0) for meta in sample_metadata
        ]

        if not submit_samples:
            return default_results

        global_steps = sample_metadata[0].get("global_steps", 0)
        submit_key = self.signal_settings.get_key(mode, "data", global_steps)
        submit_payload = {
            "global_step": global_steps,
            "mode": mode,
            ("eval_data" if mode == "eval" else "train_data"): submit_samples,
        }

        print(
            f"[INFO] Submitting {len(submit_samples)} {mode} samples to signal key '{submit_key}' (global_steps {global_steps})"
        )
        if not self._submit_signal_payload(submit_key, submit_payload):
            return default_results

        reward_key = self.signal_settings.get_key(mode, "reward", global_steps)
        print(f"[INFO] Waiting for {mode} rewards from key '{reward_key}' ...")
        reward_payload = self._wait_signal_payload(reward_key)
        print(f"[INFO] Received reward payload for key '{reward_key}'.")

        payload_mode = reward_payload.get("mode") if isinstance(reward_payload, dict) else None
        if payload_mode is not None and payload_mode != mode:
            print(f"[WARN] Received reward payload mode '{payload_mode}' but expected '{mode}'.")

        raw_rewards = reward_payload.get("rewards") if isinstance(reward_payload, dict) else reward_payload
        if isinstance(raw_rewards, dict):
            reward_entries = raw_rewards.get("rewards", [])
        elif isinstance(raw_rewards, list):
            reward_entries = raw_rewards
        else:
            reward_entries = []

        if not reward_entries:
            print(f"[WARN] No rewards received for key '{reward_key}'.")
            return default_results

        assigned_entries: List[tuple] = []
        used_rollout_ids = set()
        unmatched_payloads = []

        for entry in reward_entries:
            if isinstance(entry, dict):
                rollout_id = entry.get("rollout_id")
                if not rollout_id and isinstance(entry.get("metadata"), dict):
                    rollout_id = entry["metadata"].get("rollout_id")
                score_val = entry.get("score")
                if score_val is None:
                    for candidate_key in ("reward", "value", "avg"):
                        if candidate_key in entry:
                            score_val = entry[candidate_key]
                            break
                if rollout_id and rollout_id in metadata_by_rollout and score_val is not None:
                    used_rollout_ids.add(rollout_id)
                    assigned_entries.append(
                        (metadata_by_rollout[rollout_id], float(score_val), dict(entry))
                    )
                    continue
            unmatched_payloads.append(entry)

        remaining_metas = [
            meta for meta in sample_metadata if meta["rollout_id"] not in used_rollout_ids
        ]

        for entry, meta in zip(unmatched_payloads, remaining_metas):
            if isinstance(entry, dict):
                score_val = entry.get("score")
                if score_val is None:
                    for candidate_key in ("reward", "value", "avg"):
                        if candidate_key in entry:
                            score_val = entry[candidate_key]
                            break
                if score_val is None:
                    score_val = 0.0
                try:
                    score_val = float(score_val)
                except Exception:
                    score_val = 0.0
                payload_dict = dict(entry)
                payload_dict["score"] = score_val
            else:
                try:
                    score_val = float(entry)
                except Exception:
                    score_val = 0.0
                payload_dict = {"score": score_val}
            used_rollout_ids.add(meta["rollout_id"])
            assigned_entries.append((meta, score_val, payload_dict))

        if not assigned_entries:
            print(f"[WARN] Unable to align rewards for key '{reward_key}'.")
            return default_results

        results = []
        for meta, score_val, payload_dict in assigned_entries:
            payload_dict = dict(payload_dict)
            payload_dict["score"] = float(score_val)
            self._save_results(
                global_steps=meta["global_steps"],
                data_source=meta["data_source"],
                idx=meta["sample_idx"],
                rollout_id=meta["rollout_id"],
                full_response=meta["full_response"],
                original_prompt=meta["original_prompt"],
                improved_prompt=meta["improved_prompt"],
                scores=payload_dict,
                extra_info=meta["extra_info"],
                category=meta.get("category", mode),
                split=mode,
            )

            if meta["idx"] < self.num_examine:
                print(f"[Sample {meta['idx']} - {meta['rollout_id']}] score: {score_val}")

            pos = max(0, meta["valid_len"] - 1)
            results.append((meta["idx"], pos, float(score_val)))

        return results

    def __call__(self, data: DataProto, return_dict=False):
        """计算所有样本的reward"""
        # 如果已有rm_scores，直接返回
        if "rm_scores" in data.batch.keys():
            result = data.batch["rm_scores"]
            return {"reward_tensor": result} if return_dict else result
        
        batch_size = data.batch.batch_size[0]
        attention_mask = data.batch["attention_mask"]
        responses = data.batch["responses"]
        
        # 初始化reward tensor
        reward_tensor = torch.zeros_like(responses, dtype=torch.float32)
        
        # 处理整个batch
        results = self._process_batch(data, attention_mask, responses, batch_size)
        
        # 填充reward tensor
        for idx, pos, score in results:
            if pos >= 0:
                reward_tensor[idx, pos] = score
        
        return {"reward_tensor": reward_tensor} if return_dict else reward_tensor