import concurrent.futures
import pandas as pd
import json
from common_io import table as common_io_table

class ODPSHelper:
    def __init__(self, table_name):
        self.table_name = table_name

    @staticmethod
    def _download(slice_count, slice_id, tableName, selected_cols):
        reader = common_io_table.TableReader(tableName,
                                             selected_cols=selected_cols,
                                             slice_id=slice_id,
                                             slice_count=slice_count)
        total = reader.get_row_count()
        if selected_cols == "":
            selected_cols = [x[0] for x in reader.get_schema()]
        else:
            selected_cols = selected_cols.split(',')

        data = reader.read(num_records=total, allow_smaller_final_batch=True)
        df = {}
        for i, c in enumerate(selected_cols):
            df[c] = pd.Series([row[i].decode() if isinstance(row[i], bytes) else row[i] for row in data])
        del data
        output = pd.DataFrame(df)
        return output

    def read(self, selected_cols='', num_workers=10, rank=0, world_size=1):
        with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
            futures = []
            for i in range(num_workers):
                future = executor.submit(
                    self._download,
                    num_workers * world_size,
                    i + rank * num_workers,
                    self.table_name,
                    selected_cols
                )
                futures.append(future)
            data = [f.result() for f in futures]
        return pd.concat(data)

    @staticmethod
    def write_train_data_to_table(data_list, write_to_table_name):
        writer = common_io_table.TableWriter(write_to_table_name, slice_id=0)
        w_result = []
        for index, row in enumerate(data_list):
            w_result.append((row['instruction'], row['chosen'], row['rejected']))
            if len(w_result) == 100:
                writer.write(w_result, (0, 1, 2))
                print(f"当前是第{index + 1}条，本次已插入了{len(w_result)}条数据到{write_to_table_name}中.")
                w_result = []
        if len(w_result) > 0:
            writer.write(w_result, (0, 1, 2))
            print(f"还有剩余, 本次已插入了{len(w_result)}条数据到{write_to_table_name}中.")
        writer.close()

    @staticmethod
    def write_sft_data_to_table(data_list, write_to_table_name):
        writer = common_io_table.TableWriter(write_to_table_name, slice_id=0)
        w_result = []
        for index, row in enumerate(data_list):
            w_result.append((row['instruction'], row['output']))
            if len(w_result) == 100:
                writer.write(w_result, (0, 1))
                print(f"当前是第{index + 1}条，本次已插入了{len(w_result)}条数据到{write_to_table_name}中.")
                w_result = []
        if len(w_result) > 0:
            writer.write(w_result, (0, 1))
            print(f"还有剩余, 本次已插入了{len(w_result)}条数据到{write_to_table_name}中.")
        writer.close()

    def write_dpo_jsonl(self, jsonl_path):
        data_list = []
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line)
                    data_list.append({
                        'instruction': item['prompt'],
                        'chosen': item['chosen'],
                        'rejected': item['rejected']
                    })
        self.write_train_data_to_table(data_list, self.table_name)

    def write_sft_jsonl(self, jsonl_path):
        data_list = []
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                if line.strip():
                    item = json.loads(line)
                    data_list.append({
                        'instruction': item['prompt'],
                        'output': item['response']
                    })
        self.write_sft_data_to_table(data_list, self.table_name)

    @staticmethod
    def write_dataset_to_table(data_list, write_to_table_name):
        writer = common_io_table.TableWriter(write_to_table_name, slice_id=0)
        w_result = []
        for index, row in enumerate(data_list):
            # 统一使用图片格式：id, question, content, content_type, 选项, extra, oss_url
            w_result.append((
                row['id'], 
                row['question'], 
                row['content'], 
                row['content_type'], 
                json.dumps(row['选项']), 
                row['extra'], 
                row.get('oss_url', '')
            ))
            
            if len(w_result) == 100:
                writer.write(w_result, (0, 1, 2, 3, 4, 5, 6))
                print(f"当前是第{index + 1}条，本次已插入了{len(w_result)}条数据到{write_to_table_name}中.")
                w_result = []
        
        if len(w_result) > 0:
            writer.write(w_result, (0, 1, 2, 3, 4, 5, 6))
            print(f"还有剩余, 本次已插入了{len(w_result)}条数据到{write_to_table_name}中.")
        writer.close()



    def write_dataset_jsonl(self, jsonl_path):
        data_list = []
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            content = f.read().strip()
            if content.startswith('[') and content.endswith(']'):
                # JSON数组格式
                data_list = json.loads(content)
            else:
                # JSONL格式
                f.seek(0)
                for line_num, line in enumerate(f, 1):
                    line = line.strip()
                    if line:
                        try:
                            item = json.loads(line)
                            data_list.append(item)
                        except json.JSONDecodeError as e:
                            print(f"第{line_num}行JSON解析错误: {e}")
                            print(f"问题行内容: {line[:100]}...")
                            continue
        self.write_dataset_to_table(data_list, self.table_name)

if __name__ == '__main__':

    # 图片数据集数据示例
    image_dataset_table_name = 'xxx'
    image_dataset_jsonl_path = 'data/dataset/EVADE/evade_dataset_test_image.json'
    image_dataset_odps_helper = ODPSHelper(image_dataset_table_name)
    image_dataset_odps_helper.write_dataset_jsonl(image_dataset_jsonl_path)