from .criteria_dataloder import CriteriaDataLoder
from .traj_data import PairedData, Traj
import os, json
import json

class DpoPool:
    def __init__(self, dpo_pool_file: str):
        assert isinstance(dpo_pool_file, str)
        self.dpo_pool_file = dpo_pool_file
        
        self.init_status = False

        self._pool = dict()

        self._load()
        # raise NotImplementedError

    def _load(self):
        if self.init_status:
            raise RuntimeError(f"DpoPool has been initialized. ")
        for line in open(self.dpo_pool_file):
            item = json.loads(line)
            sample_id = item['id']
            assert sample_id not in self._pool, f"{sample_id} is duplicated in the dpo pool file: [{self.dpo_pool_file}]. check it. !"
            self._pool[sample_id] = item
        self.init_status = True
        print(f"DpoPool has loaded from dpo pool file: [{self.dpo_pool_file}]")

    def get_rejected_item(self, sample_id: str) -> Traj:
        assert isinstance(sample_id, str)
        # get the Traj from reject
        
        dpo_data_sample = self._pool.get(sample_id, None)
        if dpo_data_sample is None:
            return None
        return dpo_data_sample['rejected_item']



class ExplorationDataLoader(CriteriaDataLoder):
    def __init__(self, args):
        super().__init__(args)

    
    def _recover(self, recover_file:str):
        skip_id = set()
        for line in open(recover_file).readlines():
            json_item = json.loads(line)
            skip_id.add(json_item['paired_data']['id'])
        return skip_id
    def load_language_criteria_iter(self, json_item, dpo_pool: DpoPool|None) -> Traj:
        
        # import pdb; pdb.set_trace()
        p_data = PairedData(dict(
            id=json_item['id'],
            suffix=json_item['suffix'],
            query={"content": json_item['conversations'][0]['content']}, 
            chosen={"content": json_item['chosen']['content']}, 
            rejected={"content": json_item['rejected']['content']}
        ))
        if dpo_pool is not None:
            rejected_item = dpo_pool.get_rejected_item(p_data.id)
            if rejected_item is None:
                return None
            assert rejected_item['traj']['paired_data']['id'] == p_data.id
            assert isinstance(rejected_item['traj']['criteria_list'], list)
            assert len(rejected_item['traj']['criteria_list']) != 0

            # rejected_item['organized_response']
            t_data = Traj()
            # load answer from rejecte Traj/Response

            t_data.loads({
                "paired_data": p_data,
                "criteria_list": rejected_item['traj']['criteria_list'],
                "answer": rejected_item['traj']['answer'],
            })
            t_data.build_exploration_criteria_conversaion()
        else:
            t_data = Traj()
            t_data.loads({
                "paired_data": p_data,
            }, shuffle= True)
            t_data.build_exploration_criteria_conversaion()
        return t_data
    
    def load_criteria_iter(self, dataset_file: str, recover_file: str, dpo_pool: DpoPool):
        skip_id = {}
        if os.path.exists(recover_file):
            skip_id = self._recover(recover_file)
          
        handled_num = len(skip_id)
        with open(dataset_file, 'r') as f:

            for line in f:
                json_item = json.loads(line)
                if json_item['id'] in skip_id:
                    continue
                
                if self.args.max_input_size != -1 and handled_num > self.args.max_input_size:
                    return
                handled_num += 1
                
                if "images" in json_item['conversations']:
                    yield self.load_images_iter(json_item)
                elif "videos" in json_item['conversations']:
                    yield self.load_videos_iter(json_item)
                elif "audios" in json_item['conversations']:
                    yield self.load_audio_iter(json_item)
                else:
                    traj = self.load_language_criteria_iter(json_item, dpo_pool)
                    if traj is None: 
                        print(f"[{json_item['id']}] has no Rejected Response. !!!!!!!!")
                        continue
                    yield traj
        return None
    
    def load_judge_iter(self, dataset_file, recover_file, manner: str):

        if os.path.exists(recover_file):
            skip_id = self._recover(recover_file)
        else:
            skip_id = {}
        
        handled_num = len(skip_id)
        with open(dataset_file, 'r') as f:
            for line in f:
                json_item = json.loads(line)
                if json_item['paired_data']['id'] in skip_id:
                    continue
                if self.args.max_input_size != -1 and handled_num > self.args.max_input_size:
                    return 
                handled_num += 1

                yield self._load_judge_iter(json_item, manner)
        return None

    def _load_judge_iter(self, json_item, manner: str):
        
        p_data = PairedData(json_item['paired_data'])
        
        t_data = Traj()
        t_data.loads({
            "paired_data": p_data,
            "answer": json_item['answer'],
            "criteria_list": json_item['criteria_list']
        }, shuffle=False)
        # print(f"use build_exploration_judge_conversation")
        t_data.build_exploration_judge_conversaion()
        return t_data
