import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.compose import TransformedTargetRegressor
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import Ridge, Lasso
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
import seaborn as sns
from scipy.optimize import minimize
import heapq

def load_data(file_path):
    # 更新列名以包含PPL
    df = pd.read_csv(file_path, sep=r'\s+', engine='python', encoding='utf-16')
    df.columns = ['INDEX', 'CODE', 'MATH', 'GENERAL', '平均RL-LOSS', '最大RL-LOSS', 'PPL']

    # 验证比例总和
    df['比例总和'] = df['CODE'] + df['MATH'] + df['GENERAL']
    d = df[np.isclose(df['比例总和'], 1.0, atol=0.01)].copy()

    return df[['CODE', 'MATH', 'GENERAL', '平均RL-LOSS', '最大RL-LOSS', 'PPL']]

def train_model(X, y):
    pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('regressor', GradientBoostingRegressor(
            n_estimators=200,
            max_depth=4,
            learning_rate=0.05,
            random_state=42
        ))
    ])
    model = TransformedTargetRegressor(regressor=pipeline, func=np.log1p, inverse_func=np.expm1)
    model.fit(X, y)

    return model

def find_extreme_points(model, coarse_step=0.01):
    """
    网格搜索找出三个最佳点（最小）和三个最差点（最大）
    """
    grid = np.arange(0.0001, 1.0, coarse_step)
    
    # 使用堆来存储最佳点和最差点
    min_heap = []  # 存储(-pred, (pred, code, math, general)) 用于最大堆模拟最小堆
    max_heap = []  # 存储(pred, (pred, code, math, general)) 用于最小堆模拟最大堆
    
    total_points = 0
    for code in grid:
        for math in grid:
            general = 1.0 - code - math
            if general < 0:
                continue
                
            X_input = pd.DataFrame([[code, math, general]], columns=['CODE', 'MATH', 'GENERAL'])
            pred = model.predict(X_input)[0]
            total_points += 1
            
            # 处理最佳点（最小值）
            item = (pred, code, math, general)
            if len(min_heap) < 50:
                heapq.heappush(min_heap, (-pred, item))
            else:
                # 如果当前值比堆中最大值小，则替换
                if pred < min_heap[0][1][0]:
                    heapq.heapreplace(min_heap, (-pred, item))
            
            # 处理最差点（最大值）
            if len(max_heap) < 50:
                heapq.heappush(max_heap, (pred, item))
            else:
                # 如果当前值比堆中最小值大，则替换
                if pred > max_heap[0][0]:
                    heapq.heapreplace(max_heap, (pred, item))
    
    print(f"总网格点数: {total_points}")
    
    # 提取并排序结果
    best_points = [item for _, item in min_heap]
    best_points.sort(key=lambda x: x[0])  # 按值升序排序
    
    worst_points = [item for _, item in max_heap]
    worst_points.sort(key=lambda x: x[0], reverse=True)  # 按值降序排序
    
    return best_points, worst_points

def main():
    file_path = '1M_data_config.txt'
    df = load_data(file_path)
    X = df[['CODE', 'MATH', 'GENERAL']]

    # 用于存储所有目标的结果
    summary = {}

    # 为每个目标分别建模和优化
    for target in ['平均RL-LOSS', '最大RL-LOSS', 'PPL']:
        y = df[target]
        print(f"\n{'='*50}")
        print(f"📌   正在拟合模型预测 {target} ...")
        print(f"{'='*50}")

        model = train_model(X, y)
        best_points, worst_points = find_extreme_points(model)

        summary[target] = {
            'best_points': best_points,
            'worst_points': worst_points
        }

        # 打印最佳点
        print(f"\n✅   [{target}] 最佳点 (最小):")
        for i, (pred, code, math, general) in enumerate(best_points):
            print(f"  第{i+1}佳: CODE={code:.4f}, MATH={math:.4f}, GENERAL={general:.4f}, {target}={pred:.6f}")
        
        # 打印最差点
        print(f"\n⚠️   [{target}] 最差点 (最大):")
        for i, (pred, code, math, general) in enumerate(worst_points):
            print(f"  第{i+1}差: CODE={code:.4f}, MATH={math:.4f}, GENERAL={general:.4f}, {target}={pred:.6f}")

    # 打印对比结果表格
    print("\n📊   结果对比总结:")
    for target in ['平均RL-LOSS', '最大RL-LOSS', 'PPL']:
        res = summary[target]
        print(f"\n==== {target} ====")
        
        print("\n最佳点 (最小):")
        print("排名 |    CODE    |   MATH    |  GENERAL  |     值")
        for i, (pred, code, math, general) in enumerate(res['best_points']):
            print(f" {i+1}   | {code:.6f} | {math:.6f} | {general:.6f} | {pred:.6f}")
        
        print("\n最差点 (最大):")
        print("排名 |    CODE    |   MATH    |  GENERAL  |     值")
        for i, (pred, code, math, general) in enumerate(res['worst_points']):
            print(f" {i+1}   | {code:.6f} | {math:.6f} | {general:.6f} | {pred:.6f}")

if __name__ == "__main__":
    main()
