import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pickle
import time
import os
from torch.utils.data import DataLoader, TensorDataset
from training_all import *

# torch.set_default_dtype(torch.float64)
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
defaults = config()


def main():
    args = config()
    args['probType'] = 'acopf'
    args['opfSize'] = [30, 10000] 
    # args['opfSize'] = [118, 20000]
    data, result_save_dir, model_save_dir = load_instance(args)
    train_descent(data, args, model_save_dir, result_save_dir)


def precompute_initial_solutions(data, Xtrain, solver_net, homeo_mapping, args, batch_size=512):
    """
    precompute the solution of other nn solver`
    """
    solver_net.eval()
    homeo_mapping.eval()
    args['proj_para']['useTestCorr'] = True

    num_samples = Xtrain.shape[0]
    Y0_cache = torch.zeros(num_samples, data.ydim, device=DEVICE)
    
    with torch.no_grad():
        for i in range(0, num_samples, batch_size):
            end_idx = min(i + batch_size, num_samples)
            X_batch = Xtrain[i:end_idx]
            
            # 调用initial_sol函数计算初始解
            Y0_batch = initial_sol(data, X_batch, solver_net, homeo_mapping, args)
            Y0_cache[i:end_idx] = Y0_batch

    return Y0_cache


def train_epoch(data, descent_models, step_ratio, optimizer, step_optimizer, train_dataloader, Y0_train_cache, descent_step, batch_size, args, grad_clip_norm=2.0):
    """执行一个训练epoch"""
    for batch_idx, (Xtrain_batch, Ytrain_batch) in enumerate(train_dataloader):
        optimizer.zero_grad()
        step_optimizer.zero_grad()
        
        # Use precomputed initial solutions
        batch_start = batch_idx * batch_size
        batch_end = min(batch_start + batch_size, len(Y0_train_cache))
        Y0 = Y0_train_cache[batch_start:batch_end]

        Ynew = Y0
        for s in range(descent_step):
            border_real = data.ineq_real(Xtrain_batch, Ynew)
            # inver_eq_jac, eq_jac = data.compute_inverse_eq_jac(Ynew)
            # dynz_dz = data.compute_dynz_dz_new(inver_eq_jac, eq_jac)
            eq_jac = data.eq_jac(Ynew)
            LU, pivots = torch.linalg.lu_factor(eq_jac[:, :, data.other_vars])
            dynz_dz = -torch.linalg.lu_solve(LU, pivots, eq_jac[:, :, data.partial_vars])

            f_grad = data.obj_partial_grad(Ynew, dynz_dz)
            ineq_grad = data.ineq_partial_jac(Ynew, dynz_dz)
            d = descent_models[s](f_grad, ineq_grad, border_real)
            # eq_res = data.eq_resid(Xtrain_batch, Ynew)
      
            dstep = stepSize(d, ineq_grad, border_real)
            Ynew = Ynew + torch.sigmoid(step_ratio[s]) * dstep * data.complete_d(d, dynz_dz)

        # Compute losses  
        ineq_dist = data.ineq_resid(Xtrain_batch, Ynew)
        ineq_cost = torch.norm(ineq_dist, dim=1, p=1)          
        eq_cost = torch.norm(data.eq_resid(Xtrain_batch, Ynew), dim=1, p=1)  
        training_obj = data.obj_fn(Ynew)
        train_loss = training_obj + args['nn_para']['softWeightEqFrac']*eq_cost + args['nn_para']['softWeightInEqFrac']*ineq_cost
        
        train_loss.sum().mean().backward()
        
        # 梯度裁剪
        torch.nn.utils.clip_grad_norm_(descent_models.parameters(), max_norm=grad_clip_norm)
        torch.nn.utils.clip_grad_norm_([step_ratio], max_norm=grad_clip_norm)
        
        optimizer.step()
        step_optimizer.step()


def evaluate_step(data, Xtest, Ytest, Y_current, violation_threshold):
    """评估当前步骤的性能指标"""
    # Calculate violation statistics
    eq_vio = torch.abs(data.eq_resid(Xtest, Y_current))
    ineq_vio = torch.clamp(data.ineq_real(Xtest, Y_current), 0)
    eq_vio_ratio = (torch.max(eq_vio, dim=1)[0] > violation_threshold).float().mean().item() * 100
    ineq_vio_ratio = (torch.max(ineq_vio, dim=1)[0] > violation_threshold).float().mean().item() * 100
    
    # Identify samples that don't violate constraints
    eq_max_vio = torch.max(eq_vio, dim=1)[0]
    ineq_max_vio = torch.max(ineq_vio, dim=1)[0]
    feasible_mask = (eq_max_vio <= violation_threshold) & (ineq_max_vio <= violation_threshold)
    
    # Calculate errors for all samples
    sol_error_all = torch.norm(Ytest - Y_current, dim=1, p=1).mean().detach().item()
    rel_sol_error_all = (torch.norm(Ytest - Y_current, dim=1, p=1) / torch.norm(Ytest, dim=1, p=1)).mean().detach().item()
    obj_error_all = torch.mean(torch.abs(data.obj_fn(Y_current) - data.obj_fn(Ytest))).detach().item()
    rel_obj_error_all = torch.mean(torch.abs(data.obj_fn(Y_current)/data.obj_fn(Ytest) - 1)).detach().item()
    
    # Calculate errors for feasible samples only
    if feasible_mask.sum() > 0:
        Y_feasible = Y_current[feasible_mask]
        Ytest_feasible = Ytest[feasible_mask]
        sol_error_feasible = torch.norm(Ytest_feasible - Y_feasible, dim=1, p=1).mean().detach().item()
        rel_sol_error_feasible = (torch.norm(Ytest_feasible - Y_feasible, dim=1, p=1) / torch.norm(Ytest_feasible, dim=1, p=1)).mean().detach().item()
        obj_error_feasible = torch.mean(torch.abs(data.obj_fn(Y_feasible) - data.obj_fn(Ytest_feasible))).detach().item()
        rel_obj_error_feasible = torch.mean(torch.abs(data.obj_fn(Y_feasible)/data.obj_fn(Ytest_feasible) - 1)).detach().item()
        
        sol_error_str = f"{sol_error_all:.1e} {sol_error_feasible:.1e}"
        rel_sol_error_str = f"{rel_sol_error_all:.1e} {rel_sol_error_feasible:.1e}"
        obj_error_str = f"{obj_error_all:.1e} {obj_error_feasible:.1e}"
        rel_obj_error_str = f"{rel_obj_error_all:.1e} {rel_obj_error_feasible:.1e}"
    else:
        sol_error_str = f"{sol_error_all:.1e} N/A"
        rel_sol_error_str = f"{rel_sol_error_all:.1e} N/A"
        obj_error_str = f"{obj_error_all:.1e} N/A"
        rel_obj_error_str = f"{rel_obj_error_all:.1e} N/A"
    
    return {
        'ineq_vio': torch.mean(ineq_vio).detach().item(),
        'eq_vio': torch.mean(eq_vio).detach().item(),
        'ineq_vio_ratio': ineq_vio_ratio,
        'eq_vio_ratio': eq_vio_ratio,
        'sol_error_str': sol_error_str,
        'rel_sol_error_str': rel_sol_error_str,
        'obj_error_str': obj_error_str,
        'rel_obj_error_str': rel_obj_error_str
    }


def test_descent_steps(data, descent_models, step_ratio, Xtest, Ytest, Y0_test, descent_step, violation_threshold, initial_time_sec=None):
    """执行测试阶段的下降步骤并收集结果（含时间）"""
    table_data = []
    
    # Evaluate initial solution（把初始解耗时放在最后一列）
    initial_metrics = evaluate_step(data, Xtest, Ytest, Y0_test, violation_threshold)
    initial_time_str = f"{initial_time_sec:.4f}" if (initial_time_sec is not None) else "N/A"
    initial_data = [
        "Initial",
        f"{initial_metrics['ineq_vio']:.1e}",
        f"{initial_metrics['eq_vio']:.1e}", 
        f"{initial_metrics['ineq_vio_ratio']:.1f}%",
        f"{initial_metrics['eq_vio_ratio']:.1f}%",
        initial_metrics['sol_error_str'],
        initial_metrics['rel_sol_error_str'],
        initial_metrics['obj_error_str'],
        initial_metrics['rel_obj_error_str'],
        initial_time_str
    ]
    table_data.append(initial_data)
    
    # Perform descent steps
    Ynew = Y0_test
    for s in range(descent_step):
        t0 = time.time()

        border_real = data.ineq_real(Xtest, Ynew)
        # inver_eq_jac, eq_jac = data.compute_inverse_eq_jac(Ynew)
        # dynz_dz = data.compute_dynz_dz_new(inver_eq_jac, eq_jac)
        eq_jac = data.eq_jac(Ynew)
        LU, pivots = torch.linalg.lu_factor(eq_jac[:, :, data.other_vars])
        dynz_dz = -torch.linalg.lu_solve(LU, pivots, eq_jac[:, :, data.partial_vars])

        f_grad = data.obj_partial_grad(Ynew, dynz_dz)
        ineq_grad = data.ineq_partial_jac(Ynew, dynz_dz)
        d = descent_models[s](f_grad, ineq_grad, border_real)
        eq_res = data.eq_resid(Xtest, Ynew)
        dstep = stepSize(d, ineq_grad, border_real)
        stepsize = torch.sigmoid(step_ratio[s]) * dstep 

        # corr_eq_res = inver_eq_jac.bmm(eq_res.unsqueeze(-1)).squeeze(-1) / (-stepsize)
        corr_eq_res = torch.linalg.lu_solve(LU, pivots, eq_res.unsqueeze(-1)).squeeze(-1) / (-stepsize)
        Ynew = Ynew + stepsize * data.complete_d_pami_cbwf(d, dynz_dz, corr_eq_res)

        step_elapsed = time.time() - t0

        # Evaluate current step
        step_metrics = evaluate_step(data, Xtest, Ytest, Ynew, violation_threshold)
        step_data = [
            f"Descent-{s}",
            f"{step_metrics['ineq_vio']:.1e}",
            f"{step_metrics['eq_vio']:.1e}",
            f"{step_metrics['ineq_vio_ratio']:.1f}%",
            f"{step_metrics['eq_vio_ratio']:.1f}%",
            step_metrics['sol_error_str'],
            step_metrics['rel_sol_error_str'],
            step_metrics['obj_error_str'],
            step_metrics['rel_obj_error_str'],
            f"{step_elapsed:.4f}"
        ]
        table_data.append(step_data)
    
    return table_data


def print_results_table(table_data, violation_threshold):
    """打印结果表格"""
    headers = ["Step", "Ineq Vio", "Eq Vio", f"Ineq>{violation_threshold:.0e}", f"Eq>{violation_threshold:.0e}",
               "Sol MAE  Corr.", "Rel Sol MAE  Corr.", "Obj Error  Corr.", "Rel Obj Error  Corr.", "Time(s)"]
    # Print table header
    print("| " + " | ".join(f"{h:<15}" for h in headers) + " |")
    print("| " + " | ".join("-" * 15 for _ in headers) + " |")
    # Print table data
    for row in table_data:
        print("| " + " | ".join(f"{cell:<15}" for cell in row) + " |")
    print()


def train_descent(data, args, model_save_dir, result_save_dir):
    """主训练函数"""
    data._device = DEVICE

    # Load data
    Xtrain = data.trainX.to(DEVICE)
    Ytrain = data.trainY.squeeze().to(DEVICE)
    Xtest = data.testX.to(DEVICE)
    Ytest = data.testY.squeeze().to(DEVICE)
    nepochs = args['nn_para']['epochs']
    batch_size = args['nn_para']['batch_size']
    grad_clip_norm = 1.0  # 梯度裁剪阈值

    # Descent model
    zdim = len(data.partial_vars) 
    descent_step = args['nn_para']['descent_step']
    num_layers_per_step = [args['nn_para']['num_layer'] for _ in range(descent_step)]

    # Create separate networks for each descent step
    descent_models = nn.ModuleList([CustomNetwork(zdim, 10*zdim, num_layers_per_step[i]) for i in range(descent_step)]).to(DEVICE)
    
    optimizer = optim.Adam(descent_models.parameters(), lr=0.01)
    solver_shce = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 150], gamma=0.1)

    step_ratio = nn.Parameter(torch.zeros(descent_step, device=DEVICE))
    step_optimizer = optim.SGD([step_ratio], lr=0.01)
    
    # Load pretrained models
    homeo_mapping = torch.load(os.path.join(model_save_dir, 'mapping.pth'), map_location=DEVICE, weights_only=False)
    solver_net = torch.load(os.path.join(model_save_dir, 'solver_net.pth'), map_location=DEVICE, weights_only=False)
    solver_net.eval()
    
    # Precompute initial solutions (train)
    args_no_corr = args.copy()
    args_no_corr['proj_para'] = args['proj_para'].copy()
    Y0_train_cache = precompute_initial_solutions(data, Xtrain, solver_net, homeo_mapping, args_no_corr)

    # Precompute initial solutions (test) WITH timing
    t_init_start = time.time()
    Y0_test_cache = precompute_initial_solutions(data, Xtest, solver_net, homeo_mapping, args_no_corr, batch_size=len(Xtest))
    initial_time_sec = time.time() - t_init_start
    print(f"[Timing] Test initial solutions computed in {initial_time_sec:.4f} s")

    # Create datasets and dataloaders
    train_dataset = TensorDataset(Xtrain, Ytrain)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Training loop
    for i in range(nepochs):  
        # Training phase
        train_epoch(data, descent_models, step_ratio, optimizer, step_optimizer, 
                   train_dataloader, Y0_train_cache, descent_step, batch_size, args, grad_clip_norm)
        solver_shce.step()

        # print gamma
        log_gammas(descent_models, epoch=i)
        
        # Testing phase
        with torch.no_grad():
            print(f"Epoch: {i}:")
            violation_threshold = 1e-5
            
            # Run test descent steps and collect results (pass initial_time_sec for Initial row)
            table_data = test_descent_steps(
                data, descent_models, step_ratio, Xtest, Ytest, 
                Y0_test_cache, descent_step, violation_threshold,
                initial_time_sec=initial_time_sec
            )
            
            # Print results table
            print_results_table(table_data, violation_threshold)


def initial_sol(data, X, solver_net, homeo_mapping, args):
    solver_net.eval()
    homeo_mapping.eval()
    ### NN solution prediction
    with torch.no_grad():
        Y_pred = solver_net(X)
        Y_pred_scale = data.scale(X, Y_pred)
        if 'Eq' in args['algoType']:
            Y = data.complete_partial(X, Y_pred_scale, backward=False)
        else:
            Y = Y_pred_scale

    ### Post-processing for infeasible only
    steps = args['proj_para']['corrTestMaxSteps']
    eps_converge = args['proj_para']['corrEps']
    violation = data.check_feasibility(X, Y)
    penalty = torch.max(torch.abs(violation), dim=1)[0]
    infeasible_index = (penalty > eps_converge).view(-1)
    Y_pred_infeasible = Y[infeasible_index]
    num_infeasible_prediction = Y_pred_infeasible.shape[0]
    Ycorr = Y.detach().clone()
    if num_infeasible_prediction > 0:
        if args['proj_para']['useTestCorr']            :
            if 'H_Bis' in args['algoType']:
                Yproj, steps = homeo_bisection(homeo_mapping, data, args, Y_pred[infeasible_index], X[infeasible_index])
            elif 'G_Bis' in args['algoType']:
                Yproj, steps = gauge_bisection(homeo_mapping, data, args, Y_pred[infeasible_index], X[infeasible_index])
            elif 'D_Proj' in args['algoType']:
                Yproj, steps = diff_projection(data, X[infeasible_index], Y[infeasible_index], args)
            elif 'Proj' in args['algoType']:
                Yproj = data.opt_proj(X[infeasible_index], Y[infeasible_index]).to(Y.device).view(
                    Y_pred_infeasible.shape)
            elif 'WS' in args['algoType']:
                Yproj = data.opt_warmstart(X[infeasible_index], Y[infeasible_index]).to(Y.device).view(
                    Y_pred_infeasible.shape)
            else:
                Yproj = Y_pred_infeasible
            Ycorr[infeasible_index] = Yproj

    return Ycorr


### Descent Net
class CustomLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(CustomLayer, self).__init__()
        self.W = nn.Linear(input_dim, hidden_dim, bias=True)
        self.V = nn.Linear(hidden_dim, input_dim, bias=False)
        self.gamma = nn.Parameter(torch.tensor(1.0))

    def forward(self, d, f):
        hidden = F.relu(self.W(f))
        T_f = self.V(hidden)  
        d = d - self.gamma * T_f
        return d


class CustomNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers):
        super(CustomNetwork, self).__init__()
        self.layers = nn.ModuleList([CustomLayer(input_dim, output_dim) for _ in range(num_layers)])

    def forward(self, grad, ineq_grad, border):
        d = -grad
        # d = F.normalize(d, p=2, dim=1)   

        M = 1.0
        for layer in self.layers:
            cj = (torch.norm(grad, dim=1, keepdim=True)/(1e-4 - 0.5*M*border)).unsqueeze(-1) 
            index = ((ineq_grad @ d.unsqueeze(-1)) > -M*border.unsqueeze(-1)).float().to(torch.float64)
            mask = cj * index
            u = grad + (ineq_grad.transpose(1,2) @ mask).squeeze(-1)  #(batch size, n)
            d = layer(d, u) 
            d = F.normalize(d, p=2, dim=1)
        return d
    

def stepSize(d, ineq_grad, border):
    product = (ineq_grad @ d.unsqueeze(-1)).squeeze(-1)  # (b,m,n) x (b,n,1) -> (b,m)

    step_arr = torch.div(-border, product)
    step_arr = torch.where(step_arr > 0, step_arr, torch.tensor(float('inf')))
    step = torch.min(step_arr, dim=1)[0]
    return step.unsqueeze(-1)


def log_gammas(descent_models, epoch=None):
    if epoch is not None:
        print(f"\nEpoch {epoch} Gamma values:")
    for step, net in enumerate(descent_models):
        gammas = [f"{layer.gamma.item():.2e}" for layer in net.layers]
        print(f"  Step {step}: {gammas}")

    
if __name__ == '__main__':
    main()
