from PIL import Image
import io, os, json
import numpy as np
import torch
from collections import defaultdict
from concurrent import futures
from torch.utils.data import Dataset, Sampler
import matplotlib.pyplot as plt
from typing import Tuple, List, Dict, Optional, Union
import torch.distributed as dist
import random

def _defaultdict_float_factory():
    return defaultdict(float)

def _defaultdict_int_factory():
    return defaultdict(int)


class EvalStats:
    def __init__(self, basic_info: List[dict] = None, track_keys: List[str] = None):
        self.basic_info = basic_info
        self.track_keys = track_keys
        self.track_stats = {}
        self.track_stats_count = {}
        self.mean_stats = {}
        if track_keys is not None:
            for key in track_keys:
                self.track_stats[key] = 0
                self.track_stats_count[key] = 0

    def update(self, key: str, value: Union[float, int]):
        if key in self.track_keys:
            self.track_stats[key] += float(value)
            self.track_stats_count[key] += 1

    def aggregate(self, normalize=True, log_overall=True):
        mean_score = 0
        for key in self.track_keys:
            self.track_stats[key] = torch.tensor(self.track_stats[key], device="cuda", dtype=torch.float).mean()
            self.track_stats_count[key] = torch.tensor(self.track_stats_count[key], device="cuda", dtype=torch.float).mean()
            dist.all_reduce(self.track_stats[key], op=dist.ReduceOp.SUM)
            dist.all_reduce(self.track_stats_count[key], op=dist.ReduceOp.SUM)
            if normalize and self.track_stats_count[key] > 0:
                self.mean_stats[f"{key}"] = self.track_stats[key] / self.track_stats_count[key]
            else:
                self.mean_stats[f"{key}"] = self.track_stats[key]
            mean_score += self.mean_stats[key]
        if log_overall:
            self.mean_stats["Overall"] = mean_score / len(self.track_keys)
        return self.mean_stats
    
    def to_dict(self):
        return_dict = {}
        if self.basic_info is not None:
            return_dict.update({k: v for k, v in self.basic_info})
        if self.mean_stats is not None:
            return_dict.update(self.mean_stats)
        return return_dict

class MixRLDataset(Dataset):
    def __init__(self, dataset_list, probs=None):
        self.dataset_list = dataset_list
        self.probs = probs
        self.data = []
        self.num_samples = 0

        for data_file in dataset_list:
            with open(data_file, "r", encoding='utf-8') as f:
                self.data.append([json.loads(line) for line in f])
                self.num_samples += len(self.data[-1])

        if probs is None:
            self.probs = [1.0 / len(dataset_list)] * len(dataset_list)
        else:
            assert len(probs) == len(dataset_list), f"probs length {len(probs)} != dataset_list length {len(dataset_list)}"
            self.probs = probs

    def __len__(self):
        return self.num_samples
    
    def get_data(self, seed:int):
        seed = np.random.default_rng(seed)
        idx = seed.choice(len(self.dataset_list), p=self.probs)
        return self.get_data_by_idx(seed, idx)

    def get_data_by_idx(self, seed:int, idx=0):
        seed = np.random.default_rng(seed)
        len_i = len(self.data[idx])
        return self.data[idx][seed.integers(0, len_i)]
    
    def __getitem__(self, idx):
        return self.data[0][idx]
    

def compute_advantages_over_policy_groups(rewards, policy_group):
    # 确保rewards是浮点数类型
    rewards = rewards.float()    
    # 收集所有rank的统计信息
    world_size = dist.get_world_size(policy_group)
    group_rewards = [torch.zeros_like(rewards) for _ in range(world_size)]
    dist.all_gather(group_rewards, rewards, group=policy_group)
    group_rewards = torch.cat(group_rewards, dim=0)
    group_mean = group_rewards.mean(dim=0)
    group_std = group_rewards.std(dim=0)
    advantage = (rewards - group_mean) / (group_std + 1e-6)
    return advantage, group_mean, group_std


class PromptDataset(Dataset):
    def __init__(self, file_path, num_samples=None):
        self.num_samples = num_samples
        with open(file_path, "r", encoding='utf-8') as f:
            self.prompts = [line.strip() for line in f if line.strip()]
            
    def __len__(self):
        if self.num_samples is not None:
            return self.num_samples
        else:
            return len(self.prompts)
    
    def __getitem__(self, idx):
        return self.prompts[idx]

    @staticmethod
    def collate_fn(examples):
        return examples
    

class JsonDataset(Dataset):
    def __init__(self, file_path, num_samples=None):
        self.num_samples = num_samples
        with open(file_path, "r", encoding='utf-8') as f:
            self.data = json.load(f)
            
    def __len__(self):
        if self.num_samples is not None:
            return self.num_samples
        else:
            return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    @staticmethod
    def collate_fn(examples):
        return examples
    

class JsonlDataset(Dataset):
    def __init__(self, file_path, num_samples=None):
        self.num_samples = num_samples
        with open(file_path, "r", encoding='utf-8') as f:
            self.data = [json.loads(line) for line in f]
            
    def __len__(self):
        if self.num_samples is not None:
            return self.num_samples
        else:
            return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
    @staticmethod
    def collate_fn(examples):
        return examples
    

class EditDataset(Dataset):
    def __init__(self, file_path="",
                 transform=None):
        self.transform = transform
        self.all_data = []
        with open(file_path, "r", encoding='utf-8') as f:
            data_dict = json.load(f)
            for data in data_dict:
                if data['idx'] > 100000:
                    self.all_data.append(data)
            
    def __len__(self):
        return len(self.all_data)
    
    def __getitem__(self, idx):
        return_dict = {}
        source_image = Image.open(self.all_data[idx]['source_image']).convert('RGB').resize((512, 512))
        target_image = Image.open(self.all_data[idx]['target_image']).convert('RGB').resize((512, 512))
        original_instruction = self.all_data[idx]['original_instruction']
        editing_reasoning = self.all_data[idx]['editing_reasoning']
        long_caption = self.all_data[idx]['long_caption']
        short_caption = self.all_data[idx]['short_caption']
        editing_step = self.all_data[idx]['editing_step']
        return_dict['source_image'] = source_image
        return_dict['target_image'] = target_image
        return_dict['original_instruction'] = original_instruction
        return_dict['editing_reasoning'] = editing_reasoning
        return_dict['long_caption'] = long_caption
        return_dict['short_caption'] = short_caption
        return_dict['editing_step'] = editing_step
        return return_dict
    
    @staticmethod
    def collate_fn(examples):
        return examples
    


class GenevalPromptDataset(Dataset):
    def __init__(self, dataset, split='train', postfix="", num_samples=None, shuffle=True):
        self.num_samples = num_samples
        if len(split) > 0:
            self.file_path = os.path.join(dataset, f'{split}_metadata{postfix}.jsonl')
        else:
            self.file_path = dataset
        with open(self.file_path, 'r', encoding='utf-8') as f:
            self.metadatas = [json.loads(line) for line in f]
            self.prompts = [item['prompt'] for item in self.metadatas]

    def __len__(self):
        if self.num_samples is not None:
            return self.num_samples
        else:
            return len(self.prompts)
    
    def __getitem__(self, idx):
        idx = idx % len(self.prompts)
        return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx]}

    @staticmethod
    def collate_fn(examples):
        prompts = [example["prompt"] for example in examples]
        metadatas = [example["metadata"] for example in examples]
        return prompts, metadatas
    

class TIIFDataset(Dataset):
    def __init__(self, file_path, num_samples=None, key="short_description"):
        self.num_samples = num_samples
        self.file_path = file_path
        with open(self.file_path, 'r', encoding='utf-8') as f:
            self.metadatas = [json.loads(line) for line in f]
            self.prompts = [item[key] for item in self.metadatas]
            for item in self.metadatas:
                item['prompt'] = item[key]

    def __len__(self):
        if self.num_samples is not None:
            return self.num_samples
        else:
            return len(self.prompts)
    
    def __getitem__(self, idx):
        idx = idx % len(self.prompts)
        return {"prompt": self.prompts[idx], "metadata": self.metadatas[idx]}

    @staticmethod
    def collate_fn(examples):
        prompts = [example["prompt"] for example in examples]
        metadatas = [example["metadata"] for example in examples]
        return prompts, metadatas
            

class DistributedKRepeatSampler(Sampler):
    def __init__(self, dataset, batch_size, k, world_size, rank, seed=0):
        self.dataset = dataset
        self.batch_size = batch_size  # 每卡的batch大小
        self.k = k                    # 每个样本重复的次数
        self.world_size = world_size  # 总卡数
        self.rank = rank              # 当前卡编号
        self.seed = seed              # 随机种子，用于同步
        
        # 计算每个迭代需要的不同样本数
        self.total_samples = self.world_size * self.batch_size
        assert self.total_samples % self.k == 0, f"k can not div n*b, k{k}-num_replicas{world_size}-batch_size{batch_size}"
        self.m = self.total_samples // self.k  # 不同样本数
        self.epoch=0

    def __iter__(self):
        while True:
            # 生成确定性的随机序列，确保所有卡同步
            g = torch.Generator(device="cuda")
            g.manual_seed(self.seed + self.epoch)
            # print('epoch', self.epoch)
            # 随机选择m个不同的样本
            indices = torch.randperm(len(self.dataset), generator=g, device="cuda")[:self.m].tolist()
            # print(self.rank, 'indices', indices)
            # 每个样本重复k次，生成总样本数n*b
            repeated_indices = [idx for idx in indices for _ in range(self.k)]
            
            # 打乱顺序确保均匀分配
            shuffled_indices = torch.randperm(len(repeated_indices), generator=g, device="cuda").tolist()
            shuffled_samples = [repeated_indices[i] for i in shuffled_indices]
            # print(self.rank, 'shuffled_samples', shuffled_samples)
            # 将样本分割到各个卡
            per_card_samples = []
            for i in range(self.world_size):
                start = i * self.batch_size
                end = start + self.batch_size
                per_card_samples.append(shuffled_samples[start:end])
            # print(self.rank, 'per_card_samples', per_card_samples[self.rank])
            # 返回当前卡的样本索引
            yield per_card_samples[self.rank]
    
    def set_epoch(self, epoch):
        self.epoch = epoch  # 用于同步不同 epoch 的随机状态


def geneval_score(url=""):
    """Submits images to GenEval and computes a reward.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 64
    sess = requests.Session()
    retries = Retry(
        total=3, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadatas, only_strict=False, return_reason=False):
        del prompts
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC
        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        metadatas_batched = np.array_split(metadatas, np.ceil(len(metadatas) / batch_size))
        all_scores = []
        all_rewards = []
        all_strict_rewards = []
        all_group_strict_rewards = []
        all_group_rewards = []
        all_reasons = []
        for image_batch, metadata_batched in zip(images_batched, metadatas_batched):
            jpeg_images = []

            # Compress the images using JPEG
            for image in image_batch:
                img = Image.fromarray(image)
                buffer = BytesIO()
                img.save(buffer, format="JPEG")
                jpeg_images.append(buffer.getvalue())

            # format for LLaVA server
            data = {
                "images": jpeg_images,
                "meta_datas": list(metadata_batched),
                "only_strict": only_strict,
            }
            data_bytes = pickle.dumps(data)

            # send a request to the llava server
            response = sess.post(url, data=data_bytes, timeout=120)
            response_data = pickle.loads(response.content)
            if return_reason:
                all_reasons.append(response_data["reasons"])
            all_scores += response_data["scores"]
            all_rewards += response_data["rewards"]
            all_strict_rewards += response_data["strict_rewards"]
            all_group_strict_rewards.append(response_data["group_strict_rewards"])
            all_group_rewards.append(response_data["group_rewards"])
        all_group_strict_rewards_dict = defaultdict(list)
        all_group_rewards_dict = defaultdict(list)
        for current_dict in all_group_strict_rewards:
            for key, value in current_dict.items():
                all_group_strict_rewards_dict[key].extend(value)
        all_group_strict_rewards_dict = dict(all_group_strict_rewards_dict)

        for current_dict in all_group_rewards:
            for key, value in current_dict.items():
                all_group_rewards_dict[key].extend(value)
        all_group_rewards_dict = dict(all_group_rewards_dict)
        if return_reason:
            return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict, all_reasons
        else:
            return all_scores, all_rewards, all_strict_rewards, all_group_rewards_dict, all_group_strict_rewards_dict

    return _fn


def plot_rewards_histogram(
    gathered_rewards: torch.Tensor,
    save_dir: str,
    curr_step: int,
    num_bins: int = 50,
    figsize: Tuple[int, int] = (10, 6),
    title: str = 'Distribution of Gathered Rewards',
    xlabel: str = 'Reward Value',
    ylabel: str = 'Frequency',
    save_path: str = 'rewards_histogram.png',
    dpi: int = 300,
    color: str = 'skyblue',
    show_stats: bool = True,
    show_plot: bool = False,
    save_pdf: bool = False
) -> dict:
    """
    对gathered rewards进行bin count并绘制直方图保存
    
    Args:
        gathered_rewards: PyTorch tensor containing reward values
        num_bins: Number of bins for histogram (default: 50)
        figsize: Figure size as (width, height) (default: (10, 6))
        title: Plot title (default: 'Distribution of Gathered Rewards')
        xlabel: X-axis label (default: 'Reward Value')
        ylabel: Y-axis label (default: 'Frequency')
        save_path: Path to save the histogram image (default: 'rewards_histogram.png')
        dpi: Resolution for saved image (default: 300)
        color: Histogram color (default: 'skyblue')
        show_stats: Whether to show statistical lines (default: True)
        show_plot: Whether to display the plot (default: False)
        save_pdf: Whether to also save as PDF (default: False)
    
    Returns:
        dict: Dictionary containing statistics and histogram data
    """
    
    # 将 tensor 转换为 numpy 数组以便绘图
    if isinstance(gathered_rewards, torch.Tensor):
        rewards_np = gathered_rewards.detach().cpu().numpy()
    else:
        rewards_np = np.array(gathered_rewards)
    
    # 创建直方图
    plt.figure(figsize=figsize)
    counts, bins, patches = plt.hist(
        rewards_np, 
        bins=num_bins, 
        range=(0, 1),
        alpha=0.7, 
        color=color, 
        edgecolor='black'
    )
    
    # 设置图表标题和标签
    plt.title(title, fontsize=16)
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    
    # 添加网格以便更好地查看
    plt.grid(True, alpha=0.3)
    
    # 计算统计信息
    mean_reward = np.mean(rewards_np)
    std_reward = np.std(rewards_np)
    min_reward = np.min(rewards_np)
    max_reward = np.max(rewards_np)
    
    # 添加统计信息到图上
    if show_stats:
        plt.axvline(mean_reward, color='red', linestyle='--', linewidth=2, 
                   label=f'Mean: {mean_reward:.3f}')
        plt.axvline(mean_reward + std_reward, color='orange', linestyle='--', 
                   alpha=0.7, label=f'Mean + Std: {mean_reward + std_reward:.3f}')
        plt.axvline(mean_reward - std_reward, color='orange', linestyle='--', 
                   alpha=0.7, label=f'Mean - Std: {mean_reward - std_reward:.3f}')
        plt.legend()
    
    # 调整布局
    plt.tight_layout()
    
    # 保存图像
    plt.savefig(os.path.join(save_dir, f'{curr_step}_{save_path}'), dpi=dpi, bbox_inches='tight')
    
    # 如果需要保存为 PDF 格式
    if save_pdf:
        pdf_path = save_path.replace('.png', '.pdf')
        plt.savefig(pdf_path, bbox_inches='tight')
    
    # 显示图像（可选）
    if show_plot:
        plt.show()
    
    # 关闭图像以释放内存
    plt.close()


def compute_advantages_for_multi_round_rollout(
    per_sample_rewards: List[List[float]], 
    current_policy_group: Optional[int] = None,
    # 第一轮配置
    first_subtract_mean: bool = True,
    first_divide_std: bool = False,
    first_scaler: float = 1.0,
    # 后续轮次配置
    later_operations: str = "subtract_prev,subtract_mean",
    later_scaler: float = 1.0,
) -> Tuple[List[List[float]], Dict[str, float]]:
    """
    多轮rollout RL advantage计算函数
    
    Args:
        per_sample_rewards: 每个样本的多轮rewards，形状为 [num_samples, num_rounds]
        current_policy_group: 分布式训练的policy group，用于advantage计算中的group均值
        first_subtract_mean: 第一轮是否减均值
        first_divide_std: 第一轮是否除标准差
        first_scaler: 第一轮是在advantage上的缩放系数
        later_operations: 后续轮次的操作序列，用逗号分隔，可选操作：
                         - "subtract_prev": 减去上一轮reward
                         - "subtract_mean": 减去当前轮group均值
                         - "divide_std": 除以当前轮group标准差
                         例如: "subtract_prev,subtract_mean" 或 "subtract_mean,subtract_prev"
    
    Returns:
        per_sample_advantages: 每个样本的多轮advantages
        improvement_stats: 包含每轮全局平均reward和相对提升的字典
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    per_sample_advantages = []
    num_rounds = len(per_sample_rewards[0]) if per_sample_rewards else 0
    
    # 解析后续轮次的操作序列
    operations = [op.strip() for op in later_operations.split(',') if op.strip()]
    
    # 存储每轮的全局平均reward（用于logging和计算提升）
    global_round_mean_rewards = []
    improvement_stats = {
        "round_mean_rewards": {},
        "round_improvements": {},
        "total_improvement": 0.0
    }
    
    def compute_group_std(data_tensor, group=None):
        """计算group标准差，通过all_gather所有数据"""
        if group is not None:
            # 获取group大小
            group_size = dist.get_world_size(group)
            
            # all_gather所有rank的数据
            gathered_data = [torch.zeros_like(data_tensor) for _ in range(group_size)]
            dist.all_gather(gathered_data, data_tensor, group=group)
            
            # 拼接所有数据
            all_data = torch.cat(gathered_data, dim=0)
            
            # 计算真正的全局标准差
            return all_data.std()
        else:
            # 没有group，只计算本rank的标准差
            return data_tensor.std()
    
    # 处理第一轮
    if num_rounds > 0:
        first_round_rewards = torch.tensor(
            [sample_reward[0] for sample_reward in per_sample_rewards], 
            device=device, dtype=torch.float
        )
        
        # 计算全局平均reward（用于logging）
        first_round_global_mean = first_round_rewards.mean()
        # 全局均值总是在所有rank上计算
        if dist.is_initialized():
            dist.all_reduce(first_round_global_mean, op=dist.ReduceOp.AVG)
        
        global_round_mean_rewards.append(first_round_global_mean.item())
        improvement_stats["round_mean_rewards"]["round_0"] = first_round_global_mean.item()
        
        first_round_advantages = first_round_rewards.clone()
        
        # 减group均值（用于advantage计算）
        if first_subtract_mean:
            group_mean = first_round_rewards.mean()
            # group均值只在指定的policy group内计算
            if current_policy_group is not None:
                dist.all_reduce(group_mean, op=dist.ReduceOp.AVG, group=current_policy_group)
            first_round_advantages = first_round_advantages - group_mean
        
        # 除group标准差
        std = compute_group_std(first_round_advantages, current_policy_group)
        round_0_global_mean_std = std.clone()
        dist.all_reduce(round_0_global_mean_std, op=dist.ReduceOp.AVG)
        improvement_stats["round_0_global_mean_std"] = round_0_global_mean_std.item()
        if first_divide_std:
            first_round_advantages = first_round_advantages / (std + 1e-6)
        
        first_round_advantages_list = first_round_advantages.cpu().tolist()
        for adv in first_round_advantages_list:
            per_sample_advantages.append([adv * first_scaler])
    
    # 处理后续轮次
    for round_idx in range(1, num_rounds):
        current_round_rewards = torch.tensor(
            [per_sample_rewards[sample_idx][round_idx] for sample_idx in range(len(per_sample_rewards))],
            device=device, dtype=torch.float
        )
        
        # 计算全局平均reward（用于logging和计算提升）
        current_round_global_mean = current_round_rewards.mean()
        # 全局均值总是在所有rank上计算
        if dist.is_initialized():
            dist.all_reduce(current_round_global_mean, op=dist.ReduceOp.AVG)
        
        global_round_mean_rewards.append(current_round_global_mean.item())
        
        round_key = f"round_{round_idx}"
        improvement_stats["round_mean_rewards"][round_key] = current_round_global_mean.item()
        
        # 计算相对于上一轮的全局提升
        if len(global_round_mean_rewards) >= 2:
            improvement = global_round_mean_rewards[-1] - global_round_mean_rewards[-2]
            improvement_stats["round_improvements"][f"round_{round_idx}_vs_{round_idx - 1}"] = improvement

        
        # 初始化当前轮的advantages
        round_advantages = current_round_rewards.clone()
        
        # 按配置顺序执行操作
        for op in operations:
            if op == "subtract_prev":
                # 减去上一轮reward
                prev_round_rewards = torch.tensor(
                    [per_sample_rewards[sample_idx][round_idx - 1] for sample_idx in range(len(per_sample_rewards))],
                    device=device, dtype=torch.float
                )
                round_advantages = round_advantages - prev_round_rewards
                
            elif op == "subtract_mean":
                # 减去当前轮的group均值（用于advantage计算）
                group_mean = round_advantages.mean()
                # group均值只在指定的policy group内计算
                if current_policy_group is not None:
                    dist.all_reduce(group_mean, op=dist.ReduceOp.AVG, group=current_policy_group)
                round_advantages = round_advantages - group_mean
                
            elif op == "divide_std":
                # 除以group标准差（用于advantage计算）
                std = compute_group_std(current_round_rewards, current_policy_group)
                round_advantages = round_advantages / (std + 1e-6)

            elif op == "subtract_mean_global":
                # 减去全局均值（用于advantage计算）
                global_mean = round_advantages.mean()
                if dist.is_initialized():
                    dist.all_reduce(global_mean, op=dist.ReduceOp.AVG)
                round_advantages = round_advantages - global_mean
        
        # 添加到结果中
        round_advantages_list = round_advantages.cpu().tolist()
        for sample_idx, adv in enumerate(round_advantages_list):
            per_sample_advantages[sample_idx].append(adv * later_scaler)
    
    # 计算总的全局提升（最后一轮相对于第一轮）
    if len(global_round_mean_rewards) >= 2:
        improvement_stats["total_improvement"] = global_round_mean_rewards[-1] - global_round_mean_rewards[0]
    
    return per_sample_advantages, improvement_stats