import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
torch.set_default_dtype(torch.float64)

import pickle
import os
import argparse
import operator
import time
from functools import reduce
from torch.utils.data import TensorDataset, DataLoader

from utils import str_to_bool
import default_args

DEVICE = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")


def main():
    parser = argparse.ArgumentParser(description='DC3 Descent-Net')
    parser.add_argument('--probType', type=str, default='simple',
                        choices=['simple', 'nonconvex'])
    args, unknown = parser.parse_known_args()
    args = vars(args)

    # 从default_args获取超参数
    defaults = default_args.method_default_args(args['probType'])
    for key in defaults.keys():
        if args.get(key) is None:
            args[key] = defaults[key]

    # 加载数据
    prob_type = args['probType']
    if prob_type == 'simple':
        filepath = os.path.join('datasets', 'simple', "random_simple_dataset_var{}_ineq{}_eq{}_ex{}".format(
            args['simpleVar'], args['simpleIneq'], args['simpleEq'], args['simpleEx']))
    elif prob_type == 'nonconvex':
        filepath = os.path.join('datasets', 'nonconvex', "random_nonconvex_dataset_var{}_ineq{}_eq{}_ex{}".format(
            args['nonconvexVar'], args['nonconvexIneq'], args['nonconvexEq'], args['nonconvexEx']))
    else:
        raise NotImplementedError(f"Unsupported problem type: {prob_type}")

    with open(filepath, 'rb') as f:
        data = pickle.load(f)

    # 数据移至设备
    for attr in dir(data):
        var = getattr(data, attr)
        if not callable(var) and not attr.startswith("__") and torch.is_tensor(var):
            try:
                setattr(data, attr, var.to(DEVICE))
            except AttributeError:
                pass
    data._device = DEVICE

    # 加载solver网络
    model_dir = os.path.join('models', args['probType'], 'NN')
    network_root = os.path.join(model_dir, 'solver_net.dict')
    if not os.path.exists(network_root):
        raise FileNotFoundError(f"Solver network not found at {network_root}")

    solver_net = NNSolver(data, args)
    solver_net.load_state_dict(torch.load(network_root))
    solver_net.to(DEVICE)
    solver_net.eval()

    # 提前计算初始解 Y0
    with torch.no_grad():
        Y0_train = grad_steps(data, data.trainX.to(DEVICE), solver_net(data.trainX.to(DEVICE)), args)

        start_time = time.time()
        Y0_test = grad_steps(data, data.testX.to(DEVICE), solver_net(data.testX.to(DEVICE)), args)
        init_time = time.time() - start_time

    # descent model保存目录
    descent_save_dir = os.path.join('models', args['probType'], 'descent')
    os.makedirs(descent_save_dir, exist_ok=True)

    # 训练
    train_net(data, args, descent_save_dir, Y0_train, Y0_test, init_time)


def train_net(data, args, descent_save_dir, Y0_train, Y0_test, init_time):
    nepochs = args['epochs']
    batch_size = args['batchSize']

    train_dataset = TensorDataset(data.trainX, data.trainY, Y0_train)
    test_dataset = TensorDataset(data.testX, data.testY, Y0_test)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=len(test_dataset))

    descent_step = args['descent_step']

    # 每一步一个独立的网络
    ydim = data._ydim
    descent_models = nn.ModuleList([
        CustomNetwork(ydim, 3*ydim, 3).to(DEVICE) for _ in range(descent_step)
    ])
    optimizer = optim.Adam(descent_models.parameters(), lr=0.001)
    solver_shce = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100, 150], gamma=0.1)

    step_ratio = nn.Parameter(torch.zeros(descent_step, device=DEVICE))
    step_optimizer = optim.SGD([step_ratio], lr=0.01)

    for i in range(nepochs):
        for model in descent_models:
            model.train()

        for Xtrain, Ytrain, Y0_batch in train_loader:
            Xtrain = Xtrain.to(DEVICE)
            Ytrain = Ytrain.to(DEVICE)
            Ynew = Y0_batch.to(DEVICE)

            # 清零梯度
            optimizer.zero_grad()
            step_optimizer.zero_grad()

            ineq_grad = data.G
            proj_mat = projection_matrix(data.A)

            for s in range(descent_step):
                border = data.ineq_resid(Xtrain, Ynew)
                f_grad = data.obj_grad(Ynew)
                d = descent_models[s](f_grad, ineq_grad, border, proj_mat)
                dstep = stepSize(d, ineq_grad, border)
                Ynew = Ynew + torch.sigmoid(step_ratio[s]) * dstep * d

            train_loss = total_loss(data, Xtrain, Ynew, args)
            train_loss.sum().backward()

            optimizer.step()
            step_optimizer.step()

        solver_shce.step()
        
        # print gamma
        log_gammas(descent_models, epoch=i)

        # 测试阶段
        for model in descent_models:
            model.eval()
        with torch.no_grad():
            for Xtest, Ytest, Y0_batch in test_loader:
                Xtest = Xtest.to(DEVICE)
                Ytest = Ytest.to(DEVICE)
                Ynew = Y0_batch.to(DEVICE)

                table_data = test_descent_steps(
                    data, descent_models, step_ratio,
                    Xtest, Ytest, Ynew, descent_step, 1e-5, init_time
                )
                print(f"\nEpoch {i} Results:")
                print_results_table(table_data, 1e-5)

    # 保存所有网络
    for s, model in enumerate(descent_models):
        path = os.path.join(descent_save_dir, f'descent_model_step{s}.dict')
        torch.save(model.state_dict(), path)
        print(f"Descent model (step {s}) saved to: {path}")


def evaluate_step(data, X, Ytest, Y_current, violation_threshold):
    eq_vio = torch.abs(data.eq_resid(X, Y_current))
    ineq_vio = torch.clamp(data.ineq_resid(X, Y_current), 0)
    eq_vio_ratio = (torch.max(eq_vio, dim=1)[0] > violation_threshold).float().mean().item() * 100
    ineq_vio_ratio = (torch.max(ineq_vio, dim=1)[0] > violation_threshold).float().mean().item() * 100
    feasible_mask = (torch.max(eq_vio, dim=1)[0] <= violation_threshold) & \
                    (torch.max(ineq_vio, dim=1)[0] <= violation_threshold)

    sol_error_all = torch.norm(Ytest - Y_current, dim=1, p=1).mean().item()
    rel_sol_error_all = (torch.norm(Ytest - Y_current, dim=1, p=1) /
                         (torch.norm(Ytest, dim=1, p=1) + 1e-10)).mean().item()
    obj_current = data.obj_fn(Y_current)
    obj_test = data.obj_fn(Ytest)
    obj_error_all = torch.mean(torch.abs(obj_current - obj_test)).item()
    rel_obj_error_all = torch.mean(torch.abs(obj_current / (obj_test + 1e-10) - 1)).item()
    obj_val = torch.mean(obj_current).item()

    if feasible_mask.sum() > 0:
        Y_feasible = Y_current[feasible_mask]
        Ytest_feasible = Ytest[feasible_mask]
        sol_error_feasible = torch.norm(Ytest_feasible - Y_feasible, dim=1, p=1).mean().item()
        rel_sol_error_feasible = (torch.norm(Ytest_feasible - Y_feasible, dim=1, p=1) /
                                  (torch.norm(Ytest_feasible, dim=1, p=1) + 1e-10)).mean().item()
        obj_feasible = data.obj_fn(Y_feasible)
        obj_test_feasible = data.obj_fn(Ytest_feasible)
        obj_error_feasible = torch.mean(torch.abs(obj_feasible - obj_test_feasible)).item()
        rel_obj_error_feasible = torch.mean(torch.abs(obj_feasible / (obj_test_feasible + 1e-10) - 1)).item()

        sol_error_str = f"{sol_error_all:.1e} {sol_error_feasible:.1e}"
        rel_sol_error_str = f"{rel_sol_error_all:.1e} {rel_sol_error_feasible:.1e}"
        obj_error_str = f"{obj_error_all:.1e} {obj_error_feasible:.1e}"
        rel_obj_error_str = f"{rel_obj_error_all:.1e} {rel_obj_error_feasible:.1e}"
    else:
        sol_error_str = f"{sol_error_all:.1e} N/A"
        rel_sol_error_str = f"{rel_sol_error_all:.1e} N/A"
        obj_error_str = f"{obj_error_all:.1e} N/A"
        rel_obj_error_str = f"{rel_obj_error_all:.1e} N/A"

    return {
        'obj_val': obj_val,   
        'ineq_vio': torch.mean(ineq_vio).item(),
        'eq_vio': torch.mean(eq_vio).item(),
        'ineq_vio_ratio': ineq_vio_ratio,
        'eq_vio_ratio': eq_vio_ratio,
        'sol_error_str': sol_error_str,
        'rel_sol_error_str': rel_sol_error_str,
        'obj_error_str': obj_error_str,
        'rel_obj_error_str': rel_obj_error_str
    }


def test_descent_steps(data, descent_models, step_ratio, Xtest, Ytest, Y0, descent_step, violation_threshold, init_time):
    table_data = []

    # 初始解评估
    initial_metrics = evaluate_step(data, Xtest, Ytest, Y0, violation_threshold)
    initial_data = ["Initial",                              
                    f"{initial_metrics['obj_val']:.4f}",    
                    f"{initial_metrics['ineq_vio']:.1e}",
                    f"{initial_metrics['eq_vio']:.1e}",
                    f"{initial_metrics['ineq_vio_ratio']:.1f}%",
                    f"{initial_metrics['eq_vio_ratio']:.1f}%",
                    initial_metrics['sol_error_str'],
                    initial_metrics['rel_sol_error_str'],
                    initial_metrics['obj_error_str'],
                    initial_metrics['rel_obj_error_str'],
                    f"{init_time:.4f}"]
    table_data.append(initial_data)

    Ynew = Y0
    ineq_grad = data.G
    proj_mat = projection_matrix(data.A)
    for s in range(descent_step):
        start_time = time.time()
        border = data.ineq_resid(Xtest, Ynew)
        f_grad = data.obj_grad(Ynew)
        d = descent_models[s](f_grad, ineq_grad, border, proj_mat)
        dstep = stepSize(d, ineq_grad, border)
        Ynew = Ynew + torch.sigmoid(step_ratio[s]) * dstep * d
        step_time = time.time() - start_time

        step_metrics = evaluate_step(data, Xtest, Ytest, Ynew, violation_threshold)
        step_data = [f"Descent-{s}",                        
                     f"{step_metrics['obj_val']:.4f}",     
                     f"{step_metrics['ineq_vio']:.1e}",
                     f"{step_metrics['eq_vio']:.1e}",
                     f"{step_metrics['ineq_vio_ratio']:.1f}%",
                     f"{step_metrics['eq_vio_ratio']:.1f}%",
                     step_metrics['sol_error_str'],
                     step_metrics['rel_sol_error_str'],
                     step_metrics['obj_error_str'],
                     step_metrics['rel_obj_error_str'],
                     f"{step_time:.4f}"]
        table_data.append(step_data)

    return table_data


def print_results_table(table_data, violation_threshold):
    headers = ["Step", "Obj Val", "Ineq Vio", "Eq Vio",
               f"Ineq>{violation_threshold:.0e}",
               f"Eq>{violation_threshold:.0e}",
               "Sol MAE",
               "Rel Sol MAE",
               "Obj Error",
               "Rel Obj Error",
               "Time (s)"]

    print("| " + " | ".join(f"{h:<12}" for h in headers) + " |")
    print("| " + " | ".join("-" * 12 for _ in headers) + " |")
    for row in table_data:
        print("| " + " | ".join(f"{cell:<12}" for cell in row) + " |")


# models
class CustomLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.W = nn.Linear(input_dim, hidden_dim, bias=True)
        self.V = nn.Linear(hidden_dim, input_dim, bias=False)
        self.gamma = nn.Parameter(torch.tensor(0.1))

    def forward(self, d, f, proj_mat):
        T_f = self.V(F.relu(self.W(f)))
        d = d - self.gamma * T_f
        d = d - d @ proj_mat.T
        norms = torch.norm(d, dim=1, keepdim=True)
        scale = torch.where(norms <= 1, torch.ones_like(norms), 1.0 / norms)
        d = d * scale
        return d


class CustomNetwork(nn.Module):
    def __init__(self, input_dim, output_dim, num_layers):
        super().__init__()
        self.layers = nn.ModuleList([CustomLayer(input_dim, output_dim) for _ in range(num_layers)])

    def forward(self, grad, ineq_grad, border, proj_mat):
        d = -grad
        # d = F.normalize(d, p=2, dim=1)   
        M = 1.0
        for layer in self.layers:
            cj = torch.norm(grad, dim=1, keepdim=True)/(5e-4 - 0.5*M*border)
            index = ((d @ ineq_grad.T) > -M*border).float().to(torch.float64)
            mask = cj * index
            u = grad + mask @ ineq_grad  #(batch size, n)
            d = layer(d, u, proj_mat)
            # d = F.normalize(d, p=2, dim=1)
        return d


# class CustomNetwork(nn.Module):
#     def __init__(self, input_dim, output_dim, num_layers):
#         super(CustomNetwork, self).__init__()
#         self.layers = nn.ModuleList([CustomLayer(input_dim, output_dim) for _ in range(num_layers)])

#     def forward(self, grad, ineq_grad, border, proj_mat):
#         d = -grad
#         for layer in self.layers:
#             index = 1/(0.001 - border) * ((d @ ineq_grad.T) > -1e-5*border).float().to(torch.float64)   #(batch size, m)
#             f = index @ ineq_grad + grad
#             d = layer(d, f, proj_mat)    
#         return d

def stepSize(d, ineq_grad, border):
    border = torch.where(border >= 0, -1e-6, border)
    product = torch.matmul(d, ineq_grad.T)
    step_arr = torch.div(-border, product)
    step_arr = torch.where(step_arr > 0, step_arr, torch.tensor(float('inf')))
    step = torch.min(step_arr, dim=1)[0]
    return step.unsqueeze(-1)


def projection_matrix(A):
    At = A.T
    AAt_inv = torch.inverse(A @ At)
    P = At @ AAt_inv @ A
    return P


def total_loss(data, X, Y, args):
    obj_cost = data.obj_fn(Y)
    ineq_dist = data.ineq_dist(X, Y)
    ineq_cost = torch.norm(ineq_dist, dim=1)
    eq_cost = torch.norm(data.eq_resid(X, Y), dim=1)
    return obj_cost + args['softWeight'] * (1 - args['softWeightEqFrac']) * ineq_cost + \
            args['softWeight'] * args['softWeightEqFrac'] * eq_cost


def grad_steps(data, X, Y, args):
    take_grad_steps = args['useTrainCorr']
    if take_grad_steps:
        lr = args['corrLr']
        num_steps = args['corrTrainSteps']
        momentum = args['corrMomentum']
        partial_var = args['useCompl']
        partial_corr = True if args['corrMode'] == 'partial' else False
        if partial_corr and not partial_var:
            assert False, "Partial correction not available without completion."
        Y_new = Y
        old_Y_step = 0
        for i in range(num_steps):
            if partial_corr:
                Y_step = data.ineq_partial_grad(X, Y_new)
            else:
                ineq_step = data.ineq_grad(X, Y_new)
                eq_step = data.eq_grad(X, Y_new)
                Y_step = (1 - args['softWeightEqFrac']) * ineq_step + args['softWeightEqFrac'] * eq_step
            
            new_Y_step = lr * Y_step + momentum * old_Y_step
            Y_new = Y_new - new_Y_step

            old_Y_step = new_Y_step

        return Y_new
    else:
        return Y


class NNSolver(nn.Module):
    def __init__(self, data, args):
        super().__init__()
        self._data = data
        self._args = args
        layer_sizes = [data.xdim, self._args['hiddenSize'], self._args['hiddenSize']]
        layers = reduce(operator.add,
            [[nn.Linear(a,b), nn.BatchNorm1d(b), nn.ReLU(), nn.Dropout(p=0.2)]
                for a,b in zip(layer_sizes[0:-1], layer_sizes[1:])])
        
        output_dim = data.ydim - data.nknowns

        if self._args['useCompl']:
            layers += [nn.Linear(layer_sizes[-1], output_dim - data.neq)]
        else:
            layers += [nn.Linear(layer_sizes[-1], output_dim)]

        for layer in layers:
            if type(layer) == nn.Linear:
                nn.init.kaiming_normal_(layer.weight)

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        out = self.net(x)
 
        if self._args['useCompl']:
            if 'acopf' in self._args['probType']:
                out = nn.Sigmoid()(out)   # used to interpolate between max and min values
            return self._data.complete_partial(x, out)
        else:
            return self._data.process_output(x, out)
        

def log_gammas(descent_models, epoch=None):
    if epoch is not None:
        print(f"\nEpoch {epoch} Gamma values:")
    for step, net in enumerate(descent_models):
        gammas = [f"{layer.gamma.item():.2e}" for layer in net.layers]
        print(f"  Step {step}: {gammas}")

        

if __name__ == '__main__':
    main()