import torch
torch.set_default_dtype(torch.float64)

import pickle
import os
import argparse
import numpy as np
from training_all import *
import default_args

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


def main():
    parser = argparse.ArgumentParser(description='Optimization Solver')
    parser.add_argument('--probType', type=str, default='acopf',
                        choices=['simple', 'nonconvex', 'acopf'])
    args, unknown = parser.parse_known_args()
    args = vars(args)

    # 从default_args获取超参数
    defaults = default_args.method_default_args(args['probType'])
    for key in defaults.keys():
        if args.get(key) is None:
            args[key] = defaults[key]

    prob_type = args['probType']
    if prob_type == 'simple':
        filepath = os.path.join('datasets', 'simple',
                                "random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(
                                    args['simpleVar'], args['simpleIneq'],
                                    args['simpleEq'], args['simpleEx']))
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        solver_list = ["osqp", "qpth"]

    elif prob_type == 'nonconvex':
        filepath = os.path.join('datasets', 'nonconvex',
                                "random_nonconvex_dataset_var{}_ineq{}_eq{}_ex{}".format(
                                    args['nonconvexVar'], args['nonconvexIneq'],
                                    args['nonconvexEq'], args['nonconvexEx']))
        with open(filepath, 'rb') as f:
            data = pickle.load(f)
        solver_list = ["ipopt"]

    elif prob_type == 'acopf':
        args = config()
        args['probType'] = 'acopf'
        # args['opfSize'] = [30, 10000] 
        args['opfSize'] = [118, 20000]
        data, result_save_dir, model_save_dir = load_instance(args)
        solver_list = ["pypower"]

    else:
        raise NotImplementedError(f"Unsupported problem type: {prob_type}")

    # 数据移至设备（acopf 如果内部已经处理，这里可跳过）
    for attr in dir(data):
        var = getattr(data, attr)
        if not callable(var) and not attr.startswith("__") and torch.is_tensor(var):
            try:
                setattr(data, attr, var.to(DEVICE))
            except AttributeError:
                pass
    data._device = DEVICE

    # 分别调用各个求解器
    solver_results = {}
    for solver_name in solver_list:
        sols, total_time, parallel_time = data.opt_solve(
            data.testX, solver_type=solver_name, tol=1e-5
        )
        sols_tensor = torch.tensor(sols, dtype=torch.float64, device=DEVICE)

        obj_val = data.obj_fn(sols_tensor).mean().item()
        solver_results[solver_name] = {
            "sols": sols_tensor,
            "obj_val": obj_val,
            "time": parallel_time
        }

    # 找到最优求解器（函数值最低）
    best_solver = min(solver_results.keys(), key=lambda k: solver_results[k]["obj_val"])
    best_obj = solver_results[best_solver]["obj_val"]

    # 构造表格
    violation_threshold = 1e-5
    table_data = []

    for solver_name, res in solver_results.items():
        if solver_name == best_solver:
            # 最优解：误差全部置 0
            table_data.append([
                solver_name,
                f"{0:.2e}",  # ineq vio
                f"{0:.2e}",  # eq vio
                f"{0:.2f}%",  # vio ratio
                f"{0:.2e}",  # sol error
                f"{0:.2e}",  # rel sol error
                f"{0:.2e}",  # obj error
                f"{0:.2e}",  # rel obj error
                f"{res['time']:.4f}"
            ])
        else:
            # 非最优解：相对最优解计算误差
            metrics = evaluate_step(data,
                                    data.testX.to(DEVICE),
                                    solver_results[best_solver]["sols"],
                                    res["sols"],
                                    violation_threshold)
            table_data.append([
                solver_name,
                f"{metrics['ineq_vio']}",
                f"{metrics['eq_vio']}",
                f"{metrics['vio_ratio']}",
                f"{metrics['sol_error']}",
                f"{metrics['rel_sol_error']}",
                f"{metrics['obj_error']}",
                f"{metrics['rel_obj_error']}",
                f"{res['time']:.4f}"
            ])

    print_results_table(table_data, violation_threshold)


def evaluate_step(data, X, Y_best, Y_current, violation_threshold):
    eq_vio = torch.abs(data.eq_resid(X, Y_current))
    ineq_vio = torch.clamp(data.ineq_resid(X, Y_current), 0)

    # 整体违背比例
    vio_mask = (torch.max(eq_vio, dim=1)[0] > violation_threshold) | \
               (torch.max(ineq_vio, dim=1)[0] > violation_threshold)
    vio_ratio = vio_mask.float().mean().item() * 100

    sol_error_all = torch.norm(Y_best - Y_current, dim=1, p=1).mean().item()
    rel_sol_error_all = (torch.norm(Y_best - Y_current, dim=1, p=1) /
                         (torch.norm(Y_best, dim=1, p=1) + 1e-10)).mean().item()
    obj_current = data.obj_fn(Y_current)
    obj_best = data.obj_fn(Y_best)
    obj_error_all = torch.mean(torch.abs(obj_current - obj_best)).item()
    rel_obj_error_all = torch.mean(torch.abs(obj_current / (obj_best + 1e-10) - 1)).item()

    return {
        'ineq_vio': f"{torch.mean(ineq_vio).item():.2e}",
        'eq_vio': f"{torch.mean(eq_vio).item():.2e}",
        'vio_ratio': f"{vio_ratio:.2f}%",
        'sol_error': f"{sol_error_all:.2e}",
        'rel_sol_error': f"{rel_sol_error_all:.2e}",
        'obj_error': f"{obj_error_all:.2e}",
        'rel_obj_error': f"{rel_obj_error_all:.2e}"
    }


def print_results_table(table_data, violation_threshold):
    headers = ["Solver", "Ineq Vio", "Eq Vio",
               f"Viol>{violation_threshold:.0e}",
               "Sol MAE", "Rel Sol MAE",
               "Obj Error", "Rel Obj Error",
               "Time (s)"]

    print("| " + " | ".join(f"{h:<18}" for h in headers) + " |")
    print("| " + " | ".join("-" * 18 for _ in headers) + " |")
    for row in table_data:
        print("| " + " | ".join(f"{cell:<18}" for cell in row) + " |")


if __name__ == '__main__':
    main()
