# -*- coding: utf-8 -*-
"""idbGDMA_yeast.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Cur0Rgtm_zi2paKyMP_g3ePKALiKX59o
"""

#iDB-PD, GDMA (Yeast)

!pip -q install matplotlib numpy pandas liac-arff torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install git+https://github.com/cooper-org/cooper

import os, time, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt
import arff
from google.colab import drive, files

import cooper
from cooper import ConstrainedMinimizationProblem, Constraint, ConstraintType
from cooper.optim import SimultaneousOptimizer
from cooper.multipliers import DenseMultiplier

# load in data via google drive, change as needed
drive.mount('/content/drive')

arff_path = '/content/drive/MyDrive/yeast.arff'

with open(arff_path, 'r') as f:
    dataset = arff.load(f)

df = pd.DataFrame(dataset['data'], columns=[a[0] for a in dataset['attributes']])
data_np = df.to_numpy().astype(np.float32)


def save_and_show(fname: str):
    # Save only (no plt.show, no per-file download)
    os.makedirs("plots_gdma", exist_ok=True)
    out_path = os.path.join("plots_gdma", fname)
    plt.tight_layout()
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_path}")

# YEAST
X1 = data_np[:, :103]
Y1 = data_np[:, 103:117]

label_A = np.argmax(Y1[:, :7], axis=1)   # objective task (7 classes)
label_B = np.argmax(Y1[:, 7:], axis=1)   # constraint task (7 classes)

# torch setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.manual_seed(123)
np.random.seed(123)

# tensors
X = torch.tensor(X1, dtype=torch.float32, device=device)
y_true        = torch.tensor(label_A, dtype=torch.long, device=device)
y_true_const  = torch.tensor(label_B, dtype=torch.long, device=device)
X_const = X

n_samples = X.shape[0]
m_samples = X.shape[0]

#yeast data
n_classes  = 7
input_dim  = X.shape[1]
hidden_dim = 150
lambda_reg = 1e-3
gamma0     = 0.01   # pre-train step for x on constraint
gamma      = 0.001/6
alpha      = 0.06/gamma
epochs     = 3000
sigma      = 0.1/(lambda_reg * max(m_samples, n_samples))  # dual steps (y,w)

# Model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return self.fc2(x)


def project_onto_simplex(z: torch.Tensor, dim: int = -1) -> torch.Tensor:
    return torch.softmax(z, dim=dim)

def safe_log_weights(w: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    return torch.log(torch.clamp(w, min=eps))

model0 = SimpleNet().to(device)
# Sample weights y on the simplex — treat as parameter
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples

# Optimizers
loss_history0 = []
gradpsi_history = []
# Projection onto simplex
def project_onto_simplex(z: torch.Tensor, dim: int = -1) -> torch.Tensor:#(v):
    return torch.softmax(z, dim = dim)

# Pre-training to find the threshold r for the constraint
for epoch in range(int(epochs/2)):
    model0.train()

    # Forward pass for constraint
    logitss = model0(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const

    # Backward for x
    grad_const_x = torch.autograd.grad(total_loss_const, model0.parameters(), retain_graph=False, create_graph=False)

    with torch.no_grad():
        for param, grad2 in zip(model0.parameters(), grad_const_x):
            param -= gamma0 * grad2

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history0.append(total_loss_const.item())
    gradpsi_history.append(torch.norm(grad2))

    if epoch % 100 == 0 or epoch == int(epochs/2) - 1:
        print(f"Epoch {epoch}: loss = {total_loss_const.item():.4f} norm_grad = {torch.norm(grad2):.4f}")  #max(0,total_loss_const.item())

r = np.abs(total_loss_const.item())
model = SimpleNet().to(device)
# Sample weights y on the simplex — treat as parameter
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
#y_weights = nn.Parameter(y_weights)
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples
#w_weights = nn.Parameter(w_weights)

# Optimizers
#optimizer_x = optim.SGD(model.parameters(), lr=gamma)
#optimizer_y = optim.SGD([y_weights], lr=gamma)  # ascent step
loss_history = []
dk_history = []
infeas_history = []
slack_history = []
for epoch in range(epochs):
    model.train()
    #optimizer_y.zero_grad()

    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')  # L_i(x)

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        norm_grad2 = torch.norm(grad_const_vector)
        if max(0, total_loss_const.item()) * norm_grad2 > 0:
            lam = torch.max(-torch.dot(grad_const_vector , grad_obj_vector) + alphak * norm_grad2, 0)
            lam = lam[0]
            lam_x_grad2 = (lam * grad_const_vector) / (norm_grad2 ** 2)
        else:
            lam = torch.tensor(0.0)
            lam_x_grad2 = torch.zeros_like(grad_const_vector)
        updated_vector = param_vector - gamma * (grad_obj_vector + lam_x_grad2)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history.append(total_loss.item())
    dk_history.append(torch.norm(grad_obj_vector + lam_x_grad2).item())
    infeas_history.append(max(0,total_loss_const.item()))
    slack_history.append(torch.abs(lam * total_loss_const.item()).item())

    if epoch % 100 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + lam_x_grad2):.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")


# GDMA/penalty baselines (rho=1,2,5,10)
def run_penalty_baseline(pen_param, epochs=epochs):
    model_p = SimpleNet().to(device)
    y_w = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
    w_w = torch.ones(m_samples, device=device, requires_grad=True) / m_samples

    loss_hist, infeas_hist, dk_hist = [], [], []
    for epoch in range(epochs):
        model_p.train()
        logits = model_p(X); loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')
        logitss = model_p(X_const); loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')

        weighted_loss = torch.dot(y_w, loss_per_sample)
        reg_term      = lambda_reg / n_samples * torch.sum((n_samples * y_w - 1) ** 2)
        total_loss    = weighted_loss - reg_term

        weighted_loss_const = torch.dot(w_w, loss_const_per_sample)
        reg_term_const      = lambda_reg / m_samples * torch.sum((m_samples * w_w - 1) ** 2)
        total_loss_const    = weighted_loss_const - reg_term_const - r

        grad_obj_x    = torch.autograd.grad(total_loss,       model_p.parameters(), create_graph=False)
        grad_const_x  = torch.autograd.grad(total_loss_const, model_p.parameters(), create_graph=False)
        grad_obj_vec   = torch.cat([g.view(-1) for g in grad_obj_x])
        grad_const_vec = torch.cat([g.view(-1) for g in grad_const_x])

        with torch.no_grad():
            param_vec = torch.cat([p.data.view(-1) for p in model_p.parameters()])
            updated_vec = param_vec - gamma * (grad_obj_vec + pen_param * grad_const_vec)
            ptr = 0
            for p in model_p.parameters():
                n = p.numel()
                p.data.copy_(updated_vec[ptr:ptr+n].view_as(p))
                ptr += n

        Nk = int(2*np.ceil(np.log(epoch+2)))
        for _ in range(Nk):
            grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_w - torch.ones(n_samples, device=device))
            with torch.no_grad():
                y_w.copy_(project_onto_simplex(safe_log_weights(y_w) + sigma * grad_obj_y))

        Mk = int(10*np.ceil(np.log(epoch+2)))
        for _ in range(Mk):
            grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_w - torch.ones(m_samples, device=device))
            with torch.no_grad():
                w_w.copy_(project_onto_simplex(safe_log_weights(w_w) + sigma * grad_const_w))

        loss_hist.append(total_loss.item())
        infeas_hist.append(max(0.0, total_loss_const.item()))
        dk_hist.append(torch.norm(grad_obj_vec + pen_param*grad_const_vec).item())

        if epoch % 1000 == 0 or epoch == epochs - 1:
            print(f"[GDMA ρ={pen_param}] Epoch {epoch}: ||dk||={dk_hist[-1]:.4f} Infeas={infeas_hist[-1]:.4f}", flush=True)

    return loss_hist, infeas_hist, dk_hist

loss_history2, infeas_history2, dk_history2 = run_penalty_baseline(1)

loss_history3, infeas_history3, dk_history3 = run_penalty_baseline(2)

loss_history4, infeas_history4, dk_history4 = run_penalty_baseline(5)

loss_history5, infeasibility5, dk_history5 = run_penalty_baseline(10)
# small rename
infeas_history5 = infeasibility5


#plots
def plot_series(metric_name, ys, labels, yscale='log', ylabel=None, fname=None):
    plt.figure()
    for y, lbl in zip(ys, labels):
        x = range(len(y))
        plt.plot(x, y, label=lbl, linewidth=2.0)
    if yscale:
        plt.yscale(yscale)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel(ylabel or metric_name, fontsize=14)
    plt.grid(True, which="both" if yscale == 'log' else "major", linestyle="--", alpha=0.7)
    plt.legend()
    save_and_show(fname or f"{metric_name.lower().replace(' ', '_')}_gdma.png")

labels_all = [
    "iDB-PD",
    "GDMA (ρ=1)",
    "GDMA (ρ=2)",
    "GDMA (ρ=5)",
    "GDMA (ρ=10)"
]

# Infeasibility
plot_series(
    "Infeasibility",
    [infeas_history, infeas_history2, infeas_history3, infeas_history4, infeas_history5],
    labels_all,
    yscale='log',
    ylabel=r'Infeasibility ($[\psi(x_k,w_k)]_+$)',
    fname="infeasibility_gdma.png"
)

# Stationarity
plot_series(
    "Stationarity",
    [dk_history, dk_history2, dk_history3, dk_history4, dk_history5],
    labels_all,
    yscale='log',
    ylabel=r'Stationarity ($\|d_k\|$)',
    fname="stationarity_gdma.png"
)

# Objective Training Loss
plot_series(
    "Objective Training Loss",
    [loss_history, loss_history2, loss_history3, loss_history4, loss_history5],
    labels_all,
    yscale=None,
    ylabel="Objective Training Loss",
    fname="objective_loss_gdma.png"
)

# Zip & download all plots together
!zip -r -q plots_gdma.zip plots_gdma
files.download('plots_gdma.zip')