from PIL import Image
import io
import numpy as np
import torch
from collections import defaultdict

def jpeg_incompressibility():
    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images = [Image.fromarray(image) for image in images]
        buffers = [io.BytesIO() for _ in images]
        for image, buffer in zip(images, buffers):
            image.save(buffer, format="JPEG", quality=95)
        sizes = [buffer.tell() / 1000 for buffer in buffers]
        return np.array(sizes), {}

    return _fn


def jpeg_compressibility():
    jpeg_fn = jpeg_incompressibility()

    def _fn(images, prompts, metadata):
        rew, meta = jpeg_fn(images, prompts, metadata)
        return -rew/500, meta

    return _fn


def cotracker_reward(device):
    def traj_mask_to_coords(traj_mask):
        """
        traj_mask: Tensor, [B, T, C, H, W],
        return: Tensor, [B, T, N, 3] (x,y)
        """
        # print("traj_mask shape:", traj_mask.shape)
        B, T, C, H, W = traj_mask.shape
        traj_coords = []
        # print("#####mask",traj_mask.min().item(), traj_mask.max().item())
        for b in range(B):
            coords_b = []
            for t in range(T):
                mask = traj_mask[b, t, 0]  # [H, W]
                y, x = (mask > 0.85).nonzero(as_tuple=True)
                coords = torch.stack([x, y], dim=-1)  # [N, 2]，x=列，y=行
                coords_b.append(coords)
            traj_coords.append(coords_b)
        max_N = max(coords.shape[0] for video in traj_coords for coords in video)
        traj_out = torch.full((B, T, max_N, 2), -1, dtype=torch.float32)  # 用 -1 表示无点

        for b in range(B):
            for t in range(T):
                coords = traj_coords[b][t]
                traj_out[b, t, :coords.shape[0]] = coords

        return traj_out

    def _fn(images, traj_mask):
        """
        images: Tensor, [B, T, C, H, W]
        traj_mask: Tensor, [B, T, C, H, W]
        """
        # Load CoTracker model
        cotracker = torch.hub.load("co-tracker", "cotracker2", source="local").to(device)

        # Process images
        if isinstance(images, torch.Tensor):
            # Convert to uint8 and NHWC format
            images = (images * 255).round().clamp(0, 255).to(device)

        # 1. Convert mask to coordinates
        traj = traj_mask_to_coords(traj_mask).to(device)
        valid_mask = traj[..., 0] != -1
        # print(traj.shape, traj)
        init_points = traj[:, 0]
        t_coords = torch.zeros(init_points.shape[0], init_points.shape[1], 1, device=device)  # 时间戳 t=0
        # print(t_coords.device, init_points.device)
        init_points = torch.cat([t_coords, init_points], dim=-1)  # [1, N_i, 3]
        # print("#######init_points:", init_points.shape, init_points)

        # 2. Run CoTracker
        pred_tracks, pred_visibility = cotracker(images, queries=init_points)
        # print("pred_tracks:", pred_tracks.shape)
        

        # 3. Normalize to [0,1]
        B, _, _, H, W = images.shape
        pred_tracks = pred_tracks.clone()  # 
        pred_tracks[..., 0] = pred_tracks[..., 0] / W  # 
        pred_tracks[..., 1] = pred_tracks[..., 1] / H  # 
        traj = traj.clone()  # 
        traj[..., 0] = traj[..., 0] / W  # 
        traj[..., 1] = traj[..., 1] / H

        # 4. Compute loss only for valid points (traj != -1)
        valid_mask = valid_mask.unsqueeze(-1).expand_as(traj)
        mse_loss = torch.mean(((pred_tracks - traj) ** 2)[valid_mask].view(B, -1), dim=1)
        
        # print("mse_loss:", mse_loss.item())
        reward = 0.3 - mse_loss
        reward = torch.clamp(reward, 0.0, 0.3) * 20
        return reward, {}

    return _fn

def aesthetic_score():
    from flow_grpo.aesthetic_scorer import AestheticScorer

    scorer = AestheticScorer(dtype=torch.float32).cuda()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8)
        else:
            images = images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            images = torch.tensor(images, dtype=torch.uint8)
        scores = scorer(images)
        return scores, {}

    return _fn

def clip_score():
    from flow_grpo.clip_scorer import ClipScorer

    scorer = ClipScorer(dtype=torch.float32).cuda()

    def _fn(images, prompts, metadata):
        if not isinstance(images, torch.Tensor):
            images = images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            images = torch.tensor(images, dtype=torch.uint8)/255.0
        scores = scorer(images, prompts)
        return scores, {}

    return _fn

def image_similarity_score(device):
    from flow_grpo.clip_scorer import ClipScorer

    scorer = ClipScorer(device=device).cuda()

    def _fn(images, ref_images):
        if not isinstance(images, torch.Tensor):
            images = images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            images = torch.tensor(images, dtype=torch.uint8)/255.0
        if not isinstance(ref_images, torch.Tensor):
            ref_images = [np.array(img) for img in ref_images]
            ref_images = np.array(ref_images)
            ref_images = ref_images.transpose(0, 3, 1, 2)  # NHWC -> NCHW
            ref_images = torch.tensor(ref_images, dtype=torch.uint8)/255.0
        scores = scorer.image_similarity(images, ref_images)
        return scores, {}

    return _fn

def pickscore_score(device):
    from flow_grpo.pickscore_scorer import PickScoreScorer

    scorer = PickScoreScorer(dtype=torch.float32, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

def imagereward_score(device):
    from flow_grpo.imagereward_scorer import ImageRewardScorer

    scorer = ImageRewardScorer(dtype=torch.float32, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        prompts = [prompt for prompt in prompts]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

def qwenvl_score(device):
    from flow_grpo.qwenvl import QwenVLScorer

    scorer = QwenVLScorer(dtype=torch.bfloat16, device=device)

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
            images = [Image.fromarray(image) for image in images]
        prompts = [prompt for prompt in prompts]
        scores = scorer(prompts, images)
        return scores, {}

    return _fn

    
def ocr_score(device):
    from flow_grpo.ocr import OcrScorer

    scorer = OcrScorer()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        scores = scorer(images, prompts)
        # change tensor to list
        return scores, {}

    return _fn

def video_ocr_score(device):
    from flow_grpo.ocr import OcrScorer_video_or_image

    scorer = OcrScorer_video_or_image()

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            if images.dim() == 4 and images.shape[1] == 3:
                images = images.permute(0, 2, 3, 1) 
            elif images.dim() == 5 and images.shape[2] == 3:
                images = images.permute(0, 1, 3, 4, 2)
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
        scores = scorer(images, prompts)
        # change tensor to list
        return scores, {}

    return _fn

def deqa_score_remote(device):
    """Submits images to DeQA and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://127.0.0.1:18086"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        del prompts
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        all_scores = []
        for image_batch in images_batched:
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            response_data = pickle.loads(response.content)

            all_scores += response_data["outputs"]

        return all_scores, {}

    return _fn

def geneval_score(device):
    """Submits images to GenEval and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://127.0.0.1:18085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadatas, only_strict):
        del prompts
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size))
        all_scores = []
        all_rewards = []
        all_strict_rewards = []
        all_group_strict_rewards = []
        all_group_rewards = []
        for image_batch, metadata_batched in zip(images_batched, metadatas_batched):
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
                "meta_datas": list(metadata_batched),
                "only_strict": only_strict,
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            response_data = pickle.loads(response.content)

            all_scores += response_data["scores"]
            all_rewards += response_data["rewards"]
            all_strict_rewards += response_data["strict_rewards"]
            all_group_strict_rewards.append(response_data["group_strict_rewards"])
            all_group_rewards.append(response_data["group_rewards"])
        all_group_strict_rewards_dict = defaultdict(list)
        all_group_rewards_dict = defaultdict(list)
        for current_dict in all_group_strict_rewards:
            for key, value in current_dict.items():
                all_group_strict_rewards_dict[key].extend(value)
        all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict)

        for current_dict in all_group_rewards:
            for key, value in current_dict.items():
                all_group_rewards_dict[key].extend(value)
        all_group_rewards_dict = dict(all_group_rewards_dict)

        return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict

    return _fn

def unifiedreward_score_remote(device):
    """Submits images to DeQA and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    url = "http://10.82.120.15:18085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))

        all_scores = []
        for image_batch, prompt_batch in zip(images_batched, prompts_batched):
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
                "prompts": prompt_batch
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            print("response: ", response)
            print("response: ", response.content)
            response_data = pickle.loads(response.content)

            all_scores += response_data["outputs"]

        return all_scores, {}

    return _fn



def multi_score(device, score_dict):
    score_functions = {
        "deqa": deqa_score_remote,
        "ocr": ocr_score,
        "video_ocr": video_ocr_score,
        "imagereward": imagereward_score,
        "pickscore": pickscore_score,
        "qwenvl": qwenvl_score,
        "aesthetic": aesthetic_score,
        "jpeg_compressibility": jpeg_compressibility,
        "unifiedreward": unifiedreward_score_sglang,
        "geneval": geneval_score,
        "clipscore": clip_score,
        "image_similarity": image_similarity_score,
        "cotracker": cotracker_reward
    }
    score_fns={}
    for score_name, weight in score_dict.items():
        score_fns[score_name] = score_functions[score_name](device) if 'device' in score_functions[score_name].__code__.co_varnames else score_functions[score_name]()

    # only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time.
    def _fn(images, prompts, metadata=None, ref_images=None, traj=None, only_strict=True,is_video=False):
        total_scores = []
        score_details = {}
        if is_video:
            batch_size, num_frames, *image_shape = images.shape
            # Initialize total_scores as zeros for the batch
            total_scores = np.zeros(batch_size, dtype=np.float32)
            # Store scores for each score function across frames
            frame_scores = {score_name: [] for score_name in score_dict}
            
            for i in range(num_frames):
                image = images[:, i]  # Shape: (batch_size, C, H, W)
                for score_name, weight in score_dict.items():
                    if score_name == "geneval":
                        scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](
                            image, prompts, metadata, only_strict
                        )
                        # Store in score_details
                        if i == 0:  # Only store once, assuming these are frame-independent
                            score_details['accuracy'] = rewards
                            score_details['strict_accuracy'] = strict_rewards
                            for key, value in group_strict_rewards.items():
                                score_details[f'{key}_strict_accuracy'] = value
                            for key, value in group_rewards.items():
                                score_details[f'{key}_accuracy'] = value
                    elif score_name == "image_similarity":
                        scores, rewards = score_fns[score_name](image, ref_images)
                    else:
                        scores, rewards = score_fns[score_name](image, prompts, metadata)
                    
                    # Store scores for this frame
                    frame_scores[score_name].append(scores)
                    # Add weighted scores to total
                    weighted_scores = scores * weight
                    total_scores += weighted_scores
                
                # Convert frame_scores to numpy arrays and compute mean across frames
                for score_name in score_dict:
                    frame_scores[score_name] = torch.stack(frame_scores[score_name], axis=1)  # Shape: (batch_size, num_frames)
                    score_details[score_name] = torch.mean(frame_scores[score_name], axis=1).tolist()  # Mean across frames
                
                # Compute average total score across frames
                score_details['avg'] = total_scores / num_frames  # Normalize by number of frames
                
                # Check for nan values
                if np.any(np.isnan(total_scores)):
                    print(f"Warning: NaN detected in total_scores for frame {i}")
                    total_scores = np.nan_to_num(total_scores, nan=0.0)
                
                # Ensure score_details['avg'] is a list
                score_details['avg'] = total_scores.tolist()
                print("total_score",total_scores)

        else:
            for score_name, weight in score_dict.items():
                if score_name == "geneval":
                    scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](images, prompts, metadata, only_strict)
                    score_details['accuracy'] = rewards
                    score_details['strict_accuracy'] = strict_rewards
                    for key, value in group_strict_rewards.items():
                        score_details[f'{key}_strict_accuracy'] = value
                    for key, value in group_rewards.items():
                        score_details[f'{key}_accuracy'] = value
                elif score_name == "image_similarity":
                    scores, rewards = score_fns[score_name](images, ref_images)
                elif score_name == "cotracker":
                    scores, rewards = score_fns[score_name](images, traj)
                else:
                    scores, rewards = score_fns[score_name](images, prompts, metadata)
                                                        
                score_details[score_name] = scores
                weighted_scores = [weight * score for score in scores]
                
                if not total_scores:
                    total_scores = weighted_scores
                else:
                    total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)]
            
            score_details['avg'] = total_scores
        return score_details, {}

    return _fn

def video_reward(device, score_dict):
    score_functions = {
        "deqa": deqa_score_remote,
        "ocr": ocr_score,
        "video_ocr": video_ocr_score,
        "imagereward": imagereward_score,
        "pickscore": pickscore_score,
        "qwenvl": qwenvl_score,
        "aesthetic": aesthetic_score,
        "jpeg_compressibility": jpeg_compressibility,
        "unifiedreward": unifiedreward_score_sglang,
        "geneval": geneval_score,
        "clipscore": clip_score,
        "image_similarity": image_similarity_score,
        "cotracker": cotracker_reward
    }
    score_fns={}
    for score_name, weight in score_dict.items():
        score_fns[score_name] = score_functions[score_name](device) if 'device' in score_functions[score_name].__code__.co_varnames else score_functions[score_name]()

    # only_strict is only for geneval. During training, only the strict reward is needed, and non-strict rewards don't need to be computed, reducing reward calculation time.
    def _fn(images, prompts, metadata=None, ref_images=None, traj=None, only_strict=True,is_video=False, random_block=None):
        total_scores = []
        score_details = {}
        if is_video:
            batch_size, num_frames, *image_shape = images.shape
            # Initialize total_scores as zeros for the batch
            total_scores = np.zeros(batch_size, dtype=np.float32)
            # Store scores for each score function across frames
            frame_scores = {score_name: [] for score_name in score_dict}
            
            for i in range(num_frames):
                image = images[:, i]  # Shape: (batch_size, C, H, W)
                for score_name, weight in score_dict.items():
                    if score_name == "geneval":
                        scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](
                            image, prompts, metadata, only_strict
                        )
                        # Store in score_details
                        if i == 0:  # Only store once, assuming these are frame-independent
                            score_details['accuracy'] = rewards
                            score_details['strict_accuracy'] = strict_rewards
                            for key, value in group_strict_rewards.items():
                                score_details[f'{key}_strict_accuracy'] = value
                            for key, value in group_rewards.items():
                                score_details[f'{key}_accuracy'] = value
                    elif score_name == "image_similarity":
                        scores, rewards = score_fns[score_name](image, ref_images)
                    else:
                        scores, rewards = score_fns[score_name](image, prompts, metadata)
                    
                    # Store scores for this frame
                    frame_scores[score_name].append(scores)
                    # Add weighted scores to total
                    weighted_scores = scores * weight
                    total_scores += weighted_scores
                
                # Convert frame_scores to numpy arrays and compute mean across frames
                for score_name in score_dict:
                    frame_scores[score_name] = torch.stack(frame_scores[score_name], axis=1)  # Shape: (batch_size, num_frames)
                    score_details[score_name] = torch.mean(frame_scores[score_name], axis=1).tolist()  # Mean across frames
                
                # Compute average total score across frames
                score_details['avg'] = total_scores / num_frames  # Normalize by number of frames
                
                # Check for nan values
                if np.any(np.isnan(total_scores)):
                    print(f"Warning: NaN detected in total_scores for frame {i}")
                    total_scores = np.nan_to_num(total_scores, nan=0.0)
                
                # Ensure score_details['avg'] is a list
                score_details['avg'] = total_scores.tolist()
                print("total_score",total_scores)

        else:
            for score_name, weight in score_dict.items():
                if score_name == "geneval":
                    scores, rewards, strict_rewards, group_rewards, group_strict_rewards = score_fns[score_name](images, prompts, metadata, only_strict)
                    score_details['accuracy'] = rewards
                    score_details['strict_accuracy'] = strict_rewards
                    for key, value in group_strict_rewards.items():
                        score_details[f'{key}_strict_accuracy'] = value
                    for key, value in group_rewards.items():
                        score_details[f'{key}_accuracy'] = value
                elif score_name == "image_similarity":
                    scores, rewards = score_fns[score_name](images, ref_images)
                elif score_name == "cotracker":
                    scores, rewards = score_fns[score_name](images, traj)
                else:
                    scores, rewards = score_fns[score_name](images[:,random_block], prompts, metadata)
                                                        
                score_details[score_name] = scores
                # print("#######score",scores)
                weighted_scores = [weight * score for score in scores]
                
                if not total_scores:
                    total_scores = weighted_scores
                else:
                    total_scores = [total + weighted for total, weighted in zip(total_scores, weighted_scores)]
            
            score_details['avg'] = total_scores 
        return score_details, {}

    return _fn

def main():
    import torchvision.transforms as transforms

    image_paths = [
        "nasa.jpg",
    ]

    transform = transforms.Compose([
        transforms.ToTensor(),  # Convert to tensor
    ])

    images = torch.stack([transform(Image.open(image_path).convert('RGB')) for image_path in image_paths])
    prompts=[
        'A astronaut’s glove floating in zero-g with "NASA 2049" on the wrist',
    ]
    metadata = {}  # Example metadata
    score_dict = {
        "unifiedreward": 1.0
    }
    # Initialize the multi_score function with a device and score_dict
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    scoring_fn = multi_score(device, score_dict)
    # Get the scores
    scores, _ = scoring_fn(images, prompts, metadata)
    # Print the scores
    print("Scores:", scores)


if __name__ == "__main__":
    main()