from split import get_cot,ave_length
from agent_split import split_agent
import os
import sys
from tqdm import tqdm
import matplotlib.pyplot as plt
import argparse
import json
import torch
import gc

project_root = os.environ.get("PROJECT_ROOT")
if project_root and project_root not in sys.path:
    sys.path.append(project_root)
from env import Reward_Skywork,Reward_Prm,Gsm8kDataset,MathDataset

data_root = os.environ.get("DATA_ROOT")

def split_reward_argparser():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, default="simple_split/gsm8k_qwen7b_3_3,3.log")
    parser.add_argument("--dataset", type=str, default="gsm8k")
    parser.add_argument("--reward", type=str, default="prm") # prm, skywork
    parser.add_argument("--agent_path", type=str, default=f"{data_root}/Qwen2.5-7B-Instruct/")
    parser.add_argument("--cuda_id", type=int, default=0)
    parser.add_argument("--n", type=int, default=20)
    parser.add_argument("--save_path", type=str, default="simple_split/gsm8k_qwen7b_3_3,3_prm_20_agent.png")
    parser.add_argument("--split_method", type=str, default="agent")  # or "original", "new"
    return parser.parse_args()

dataset = {
    "gsm8k": Gsm8kDataset,
    "math": MathDataset
}

dataset_path = {
    "gsm8k": f"{data_root}/gsm8k",
    "math": f"{data_root}/efficient-reasoning/competition_math"
}

reward_path = {
    "skywork": Reward_Skywork,
    "prm": Reward_Prm
}

reward_model_path = {
    "skywork": f"{data_root}/Skywork-Reward-Llama-3.1-8B-v0.2",
    "prm": f"{data_root}/Qwen2.5-Math-PRM-7B"
}

def split_reward(data_path, dataset, agent_path, cuda_id, n, save_path, split_method):
    prompt = dataset.get_prompt()
    cot = get_cot("./raw_answer/" + data_path)
    print(ave_length(cot))
    # print(len(cot))
    reward_model = reward_path[args.reward](reward_model_path[args.reward], f"cuda:{cuda_id}")
    split_results = split_agent(cot, agent_path, cuda_id, split_method)
    # print(split_results[0])
    reward_list = [[] for _ in range(len(split_results))]

    # Create output directory if it doesn't exist
    output_dir = os.path.join("./results", os.path.splitext(save_path)[0])
    os.makedirs(output_dir, exist_ok=True)

    for i in tqdm(range(len(split_results)), desc="reward calculating"):
        for j in range(len(split_results[i])):
            #[TODO] fix user logic
            reward = reward_model(system_prompt="", user=prompt[i], answer=split_results[i][j])
            reward_list[i].append(reward)

    for i in range(len(reward_list)):
        if len(reward_list[i]) < n:
            reward_list[i].extend([0] * (n - len(reward_list[i])))
        elif len(reward_list[i]) > n:
            avg = sum(reward_list[i][n-1:]) / (len(reward_list[i]) - n + 1)
            reward_list[i] = reward_list[i][:n-1] + [avg]
    
    column_avg = []
    column_counts = []
    column_max = []
    column_min = []
    for j in range(n):
        column = [reward_list[i][j] for i in range(len(reward_list)) if reward_list[i][j] != 0]
        # print(column)
        column_counts.append(len(column))
        if column:
            column_avg.append(sum(column) / len(column))
            column_max.append(max(column))
            column_min.append(min(column))
        else:
            column_avg.append(0)
            column_max.append(0)
            column_min.append(0)
    
    result_data = {
        "column_averages": column_avg,
        "column_counts": column_counts,
        "column_max": column_max,
        "column_min": column_min
    }
    
    # Save JSON file in the output directory
    json_save_path = os.path.join(output_dir, "data.json")
    with open(json_save_path, 'w') as f:
        json.dump(result_data, f, indent=4)
    
    print("Column averages:", column_avg)
    print("Column counts:", column_counts)
    print("Column max values:", column_max)
    print("Column min values:", column_min)
    
    plt.figure(figsize=(15, 6))
    
    # First subplot: All statistics
    plt.subplot(1, 2, 1)
    plt.plot(range(1, n+1), column_avg, marker='o', linestyle='-', linewidth=2, label='Average')
    plt.plot(range(1, n+1), column_max, marker='^', linestyle='--', linewidth=1, label='Maximum')
    plt.plot(range(1, n+1), column_min, marker='v', linestyle='--', linewidth=1, label='Minimum')
    plt.title('Reward Statistics by Step')
    plt.xlabel('Step')
    plt.ylabel('Reward')
    plt.grid(True)
    plt.legend()
    
    # Second subplot: Average only
    plt.subplot(1, 2, 2)
    plt.plot(range(1, n+1), column_avg, marker='o', linestyle='-', linewidth=2, color='blue')
    plt.title('Average Reward by Step')
    plt.xlabel('Step')
    plt.ylabel('Average Reward')
    plt.grid(True)
    
    plt.tight_layout()
    # Save image in the output directory
    image_save_path = os.path.join(output_dir, "plot.png")
    plt.savefig(image_save_path)
    plt.close()
    
    return column_avg

if __name__ == "__main__":
    args = split_reward_argparser()
    dataset = dataset[args.dataset](dataset_path[args.dataset], "better")
    split_reward(args.data_path, dataset, args.agent_path, args.cuda_id, args.n, args.save_path, args.split_method)  