# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from omegaconf import ListConfig
import os
from typing import List, Union
import copy
import pandas as pd

import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, PreTrainedTokenizer
from verl.utils.fs import copy_local_path_from_hdfs

from verl.utils.model import compute_position_id_with_mask
import verl.utils.torch_functional as verl_F
import json

def collate_fn(data_list: list[dict]) -> dict:
    tensors = {}
    non_tensors = {}

    for data in data_list:
        for key, val in data.items():
            if isinstance(val, torch.Tensor):
                if key not in tensors:
                    tensors[key] = []
                tensors[key].append(val)
            else:
                if key not in non_tensors:
                    non_tensors[key] = []
                non_tensors[key].append(val)

    for key, val in tensors.items():
        tensors[key] = torch.stack(val, dim=0)

    for key, val in non_tensors.items():
        non_tensors[key] = np.array(val, dtype=object)

    output = {}
    output.update(tensors)
    output.update(non_tensors)
    return output


class RLHFDataset(Dataset):
    """
    We assume the dataset contains a column that contains prompts and other information
    """

    def __init__(self,
                 parquet_files: Union[str, List[str]],
                 tokenizer: PreTrainedTokenizer,
                 prompt_key='prompt',
                 max_prompt_length=1024,
                 filter_prompts=True,
                 cache_dir='~/.cache/verl/rlhf',
                 chat_template_func=None,
                 return_raw_chat=False,
                 truncation='error',
                 sample_n_file=None):
        if not isinstance(parquet_files, (List, ListConfig)):
            parquet_files = [parquet_files]

        self.parquet_files = copy.deepcopy(parquet_files)
        self.original_parquet_files = copy.deepcopy(parquet_files)  # use for resume
        self.cache_dir = os.path.expanduser(cache_dir)
        self.tokenizer = tokenizer

        self.prompt_key = prompt_key
        self.max_prompt_length = max_prompt_length
        self.filter_prompts = filter_prompts

        self.return_raw_chat = return_raw_chat
        self.chat_template_func = chat_template_func
        self.truncation = truncation

        # whether to store the dataset in state_dict()
        # default not store
        self.serialize_dataset = False
        self.print_example = False
        self._download()
        self._read_files_and_tokenize()
        # read the difficulty level data id file and match data by id
        self._read_category_data_id()

        self.sample_id2rollout_n = None
        if sample_n_file is not None:
            print(f"[Dynamic Rollout On] reading sample_n_file: {sample_n_file}")
            self._read_n_for_each_sample(sample_n_file)
        print(f"example prompt: {self.tokenizer.decode(self.__getitem__(0)['input_ids'])}")
    
    def _read_category_data_id(self):
        # find base dir of parquet file and find {anycategory}_data_ids.json
        base_dir = os.path.dirname(self.parquet_files[0])
        # match with regex pattern
        import glob
        pattern = os.path.join(base_dir, "*_data_ids.json")
        matching_files = glob.glob(pattern)
        if not matching_files:
            print(f"No category file found in {base_dir}")
            return
        # iterate over the matching files
        id2cat = {}
        for file in matching_files:
            category_data_ids = json.load(open(file, "r"))
            for id in category_data_ids:
                id2cat[id] = os.path.basename(file).split("_")[0]
        # add a column named category to the dataframe
        self.dataframe["category"] = self.dataframe["id"].map(id2cat)
    
    def _read_n_for_each_sample(self, sample_n_file):
        with open(sample_n_file, "r") as f:
            self.sample_id2rollout_n = json.load(f)
        self.sample_id2rollout_n = {int(k): v for k, v in self.sample_id2rollout_n.items()}
        # self.dataframe["n"] = self.dataframe["extra_info"]["index"].map(sample_n)
        # print(f"n for each sample: {self.dataframe['n'].head()}")



    def _download(self, use_origin_parquet=False):
        from verl.utils.fs import copy_local_path_from_hdfs
        parquet_files = self.parquet_files if not use_origin_parquet else self.original_parquet_files
        for i, parquet_file in enumerate(parquet_files):
            self.parquet_files[i] = copy_local_path_from_hdfs(src=parquet_file, cache_dir=self.cache_dir)

    def _read_files_and_tokenize(self):
        dataframes = []
        for parquet_file in self.parquet_files:
            # read parquet files and cache
            dataframe = pd.read_parquet(parquet_file)
            dataframes.append(dataframe)
        self.dataframe = pd.concat(dataframes)

        print(f'original dataset len: {len(self.dataframe)}')

        # filter out too long prompts
        tokenizer = self.tokenizer
        prompt_key = self.prompt_key
        # self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(
        #     tokenizer.apply_chat_template(doc[prompt_key], add_generation_prompt=True)) <= self.max_prompt_length,
        #                                                      axis=1)]
        self.dataframe = self.dataframe[self.dataframe.apply(lambda doc: len(tokenizer.encode(doc[prompt_key][0]['content'])) <= self.max_prompt_length,
                                                             axis=1)]
        # reset index
        self.dataframe = self.dataframe.reset_index(drop=True)

        print(f'filter dataset len: {len(self.dataframe)}')

    def resume_dataset_state(self):
        self.serialize_dataset = False if hasattr(self, 'original_parquet_files') else True
        # resume dataframe if not it's serialized in data.pt
        if not self.serialize_dataset:
            self._download(use_origin_parquet=True)  # download and resume from original parquet files
            self._read_files_and_tokenize()
        else:
            print(r'old dataloader ckpt file is used, please train from scratch for better ckpt performance')

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, item):
        """
        Note that we also return the raw_input_ids so that it can be combined with other chat template
        """
        row_dict = self.dataframe.iloc[item].to_dict()

        chat = row_dict.pop(self.prompt_key)

        # prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
        # prompt_with_chat_template = chat[0]['content']
        prompt = chat[0]['content']
        if "<|im_end|>\n<|im_start|>" in prompt or prompt.startswith("Question:"):
            prompt_with_chat_template = prompt
        else:
            prompt_with_chat_template = self.tokenizer.apply_chat_template(chat, add_generation_prompt=True, tokenize=False)
        if not self.print_example:
            print(" ======== Example Prompt ======== ")
            print(prompt_with_chat_template)
            print(" ======== Example Prompt ======== ")
            self.print_example = True
        input_ids, attention_mask = verl_F.tokenize_and_postprocess_data(prompt=prompt_with_chat_template,
                                                                         tokenizer=self.tokenizer,
                                                                         max_length=self.max_prompt_length,
                                                                         pad_token_id=self.tokenizer.pad_token_id,
                                                                         left_pad=True,
                                                                         truncation=self.truncation)

        position_ids = compute_position_id_with_mask(attention_mask)

        row_dict['input_ids'] = input_ids[0]
        row_dict['attention_mask'] = attention_mask[0]
        row_dict['position_ids'] = position_ids[0]

        # encode prompts without chat template
        if self.return_raw_chat:
            row_dict['raw_prompt'] = chat.tolist()

        # add index for each prompt
        if "extra_info" in row_dict and "index" in row_dict["extra_info"]:
            row_dict["index"]  = row_dict["extra_info"]["index"]
            row_dict["prompt_id"] = row_dict["extra_info"]["index"]
        else:
            # index = row_dict.get("extra_info", {}).get("index", 0)
            # row_dict["index"] = index
            row_dict["prompt_id"] = row_dict["unique_id"] if "unique_id" in row_dict else row_dict.get("id", None)
            
        if "extra_info" in row_dict and "question" in row_dict["extra_info"]:
            row_dict["question"] = row_dict["extra_info"]["question"]
        
        # TODO: add n for each sample
        if self.sample_id2rollout_n is not None:
            row_dict["n"] = self.sample_id2rollout_n[row_dict["prompt_id"]]
            row_dict["extra_info"]["n"] = row_dict["n"]


        return row_dict

    def __getstate__(self):
        if not self.serialize_dataset:
            state = self.__dict__.copy()

            if 'dataframe' in state:
                del state['dataframe']
            return state
        return self.__dict__.copy()
    
    def __iter__(self):
        return self.dataframe.iterrows()

# implement a randomsampler that can sample by category
# suppose category_sample_ratio is a string with the sampling weights by category,e.g. easy-0.1_medium-0.1_hard-0.8 means 10% of easy, 10% of medium and 80% of hard
from torch.utils.data import WeightedRandomSampler
import numpy as np

class RandomSamplerByCategory:
    def __init__(self, data_source, category_sample_ratio, replacement=True):
        self.data_source = data_source
        # Parse category_sample_ratio string if it's a string
        if isinstance(category_sample_ratio, str):
            self.category_sample_ratio = self._parse_ratio_string(category_sample_ratio)
        else:
            self.category_sample_ratio = category_sample_ratio
        print(f"category_sample_ratio: {self.category_sample_ratio}")
        
        # Calculate weights for each sample based on category
        self.weights = self._calculate_weights()
        
        # Create PyTorch's WeightedRandomSampler
        self.sampler = WeightedRandomSampler(
            weights=self.weights,
            num_samples=len(self.data_source),
            replacement=replacement
        )
        
    def _parse_ratio_string(self, ratio_string):
        """Parse ratio string like 'easy-0.1_medium-0.1_hard-0.8' into dict {'easy': 0.1, 'medium': 0.1, 'hard': 0.8}"""
        parts = ratio_string.split('_')
        ratio_dict = {}
        for part in parts:
            category, ratio = part.split('-')
            ratio_dict[category] = float(ratio)
        return ratio_dict
        
    
    def _calculate_weights(self):
        """Calculate sampling weights for each row based on category"""
        weights = np.ones(len(self.data_source))
        print("length of dataframe", len(self.data_source))
        
        for idx, row in self.data_source:
            category = row.get('category', 'default')
            if category in self.category_sample_ratio:
                weights[idx] = self.category_sample_ratio[category]
            else:
                # raise warning
                print(f"Warning: Category {category} not found in category_sample_ratio")

        # for idx, row in self.dataframe.iterrows():
        #     print("idx", idx)
        #     # category = row.get('category', 'default')
        #     # if category in self.category_sample_ratio:
        #     #     weights[idx] = self.category_sample_ratio[category]
        #     # else:
        #     #     # raise warning
        #     #     print(f"Warning: Category {category} not found in category_sample_ratio")
        #     #     # Default weight for categories not specified
        #     #     weights[idx] = 1.0
        # raise Exception("Not implemented")
        return weights
    
    def __iter__(self):
        """Return indices sampled according to category weights using PyTorch's WeightedRandomSampler"""
        return iter(self.sampler)
    
    def __len__(self):
        return len(self.data_source)
