import os
import os.path as osp
import yaml
import pandas as pd
import json
from utils import lang_dict, clean_punctuation
from generation_check import check_quality, extract_zh_from_prompt, extract_tar

ass_header = """[Script Info]
; Script generated by SRT to ASS converter
Title: Converted from SRT
ScriptType: v4.00+
PlayResX: 1920
PlayResY: 1080

[V4+ Styles]
Format: Name, Fontname, Fontsize, PrimaryColour, SecondaryColour, OutlineColour, BackColour, Bold, Italic, Underline, StrikeOut, ScaleX, ScaleY, Spacing, Angle, BorderStyle, Outline, Shadow, Alignment, MarginL, MarginR, MarginV, Encoding
Style: Default,Arial,20,&H00FFFFFF,&H000000FF,&H00000000,&H00000000,-1,0,0,0,100,100,0,0,1,1.5,0,2,10,10,10,1

[Events]
Format: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text
"""

if __name__ == '__main__':
    dirname = osp.dirname(osp.abspath(__file__))
    config_path = osp.join(dirname, '..', 'config.yaml')
    with open(config_path, 'r', encoding='utf-8') as file:
        config = yaml.safe_load(file)
    sft_model = config['sft_model']
    tr_model = config['tr_model']
    trpe_model = config['trpe_model']
    threshold_limit = config['threshold_limit']
    context_len = config['context_len']
    step = config['step']
    lang = config['lang']
    sft_proportion = config['sft_proportion']
    dpo_mode = config['dpo_mode']
    dpo_finetuning_type = config['dpo_finetuning_type']
    src_lang_str = lang_dict[lang.split('2')[0]]
    lang_str = lang_dict[lang.split('2')[1]]
    filter_threshold = config['filter_threshold']
    evaluation_mode = config['evaluation_mode']
    
    if evaluation_mode:
        exit()
    
    dataset_path = osp.join(dirname, '..', 'LLaMAFactory', 'data', f'translation_test_{tr_model}_{trpe_model}_{lang}.json')
    info_dir = osp.join(dirname, 'info', lang)
    output_name = f'sft_{sft_model}_{tr_model}_{trpe_model}' if sft_proportion == 1.0 else f'dpo_{sft_model}_{tr_model}_{trpe_model}_{dpo_mode}_{sft_proportion}_{dpo_finetuning_type}'
    output_dir = osp.join(dirname, '..', 'Inference', lang, output_name)
    infer_path = osp.join(output_dir, 'generated_predictions.jsonl')
    
    dataset_list = json.load(open(dataset_path, 'r', encoding='utf-8'))
    infer_list = []
    with open(infer_path, 'r', encoding='utf-8') as file:
        for line in file:
            line = line.strip()
            if line:
                json_obj = json.loads(line)
                infer_list.append(json_obj)
    skip_path = osp.join(output_dir, 'final_exception_indexes.json')
    skip_index = json.load(open(skip_path, 'r', encoding='utf-8')) if osp.exists(skip_path) else []
    
    episode_maxline_mapping = {}
    episode_subtitle_dict = {} 
            
    for i in range(len(dataset_list)):
        data = dataset_list[i]
        prediction = infer_list[i]
        play_name = data['play name']
        episode = data['episode']
        start_timestamp = data['source start timestamp']
        end_timestamp = data['source end timestamp']
        original_line = data['original line']
        
        if play_name not in episode_maxline_mapping:
            episode_maxline_mapping[play_name] = {}
        if episode not in episode_maxline_mapping[play_name]:
            episode_maxline_mapping[play_name][episode] = max(original_line)
        else:
            episode_maxline_mapping[play_name][episode] = max(episode_maxline_mapping[play_name][episode], max(original_line))
            
        if play_name not in episode_subtitle_dict:
            episode_subtitle_dict[play_name] = {}
        if episode not in episode_subtitle_dict[play_name]:
            episode_subtitle_dict[play_name][episode] = {}
        
        prompt = data["instruction"]
        predict = prediction["predict"]
        
        if i not in skip_index and check_quality(prompt, predict):
            prompt_zhs = extract_zh_from_prompt(prompt)
            pred_tars = extract_tar(predict)
            
            if len(original_line) == step + 2 * context_len:
                begin_idx = context_len
                end_index = step + context_len
            else:
                begin_idx = 0
                end_index = len(original_line)
            
            for i in list(range(len(original_line)))[begin_idx:end_index]:
                episode_subtitle_dict[play_name][episode][original_line[i]] = [start_timestamp[i], end_timestamp[i],prompt_zhs[i], pred_tars[i]]
    
    for play_name in episode_subtitle_dict:
        for episode in episode_subtitle_dict[play_name]:
            subtitle_dict = episode_subtitle_dict[play_name][episode]
            max_line = episode_maxline_mapping[play_name][episode]
            subtitle_list = []
            for i in range(0, max_line + 1):
                if i in subtitle_dict:
                    start_timestamp = subtitle_dict[i][0]
                    end_timestamp = subtitle_dict[i][1]
                    prompt_zh = subtitle_dict[i][2]
                    if '2zh' in lang:
                        pred_tar = clean_punctuation(subtitle_dict[i][3])
                    else:
                        pred_tar = subtitle_dict[i][3]
                    subtitle_list.append([start_timestamp, end_timestamp, prompt_zh, pred_tar])
                else:
                    subtitle_list.append([0, 0, '', ''])
            subtitle_path = osp.join(output_dir, 'ass', play_name)
            os.makedirs(subtitle_path, exist_ok=True)
            subtitle_file = osp.join(subtitle_path, f'{play_name} {episode}_{lang_str}.ass')
            with open(subtitle_file, 'w', encoding='utf-8') as file:
                file.write(ass_header)
                for i in range(len(subtitle_list)):
                    file.write(f"Dialogue: 0,{subtitle_list[i][0]},{subtitle_list[i][1]},Default,,0,0,0,,{subtitle_list[i][3]}\n")
            csv_path = osp.join(output_dir, 'csv', play_name)
            os.makedirs(csv_path, exist_ok=True)
            csv_file = osp.join(csv_path, f'{play_name} {episode}_{lang_str}.csv')
            df = pd.DataFrame(subtitle_list, columns=['开始时间', '结束时间', '原文', '译文'])
            df.to_csv(csv_file, index=False, encoding='utf-8-sig')