# -*- coding: utf-8 -*-
"""idbGDMA_mMNIST.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1xgMUNTwUPUhj-szpw4rw9tsgOvaBgCoM
"""

# iDB-PD + GDMA variants (rho = 1,2,5,10)
#double MNIST

# PyTorch, matplotlib, etc.
!pip -q install matplotlib numpy torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121

#import relevant packages etc
import os, time, pickle, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt
from google.colab import files
from google.colab import drive

#connect google drive to get data
#change as needed to load in data
drive.mount('/content/drive')
data_path = "/content/drive/MyDrive/multi_mnist.pickle"
with open(data_path, 'rb') as f:
    data = pickle.load(f)

#setup
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # harmless in Colab
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.manual_seed(123)
np.random.seed(123)

#helper for saving images
def save_and_show(fname: str):
    plt.tight_layout()
    os.makedirs("plots_gdma", exist_ok=True)
    out_path = os.path.join("plots_gdma", fname)
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.show()
    print(f"Saved: {out_path}")

# Load data
X_np = data[2]
n_samples = X_np.shape[0]
X = torch.tensor(X_np.reshape(n_samples, -1), dtype=torch.float32, device=device)
y_true = torch.tensor(data[3][:, 0], dtype=torch.long, device=device)

X_const = torch.tensor(X_np.reshape(n_samples, -1), dtype=torch.float32, device=device)
m_samples = X_const.shape[0]
y_true_const = torch.tensor(data[3][:, 1], dtype=torch.long, device=device)

#parameters of problem
n_classes  = 10
input_dim  = X.shape[1]
hidden_dim = 1500

lambda_reg = 1e-3
gamma0     = 0.002
gamma      = 0.001 /5.5
alpha      = 0.3/ gamma
epochs     = 3000
sigma      = 0.5/ (lambda_reg * max(m_samples, n_samples))

#NN model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_classes)
    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return self.fc2(x)

#helper fucntions for NN model
def project_onto_simplex(z: torch.Tensor, dim: int = -1) -> torch.Tensor:
    return torch.softmax(z, dim=dim)

def safe_log_weights(w: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return torch.log(torch.clamp(w, min=eps))

#pre-train to learn threshold r
model0 = SimpleNet().to(device)
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples
loss_history0, gradpsi_history = [], []

for epoch in range(int(epochs/3)):
    model0.train()
    logitss = model0(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')

    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const

    grad_const_x = torch.autograd.grad(total_loss_const, model0.parameters(), retain_graph=False, create_graph=False)
    with torch.no_grad():
        last_grad = None
        for p, g in zip(model0.parameters(), grad_const_x):
            p -= gamma0 * g
            last_grad = g

    Mk = int(10*np.ceil(np.log(epoch+2)))
    for _ in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples, device=device))
        with torch.no_grad():
            w_weights.copy_(project_onto_simplex(safe_log_weights(w_weights) + sigma * grad_const_w))

    loss_history0.append(total_loss_const.item())
    if last_grad is not None:
        gradpsi_history.append(torch.norm(last_grad).item())
    if epoch % 100 == 0 or epoch == int(epochs/3) - 1:
        print(f"[Pretrain] Epoch {epoch}: loss={total_loss_const.item():.4f}", flush=True)

r = float(abs(total_loss_const.item()))

# iDB-PD
model = SimpleNet().to(device)
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples

loss_history, dk_history, infeas_history = [], [], []

for epoch in range(epochs):
    model.train()
    logits = model(X); loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')
    logitss = model(X_const); loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')

    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term     = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss   = weighted_loss - reg_term

    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const      = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const    = weighted_loss_const - reg_term_const - r

    grad_obj_x    = torch.autograd.grad(total_loss,       model.parameters(), create_graph=False)
    grad_const_x  = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak        = alpha / ((epoch+2)**1.001)
    grad_obj_vec   = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vec = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vec = torch.cat([p.data.view(-1) for p in model.parameters()])
        norm_g2   = torch.norm(grad_const_vec)
        if max(0.0, total_loss_const.item()) * norm_g2 > 0:
            lam = torch.max(-torch.dot(grad_const_vec, grad_obj_vec) + alphak * norm_g2, torch.tensor(0.0, device=device))
            lam_x_g2 = (lam * grad_const_vec) / (norm_g2 ** 2)
        else:
            lam = torch.tensor(0.0, device=device)
            lam_x_g2 = torch.zeros_like(grad_const_vec)

        updated_vec = param_vec - gamma * (grad_obj_vec + lam_x_g2)
        ptr = 0
        for p in model.parameters():
            n = p.numel()
            p.data.copy_(updated_vec[ptr:ptr+n].view_as(p))
            ptr += n

    Nk = int(2*np.ceil(np.log(epoch+2)))
    for _ in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples, device=device))
        with torch.no_grad():
            y_weights.copy_(project_onto_simplex(safe_log_weights(y_weights) + sigma * grad_obj_y))

    Mk = int(10*np.ceil(np.log(epoch+2)))
    for _ in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples, device=device))
        with torch.no_grad():
            w_weights.copy_(project_onto_simplex(safe_log_weights(w_weights) + sigma * grad_const_w))

    loss_history.append(total_loss.item())
    dk_history.append(torch.norm(grad_obj_vec + lam_x_g2).item())
    infeas_history.append(max(0.0, total_loss_const.item()))

    if epoch % 100 == 0 or epoch == epochs - 1:
        print(f"[iDB-PD] Epoch {epoch}: ||dk||={dk_history[-1]:.4f} Infeas={infeas_history[-1]:.4f}", flush=True)


#baseline method: GDMA with varying penalty (rho=1,2,5,10)

def run_penalty_baseline(pen_param, epochs=epochs):
    model_p = SimpleNet().to(device)
    y_w = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
    w_w = torch.ones(m_samples, device=device, requires_grad=True) / m_samples

    loss_hist, infeas_hist, dk_hist = [], [], []
    for epoch in range(epochs):
        model_p.train()
        logits = model_p(X); loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')
        logitss = model_p(X_const); loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')

        weighted_loss = torch.dot(y_w, loss_per_sample)
        reg_term      = lambda_reg / n_samples * torch.sum((n_samples * y_w - 1) ** 2)
        total_loss    = weighted_loss - reg_term

        weighted_loss_const = torch.dot(w_w, loss_const_per_sample)
        reg_term_const      = lambda_reg / m_samples * torch.sum((m_samples * w_w - 1) ** 2)
        total_loss_const    = weighted_loss_const - reg_term_const - r

        grad_obj_x    = torch.autograd.grad(total_loss,       model_p.parameters(), create_graph=False)
        grad_const_x  = torch.autograd.grad(total_loss_const, model_p.parameters(), create_graph=False)
        grad_obj_vec   = torch.cat([g.view(-1) for g in grad_obj_x])
        grad_const_vec = torch.cat([g.view(-1) for g in grad_const_x])

        with torch.no_grad():
            param_vec = torch.cat([p.data.view(-1) for p in model_p.parameters()])
            updated_vec = param_vec - gamma * (grad_obj_vec + pen_param * grad_const_vec)
            ptr = 0
            for p in model_p.parameters():
                n = p.numel()
                p.data.copy_(updated_vec[ptr:ptr+n].view_as(p))
                ptr += n

        Nk = int(2*np.ceil(np.log(epoch+2)))
        for _ in range(Nk):
            grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_w - torch.ones(n_samples, device=device))
            with torch.no_grad():
                y_w.copy_(project_onto_simplex(safe_log_weights(y_w) + sigma * grad_obj_y))

        Mk = int(10*np.ceil(np.log(epoch+2)))
        for _ in range(Mk):
            grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_w - torch.ones(m_samples, device=device))
            with torch.no_grad():
                w_w.copy_(project_onto_simplex(safe_log_weights(w_w) + sigma * grad_const_w))

        loss_hist.append(total_loss.item())
        infeas_hist.append(max(0.0, total_loss_const.item()))
        dk_hist.append(torch.norm(grad_obj_vec + pen_param*grad_const_vec).item())

        if epoch % 1000 == 0 or epoch == epochs - 1:
            print(f"[GDMA ρ={pen_param}] Epoch {epoch}: ||dk||={dk_hist[-1]:.4f} Infeas={infeas_hist[-1]:.4f}", flush=True)

    return loss_hist, infeas_hist, dk_hist

loss_history2, infeas_history2, dk_history2 = run_penalty_baseline(1)

loss_history3, infeas_history3, dk_history3 = run_penalty_baseline(2)

loss_history4, infeas_history4, dk_history4 = run_penalty_baseline(5)

loss_history5, infeasibility5, dk_history5 = run_penalty_baseline(10)
# rename so downstream code is consistent
infeas_history5 = infeasibility5

#plots
def plot_series(metric_name, ys, labels, yscale='log', ylabel=None, fname=None):
    plt.figure()
    for y, lbl in zip(ys, labels):
        x = range(len(y))
        plt.plot(x, y, label=lbl, linewidth=2.0)
    if yscale:
        plt.yscale(yscale)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel(ylabel or metric_name, fontsize=14)
    plt.grid(True, which="both" if yscale == 'log' else "major", linestyle="--", alpha=0.7)
    plt.legend()
    save_and_show(fname or f"{metric_name.lower().replace(' ', '_')}_gdma.png")

labels_all = [
    "iDB-PD",
    "GDMA (ρ=1)",
    "GDMA (ρ=2)",
    "GDMA (ρ=5)",
    "GDMA (ρ=10)"
]

# Infeasibility
plot_series(
    "Infeasibility",
    [infeas_history, infeas_history2, infeas_history3, infeas_history4, infeas_history5],
    labels_all,
    yscale='log',
    ylabel=r'Infeasibility ($[\psi(x_k,w_k)]_+$)',
    fname="infeasibility_gdma.png"
)

# Stationarity
plot_series(
    "Stationarity",
    [dk_history, dk_history2, dk_history3, dk_history4, dk_history5],
    labels_all,
    yscale='log',
    ylabel=r'Stationarity ($\|d_k\|$)',
    fname="stationarity_gdma.png"
)

# Objective Training Loss
plot_series(
    "Objective Training Loss",
    [loss_history, loss_history2, loss_history3, loss_history4, loss_history5],
    labels_all,
    yscale=None,
    ylabel="Objective Training Loss",
    fname="objective_loss_gdma.png"
)

# Zip & download all plots together
!zip -r -q plots_gdma.zip plots_gdma
files.download('plots_gdma.zip')