# -*- coding: utf-8 -*-
"""idbdisc_mMNIST.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/11WMHizgDJambifJRGrzYll3m4fZ8-mGZ
"""

#iDB-PD, AD (double MNIST)

!pip -q install matplotlib numpy torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install git+https://github.com/cooper-org/cooper

import os, time, pickle, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt
from google.colab import files
from google.colab import drive

import cooper
from cooper import ConstrainedMinimizationProblem, Constraint, ConstraintType
from cooper.optim import SimultaneousOptimizer
from cooper.multipliers import DenseMultiplier

#load in data via drive, change as needed
drive.mount('/content/drive')
data_path = "/content/drive/MyDrive/multi_mnist.pickle"
with open(data_path, 'rb') as f:
    data = pickle.load(f)

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # harmless in Colab
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.manual_seed(123)
np.random.seed(123)

#helper functions to save plots
def save_and_show(fname: str):
    plt.tight_layout()
    os.makedirs("plots_ad", exist_ok=True)
    out_path = os.path.join("plots_ad", fname)
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_path}")

def plot_series(metric_name, ys, labels, yscale='log', ylabel=None, fname=None):
    plt.figure()
    for y, lbl in zip(ys, labels):
        x = range(len(y))
        plt.plot(x, y, label=lbl, linewidth=2.0)
    if yscale:
        plt.yscale(yscale)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel(ylabel or metric_name, fontsize=14)
    plt.grid(True, which="both" if yscale == 'log' else "major", linestyle="--", alpha=0.7)
    plt.legend()
    save_and_show(fname or f"{metric_name.lower().replace(' ', '_')}_ad.png")


#load in data
X_np = data[2]
n_samples = X_np.shape[0]
X = torch.tensor(X_np.reshape(n_samples, -1), dtype=torch.float32, device=device)
y_true = torch.tensor(data[3][:, 0], dtype=torch.long, device=device)

X_const = torch.tensor(X_np.reshape(n_samples, -1), dtype=torch.float32, device=device)
m_samples = X_const.shape[0]
y_true_const = torch.tensor(data[3][:, 1], dtype=torch.long, device=device)

n_classes  = 10
input_dim  = X.shape[1]
hidden_dim = 1500

lambda_reg = 1e-3
gamma0     = 0.002
gamma      = 0.001 / 5.5
alpha      = 0.3 / gamma
epochs     = 3000
sigma      = 0.5 / (lambda_reg * max(m_samples, n_samples))

#model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_classes)
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return self.fc2(x)


def project_onto_simplex(z: torch.Tensor, dim: int = -1) -> torch.Tensor:
    # mirror-ascent style: softmax keeps weights on the simplex
    return torch.softmax(z, dim=dim)

def safe_log_weights(w: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return torch.log(torch.clamp(w, min=eps))

#r pre train
model0 = SimpleNet().to(device)
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples
loss_history0, gradpsi_history = [], []

for epoch in range(int(epochs/3)):
    model0.train()
    logitss = model0(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')

    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const

    grad_const_x = torch.autograd.grad(total_loss_const, model0.parameters(), retain_graph=False, create_graph=False)
    with torch.no_grad():
        last_grad = None
        for p, g in zip(model0.parameters(), grad_const_x):
            p -= gamma0 * g
            last_grad = g

    Mk = int(10*np.ceil(np.log(epoch+2)))
    for _ in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples, device=device))
        with torch.no_grad():
            w_weights.copy_(project_onto_simplex(safe_log_weights(w_weights) + sigma * grad_const_w))

    loss_history0.append(total_loss_const.item())
    if last_grad is not None:
        gradpsi_history.append(torch.norm(last_grad).item())
    if epoch % 100 == 0 or epoch == int(epochs/3) - 1:
        print(f"[Pretrain] Epoch {epoch}: loss={total_loss_const.item():.4f}", flush=True)

r = float(abs(total_loss_const.item()))

#iDB-PD
model = SimpleNet().to(device)
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples

loss_history, dk_history, infeas_history, slack_history = [], [], [], []
start_time_idbpd = time.time()

for epoch in range(epochs):
    model.train()
    logits = model(X); loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')
    logitss = model(X_const); loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')

    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term     = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss   = weighted_loss - reg_term

    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const      = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const    = weighted_loss_const - reg_term_const - r

    grad_obj_x    = torch.autograd.grad(total_loss,       model.parameters(), create_graph=False)
    grad_const_x  = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak        = alpha / ((epoch+2)**1.001)
    grad_obj_vec   = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vec = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vec = torch.cat([p.data.view(-1) for p in model.parameters()])
        norm_g2   = torch.norm(grad_const_vec)
        if max(0.0, total_loss_const.item()) * norm_g2 > 0:
            lam = torch.max(-torch.dot(grad_const_vec, grad_obj_vec) + alphak * norm_g2, torch.tensor(0.0, device=device))
            lam_x_g2 = (lam * grad_const_vec) / (norm_g2 ** 2)
        else:
            lam = torch.tensor(0.0, device=device)
            lam_x_g2 = torch.zeros_like(grad_const_vec)

        updated_vec = param_vec - gamma * (grad_obj_vec + lam_x_g2)
        ptr = 0
        for p in model.parameters():
            n = p.numel()
            p.data.copy_(updated_vec[ptr:ptr+n].view_as(p))
            ptr += n

    Nk = int(2*np.ceil(np.log(epoch+2)))
    for _ in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples, device=device))
        with torch.no_grad():
            y_weights.copy_(project_onto_simplex(safe_log_weights(y_weights) + sigma * grad_obj_y))

    Mk = int(10*np.ceil(np.log(epoch+2)))
    for _ in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples, device=device))
        with torch.no_grad():
            w_weights.copy_(project_onto_simplex(safe_log_weights(w_weights) + sigma * grad_const_w))

    loss_history.append(total_loss.item())
    dk_history.append(torch.norm(grad_obj_vec + lam_x_g2).item())
    infeas_history.append(max(0.0, total_loss_const.item()))
    slack_history.append(torch.abs(lam * total_loss_const.item()).item())

    if epoch % 100 == 0 or epoch == epochs - 1:
        print(f"[iDB-PD] Epoch {epoch}: ||dk||={dk_history[-1]:.4f} Infeas={infeas_history[-1]:.4f}", flush=True)

runtime_idbpd = time.time() - start_time_idbpd

#adaptive discretization (COOPER)
class DiscreteMinimaxCMP(ConstrainedMinimizationProblem):
    def __init__(self, model, X, y_true, X_const, y_const, Y_set, W_set, r):
        super().__init__()
        self.model = model
        self.X, self.y_true = X, y_true
        self.Xc, self.yc = X_const, y_const
        self.Y_set, self.W_set = Y_set, W_set
        self.r = r
        self.constraint = Constraint(
            multiplier=DenseMultiplier(num_constraints=1, device=X.device),
            constraint_type=ConstraintType.INEQUALITY
        )
    def compute_cmp_state(self, model=None, inputs=None, targets=None):
        logits   = self.model(self.X)
        phi_vals = F.cross_entropy(logits, self.y_true, reduction='none')
        phi      = torch.max(phi_vals[self.Y_set])
        logits_c = self.model(self.Xc)
        psi_vals = F.cross_entropy(logits_c, self.yc,     reduction='none')
        psi      = torch.max(psi_vals[self.W_set]) - self.r
        constraint_state = cooper.ConstraintState(violation=psi)
        return cooper.CMPState(loss=phi, observed_constraints={self.constraint: constraint_state})

def adaptive_discretization_solver(model, X, y_true, Xc, yc, r, max_iters=3000, eps=1e-3, lr=1e-3):
    n_samples, m_samples = X.shape[0], Xc.shape[0]
    dk_history_c, infeas_history_c, slack_history_c, loss_history_c = [], [], [], []
    Y_set = [torch.randint(0, n_samples, (1,)).item()]
    W_set = [torch.randint(0, m_samples, (1,)).item()]
    history = []
    for k in range(max_iters):
        cmp = DiscreteMinimaxCMP(model, X, y_true, Xc, yc, Y_set, W_set, r)
        primal_opt = torch.optim.Adam(model.parameters(), lr=lr)
        dual_opt   = torch.optim.SGD(cmp.dual_parameters(), lr=lr, maximize=True)
        coop_opt   = SimultaneousOptimizer(cmp, primal_opt, dual_opt)
        coop_opt.roll(compute_cmp_state_kwargs={})

        model.eval()
        with torch.no_grad():
            full_loss  = F.cross_entropy(model(X),  y_true, reduction='none')
            full_const = F.cross_entropy(model(Xc), yc,     reduction='none')

        y_new  = torch.argmax(full_loss).item()
        w_new  = torch.argmax(full_const).item()
        phi_new = full_loss[y_new].item()
        psi_new = full_const[w_new].item()

        if y_new not in Y_set: Y_set.append(y_new)
        if w_new not in W_set: W_set.append(w_new)

        phi_old = torch.max(full_loss[Y_set]).item()
        psi_old = torch.max(full_const[W_set]).item()

        with torch.no_grad():
            grads = [p.grad.view(-1) for p in model.parameters() if p.grad is not None]
            dk = torch.norm(torch.cat(grads)).item() if grads else 0.0

        infeas = max(0.0, psi_old - r)
        lambda_k = cmp.constraint.multiplier().item()
        slack = abs(lambda_k * (psi_old - r))

        dk_history_c.append(dk)
        infeas_history_c.append(infeas)
        slack_history_c.append(slack)
        loss_history_c.append(phi_old)

        outer_phi = phi_new - phi_old
        outer_psi = psi_new - r

        if k % 1000 == 0:
            print(f"[Adaptive Discretization] Iter {k}: phi_max={phi_old:.4f}, psi_max={psi_old:.4f}, "
                  f"outer_phi={outer_phi:.4f}, outer_psi={outer_psi:.4f}", flush=True)

        if outer_phi <= eps and outer_psi <= eps:
            print("[Adaptive Discretization] Converged at iteration", k)
            break

    return model, Y_set, W_set, dk_history_c, infeas_history_c, slack_history_c, loss_history_c

model_ad = SimpleNet().to(device)

model_ad, Y_set, W_set, dk_history_ad, infeas_history_ad, slack_history_ad, loss_history_ad = adaptive_discretization_solver(
    model_ad, X, y_true, X_const, y_true_const, r, max_iters=3000, eps=1e-3, lr=1e-3
)


#plots

# Infeasibility
plot_series(
    "Infeasibility",
    [infeas_history, infeas_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Infeasibility ($[\psi(x_k,w_k)]_+$)',
    fname="infeasibility_ad.png"
)

# Stationarity
plot_series(
    "Stationarity",
    [dk_history, dk_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Stationarity ($\|d_k\|$)',
    fname="stationarity_ad.png"
)

# Slackness (only methods that have λ)
plot_series(
    "Slackness",
    [slack_history, slack_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Slackness ($|\lambda_k \cdot \psi(x_k,w_k)|$)',
    fname="slackness_ad.png"
)

# Objective Training Loss
plot_series(
    "Objective Training Loss",
    [loss_history, loss_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale=None,
    ylabel="Objective Training Loss",
    fname="objective_loss_ad.png"
)


# Zip & download all plots together
!zip -r -q plots_ad.zip plots_ad
files.download('plots_ad.zip')