import os
import json
import yaml
import argparse
import pandas as pd
from jinja2 import Template
from tqdm import tqdm
from typing import List, Tuple
from decision_oaif.utils.openai import generate_from_openai_completion
from decision_oaif.utils.parser import parse_json

def correct_student_trajectory(student_trajectory: List[dict], privileged_state: str, correction_oracle_template: Template) -> Tuple[List[dict], float]:    
    system_prompt = correction_oracle_template.render(system=True)
    input_prompt = correction_oracle_template.render(system=False, student_trajectory=student_trajectory, privileged_state=privileged_state)
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": input_prompt}
    ]

    response, cost = generate_from_openai_completion(
        messages=messages, model="gpt-4o"
    )

    correction = parse_json(response=response)
    return correction, cost, response

def check_corrected_trajectory(corrected_trajectory, original_trajectory):
    if len(corrected_trajectory) != len(original_trajectory):
        raise ValueError(f"Len of corrected trajectory {len(corrected_trajectory)} != len of original trajectory {len(original_trajectory)}")

    for idx, traj_step in enumerate(corrected_trajectory):
        for key in ["corrected_reason", "corrected_action"]: 
            if key not in traj_step:
                raise ValueError(f"Trajectory index {idx} does not have {key} ")
        
        original_traj_step = original_trajectory[idx]
        if 'original_observation' not in traj_step:
            traj_step['original_observation'] = original_traj_step['observation']
        if 'original_reason' not in traj_step:
            traj_step['original_reason'] = original_traj_step['reason']
        if 'original_action' not in traj_step:
            traj_step['original_action'] = original_traj_step['action']
        if 'candidate_actions' in original_traj_step:
            if 'candidate_actions' not in traj_step:
                traj_step['candidate_actions'] = original_traj_step['candidate_actions']
            if ('is_corrected' in traj_step) and (traj_step['corrected_action'] not in traj_step['candidate_actions']):
                traj_step['is_corrected'] = False

def process_logs(log_files: List[str], correction_oracle_template: Template, id_to_privileged_state: dict, log_dir: str, output_log_dir: str, id_field_name: str) -> None:
    cumulative_cost = 0
    pbar = tqdm(log_files, desc="Process log files")

    for i, filename in enumerate(pbar):
        filepath = os.path.join(log_dir, filename)
        with open(filepath, 'r') as file:
            log = json.load(file)

        id, original_trajectory = log[id_field_name], log['trajectory']

        student_trajectory = [{'timestep': i,
                               'observation': datapoint['observation'],
                               'candidate_actions': datapoint.get('candidate_actions', None),
                               'reason': datapoint['reason'],
                               'action': datapoint['action']} for i, datapoint in enumerate(original_trajectory)]

        privileged_state = id_to_privileged_state[str(id)]

        retries = 0
        max_retries = 3
        while retries < max_retries:
            try:
                correction, cost, response = correct_student_trajectory(student_trajectory=student_trajectory, privileged_state=privileged_state, correction_oracle_template=correction_oracle_template)
                if correction is None:
                    raise ValueError("Failed to parse a trajectory from response")
                
                summary = correction['summary']
                corrected_trajectory = correction['trajectory']
                check_corrected_trajectory(corrected_trajectory, original_trajectory)
                break
            except Exception as e:
                retries += 1
                summary = None
                corrected_trajectory = None
                print(response)
                print(f"Error processing {filename}: {e}. Retrying {retries}/{max_retries}...")

        if corrected_trajectory is not None:
            corrected_log = log
            corrected_log['summary'] = summary
            corrected_log['trajectory'] = corrected_trajectory
            output_filepath = os.path.join(output_log_dir, filename)
            with open(output_filepath, 'w') as file:
                json.dump(corrected_log, file, indent=4)
        else:
            print(f"Failed to correctly generate trajectory for {filename}")

        cumulative_cost += cost
        average_cost_per_file = cumulative_cost / (i + 1)
        projected_total_cost = average_cost_per_file * len(log_files)
        pbar.set_description(f"Cost: ${cumulative_cost:.2f}, Projected: ${projected_total_cost:.2f}")

def load_config(file_path: str, iter: int):
    with open(file_path, 'r') as file:
        config = yaml.safe_load(file)
    config = config['correct_student_trajectory']
    for key, value in config.items():
        if isinstance(value, str):
            config[key] = value.replace('{iter}', str(iter))
    return config
    
def main():
    parser = argparse.ArgumentParser(description="Correct student trajectory.")
    parser.add_argument('--config', type=str, required=True, help='Path to dataproc config file')
    parser.add_argument('--iter', type=int, required=True, help='Iteration number')
    args = parser.parse_args()

    config = load_config(args.config, args.iter)

    log_dir = config['log_dir']
    output_log_dir = config['output_log_dir']
    privileged_state_file = config['privileged_state_file']
    prompt_file = config['prompt_file']
    id_field_name = config['id_field_name']
    correct_score_threshold = config['correct_score_threshold']

    os.makedirs(output_log_dir, exist_ok=True)
    with open(prompt_file, "r") as file:
        correction_oracle_template = Template(file.read())

    with open(privileged_state_file, "r") as file:
        id_to_privileged_state = json.load(file)

    df_summary = pd.read_csv(os.path.join(log_dir, 'summary.csv'))
    filtered_df = df_summary[df_summary['score'] <= correct_score_threshold] 
    log_files = [f"{idx}.json" for idx in filtered_df['env_idx'].tolist()]

    process_logs(log_files, correction_oracle_template, id_to_privileged_state, log_dir, output_log_dir, id_field_name)

if __name__ == '__main__':
    main()