import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
torch.set_default_dtype(torch.float64)

import pickle
import os
import argparse
import operator
import time
from functools import reduce
from torch.utils.data import TensorDataset, DataLoader

from utils import str_to_bool
import default_args

import numpy as np
import cvxpy as cp

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


def main():
    parser = argparse.ArgumentParser(description='DC3 Descent-Net')
    parser.add_argument('--probType', type=str, default='simple',
                        choices=['simple', 'nonconvex'])
    args, unknown = parser.parse_known_args()
    args = vars(args)

    defaults = default_args.method_default_args(args['probType'])
    for key in defaults.keys():
        if args.get(key) is None:
            args[key] = defaults[key]

    prob_type = args['probType']
    if prob_type == 'simple':
        filepath = os.path.join('datasets', 'simple', "random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(
            args['simpleVar'], args['simpleIneq'], args['simpleEq'], args['simpleEx']))
    elif prob_type == 'nonconvex':
        filepath = os.path.join('datasets', 'nonconvex', "random_nonconvex_dataset_var{}_ineq{}_eq{}_ex{}".format(
            args['nonconvexVar'], args['nonconvexIneq'], args['nonconvexEq'], args['nonconvexEx']))
    else:
        raise NotImplementedError(f"Unsupported problem type: {prob_type}")

    with open(filepath, 'rb') as f:
        data = pickle.load(f)

    for attr in dir(data):
        var = getattr(data, attr)
        if not callable(var) and not attr.startswith("__") and torch.is_tensor(var):
            try:
                setattr(data, attr, var.to(DEVICE))
            except AttributeError:
                pass
    data._device = DEVICE

    model_dir = os.path.join('models', args['probType'], 'NN')
    network_root = os.path.join(model_dir, 'solver_net.dict')
    if not os.path.exists(network_root):
        raise FileNotFoundError(f"Solver network not found at {network_root}")

    solver_net = NNSolver(data, args)
    solver_net.load_state_dict(torch.load(network_root))
    solver_net.to(DEVICE)
    solver_net.eval()

    with torch.no_grad():
        Y0_train = grad_steps(data, data.trainX.to(DEVICE), solver_net(data.trainX.to(DEVICE)), args)
        Y0_test = grad_steps(data, data.testX.to(DEVICE), solver_net(data.testX.to(DEVICE)), args)

    train_net(data, args, Y0_train, Y0_test)


def train_net(data, args, Y0_train, Y0_test):
    nepochs = args['epochs']
    batch_size = args['batchSize']

    train_dataset = TensorDataset(data.trainX, data.trainY, Y0_train)
    test_dataset = TensorDataset(data.testX, data.testY, Y0_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

    # descent_step = 1
    descent_step = args['descent_step']
    ydim = data._ydim
    descent_models = nn.ModuleList([
        CustomNetwork(ydim, 3*ydim, 3).to(DEVICE) for _ in range(descent_step)
    ])
    optimizers = [optim.Adam(model.parameters(), lr=args['lr']) for model in descent_models]
    schedulers = [optim.lr_scheduler.StepLR(opt, step_size=20, gamma=0.5) for opt in optimizers]

    step_ratio = nn.Parameter(torch.zeros(descent_step, device=DEVICE))
    step_optimizer = optim.SGD([step_ratio], lr=0.01)

    for epoch in range(nepochs):
        for model in descent_models:
            model.train()

        for Xtrain, Ytrain, Y0_batch in train_loader:
            Xtrain = Xtrain.to(DEVICE)
            Ytrain = Ytrain.to(DEVICE)
            Ynew = Y0_batch.to(DEVICE)

            for opt in optimizers:
                opt.zero_grad()
            step_optimizer.zero_grad()

            ineq_grad = data.G
            proj_mat = projection_matrix(data.A)

            for s in range(descent_step):
                border = data.ineq_resid(Xtrain, Ynew)
                f_grad = data.obj_grad(Ynew)

                d_list = descent_models[s](f_grad, ineq_grad, border, proj_mat)
                d = d_list[-1]
                dstep = stepSize(d, ineq_grad, border)
                Ynew = Ynew + torch.sigmoid(step_ratio[s]) * dstep * d

            train_loss = total_loss(data, Xtrain, Ynew, args)
            train_loss.sum().backward()

            for opt in optimizers:
                opt.step()
            step_optimizer.step()

        for scheduler in schedulers:
            scheduler.step()

        print(f"Epoch {epoch} training completed.")

    print("\nTraining finished. Running per-step subproblem analysis on test set...")
    analyze_descent_steps_subproblems(data, descent_models, step_ratio, data.testX, Y0_test, descent_step)


class CustomLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.W = nn.Linear(input_dim, hidden_dim, bias=True)
        self.V = nn.Linear(hidden_dim, input_dim, bias=False)
        self.alpha = nn.Parameter(torch.tensor(1.0))

    def forward(self, d, f, proj_mat):
        T_f = self.V(F.relu(self.W(f)))
        d = d - self.alpha * T_f
        d = d - d @ proj_mat.T
        norms = torch.norm(d, dim=1, keepdim=True)
        scale = torch.where(norms <= 1, torch.ones_like(norms), 1.0 / norms)
        d = d * scale
        return d


class CustomNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([CustomLayer(input_dim, output_dim) for _ in range(num_layers)])

    def forward(self, grad, ineq_grad, border, proj_mat):
        d = -grad
        d = F.normalize(d, p=2, dim=1)
        M = 1.0
        d_list = [d]
        for layer in self.layers:
            cj = torch.norm(grad, dim=1, keepdim=True) / (1e-4 - 0.5 * torch.clamp(border, max=0.0))
            index = ((d @ ineq_grad.T) > -M * border).float().to(torch.float64)
            mask = cj * index
            u = grad + mask @ ineq_grad
            d = layer(d, u, proj_mat)
            d = F.normalize(d, p=2, dim=1)
            d_list.append(d)
        return d_list


def stepSize(d, ineq_grad, border):
    border = torch.where(border >= 0, -1e-6, border)
    product = torch.matmul(d, ineq_grad.T)
    step_arr = torch.div(-border, product)
    step_arr = torch.where(step_arr > 0, step_arr, torch.tensor(float('inf')))
    step = torch.min(step_arr, dim=1)[0]
    return step.unsqueeze(-1)


def projection_matrix(A):
    At = A.T
    AAt_inv = torch.inverse(A @ At)
    P = At @ AAt_inv @ A
    return P


def total_loss(data, X, Y, args):
    obj_cost = data.obj_fn(Y)
    ineq_dist = data.ineq_dist(X, Y)
    ineq_cost = torch.norm(ineq_dist, dim=1)
    eq_cost = torch.norm(data.eq_resid(X, Y), dim=1)
    return obj_cost + args['softWeight'] * (1 - args['softWeightEqFrac']) * ineq_cost + \
            args['softWeight'] * args['softWeightEqFrac'] * eq_cost


def grad_steps(data, X, Y, args):
    take_grad_steps = args['useTrainCorr']
    if take_grad_steps:
        lr = args['corrLr']
        num_steps = args['corrTrainSteps']
        momentum = args['corrMomentum']
        partial_var = args['useCompl']
        partial_corr = True if args['corrMode'] == 'partial' else False
        if partial_corr and not partial_var:
            assert False, "Partial correction not available without completion."
        Y_new = Y
        old_Y_step = 0
        for i in range(num_steps):
            if partial_corr:
                Y_step = data.ineq_partial_grad(X, Y_new)
            else:
                ineq_step = data.ineq_grad(X, Y_new)
                eq_step = data.eq_grad(X, Y_new)
                Y_step = (1 - args['softWeightEqFrac']) * ineq_step + args['softWeightEqFrac'] * eq_step

            new_Y_step = lr * Y_step + momentum * old_Y_step
            Y_new = Y_new - new_Y_step
            old_Y_step = new_Y_step

        return Y_new
    else:
        return Y


class NNSolver(nn.Module):
    def __init__(self, data, args):
        super().__init__()
        self._data = data
        self._args = args
        layer_sizes = [data.xdim, self._args['hiddenSize'], self._args['hiddenSize']]
        layers = reduce(operator.add,
            [[nn.Linear(a,b), nn.BatchNorm1d(b), nn.ReLU(), nn.Dropout(p=0.2)]
                for a,b in zip(layer_sizes[0:-1], layer_sizes[1:])])
        
        output_dim = data.ydim - data.nknowns
        if self._args['useCompl']:
            layers += [nn.Linear(layer_sizes[-1], output_dim - data.neq)]
        else:
            layers += [nn.Linear(layer_sizes[-1], output_dim)]

        for layer in layers:
            if type(layer) == nn.Linear:
                nn.init.kaiming_normal_(layer.weight)

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        out = self.net(x)
        if self._args['useCompl']:
            if 'acopf' in self._args['probType']:
                out = nn.Sigmoid()(out)
            return self._data.complete_partial(x, out)
        else:
            return self._data.process_output(x, out)


def penalty_fun(d, f_grad, ineq_grad, border):
    c = torch.norm(f_grad, dim=1, keepdim=True) / (1e-4 - 0.5 * torch.clamp(border, max=0.0))
    term1 = torch.sum(d * f_grad, dim=1, keepdim=True)
    inner = d @ ineq_grad.T
    clipped = torch.max(inner, -border)
    term2 = torch.sum(c * clipped, dim=1, keepdim=True)
    return torch.mean(term1 + term2)


def solve_optimization(f_grad, ineq_grad, border, eq_grad):
    f_grad = np.asarray(f_grad)
    ineq_grad = np.asarray(ineq_grad)
    border = np.asarray(border)
    eq_grad = np.asarray(eq_grad)
    c = np.linalg.norm(f_grad, ord=2, axis=1, keepdims=True) / (1e-4 - 0.5 * np.clip(border, a_max=0.0, a_min=None))
    batch_size, n = f_grad.shape
    m, _ = ineq_grad.shape
    l, _ = eq_grad.shape
    optimal_values = []

    for i in range(batch_size):
        try:
            d_var = cp.Variable(n)
            constraints = [cp.norm(d_var, 2) <= 1]
            for k in range(l):
                constraints.append(eq_grad[k] @ d_var == 0)
            term1 = f_grad[i] @ d_var
            term2 = 0.0
            for j in range(m):
                inner = ineq_grad[j] @ d_var
                term2 += float(c[i, j]) * cp.maximum(inner, -float(border[i, j]))
            objective = cp.Minimize(term1 + term2)
            prob = cp.Problem(objective, constraints)
            prob.solve(solver=cp.ECOS, verbose=False)
        except Exception:
            continue

        if prob.status == cp.OPTIMAL or prob.status == cp.OPTIMAL_INACCURATE:
            optimal_values.append(prob.value)

    return np.mean(optimal_values) if optimal_values else np.nan


def analyze_descent_steps_subproblems(data, descent_models, step_ratio, Xtest, Y0, descent_step, max_samples=50):
    batch_size = Xtest.shape[0]
    n_samples = min(batch_size, max_samples)
    Xs = Xtest[:n_samples].to(DEVICE)
    Ynew = Y0[:n_samples].to(DEVICE)

    ineq_grad = data.G
    proj_mat = projection_matrix(data.A)

    # 存相对误差
    table = []

    for s in range(descent_step):
        border = data.ineq_resid(Xs, Ynew)
        f_grad = data.obj_grad(Ynew)

        with torch.no_grad():
            d_lists = descent_models[s](f_grad, ineq_grad, border, proj_mat)

        # solver 的最优值
        f_grad_np = f_grad.detach().cpu().numpy()
        ineq_grad_np = ineq_grad.detach().cpu().numpy()
        border_np = border.detach().cpu().numpy()
        eq_grad_np = data.A.detach().cpu().numpy()
        solver_val = solve_optimization(f_grad_np, ineq_grad_np, border_np, eq_grad_np)

        row = []
        for layer_idx, d_layer in enumerate(d_lists):
            val = penalty_fun(d_layer, f_grad, ineq_grad, border).detach().cpu().item()
            if np.isnan(solver_val) or abs(solver_val) < 1e-12:
                rel_err = np.nan
            else:
                rel_err = abs(val - solver_val) / (abs(solver_val) + 1e-8)
            row.append(rel_err)
        table.append(row)

        # 更新 Ynew
        d_final = d_lists[-1]
        dstep = stepSize(d_final, ineq_grad, border)
        step_scalar = torch.sigmoid(step_ratio[s])
        Ynew = Ynew + step_scalar * dstep * d_final

    # -------- 打印表格 --------
    print("\nRelative Error Table:")
    col_width = 12
    header = "Layer".ljust(8) + "".join([f"| Step{s}".center(col_width) for s in range(descent_step)])
    print(header)
    print("-" * len(header))

    n_layers = max(len(row) for row in table)
    for layer_idx in range(n_layers):
        row_vals = []
        for step in range(descent_step):
            if layer_idx < len(table[step]):
                val = table[step][layer_idx]
                row_vals.append(f"{val:.2e}".center(col_width) if not np.isnan(val) else "   nan   ".center(col_width))
            else:
                row_vals.append("   -   ".center(col_width))
        print(f"{str(layer_idx).ljust(8)}" + "".join(row_vals))


if __name__ == '__main__':
    main()
