# -*- coding: utf-8 -*-
"""idbdisc_yeast.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17tCi4amj-3fKT-765Tm2L4KkKJWMVWgY
"""

#iDB-PD, AD (Yeast)

!pip -q install matplotlib numpy pandas liac-arff torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install git+https://github.com/cooper-org/cooper

import os, time, numpy as np, pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt
import arff
from google.colab import drive, files

import cooper
from cooper import ConstrainedMinimizationProblem, Constraint, ConstraintType
from cooper.optim import SimultaneousOptimizer
from cooper.multipliers import DenseMultiplier

#load via drive, change as needed
drive.mount('/content/drive')

arff_path = '/content/drive/MyDrive/yeast.arff'

with open(arff_path, 'r') as f:
    dataset = arff.load(f)

df = pd.DataFrame(dataset['data'], columns=[a[0] for a in dataset['attributes']])
data_np = df.to_numpy().astype(np.float32)

#plot helper function
def save_and_show(fname: str):
    # Save only (no plt.show, no per-file download)
    os.makedirs("plots_ad", exist_ok=True)
    out_path = os.path.join("plots_ad", fname)
    plt.tight_layout()
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_path}")

# YEAST data
X1 = data_np[:, :103]
Y1 = data_np[:, 103:117]

label_A = np.argmax(Y1[:, :7], axis=1)
label_B = np.argmax(Y1[:, 7:], axis=1)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
torch.manual_seed(123)
np.random.seed(123)

# tensors
X = torch.tensor(X1, dtype=torch.float32, device=device)
y_true        = torch.tensor(label_A, dtype=torch.long, device=device)
y_true_const  = torch.tensor(label_B, dtype=torch.long, device=device)
X_const = X

n_samples = X.shape[0]
m_samples = X.shape[0]


n_classes  = 7
input_dim  = X.shape[1]
hidden_dim = 150
lambda_reg = 1e-3
gamma0     = 0.01   # pre-train step for x on constraint
gamma      = 0.001/6 # main step for x
alpha      = 0.06/ gamma
epochs     = 3000
sigma      = 0.1/(lambda_reg * max(m_samples, n_samples))  # dual steps (y,w)


# Model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return self.fc2(x)

model0 = SimpleNet().to(device)
# Sample weights y on the simplex
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples


# Optimizers
loss_history0 = []
gradpsi_history = []
# Projection onto simplex (Euclidean)
def project_onto_simplex(z: torch.Tensor, dim: int = -1) -> torch.Tensor:#(v):
    return torch.softmax(z, dim = dim) #w

# Pre-training to find the threshold r
start_time = time.time()
for epoch in range(int(epochs/2)):
    model0.train()

    # Forward pass for constraint
    logitss = model0(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const

    # Backward for x
    grad_const_x = torch.autograd.grad(total_loss_const, model0.parameters(), retain_graph=False, create_graph=False)

    with torch.no_grad():
        for param, grad2 in zip(model0.parameters(), grad_const_x):
            param -= gamma0 * grad2

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history0.append(total_loss_const.item())
    gradpsi_history.append(torch.norm(grad2))

    if epoch % 100 == 0 or epoch == int(epochs/2) - 1:
        print(f"Epoch {epoch}: loss = {total_loss_const.item():.4f} norm_grad = {torch.norm(grad2):.4f}")  #max(0,total_loss_const.item())

runtime_pre = time.time() - start_time
r = np.abs(total_loss_const.item())
model = SimpleNet().to(device)
# Sample weights y on the simplex
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples


# Optimizers
loss_history = []
dk_history = []
infeas_history = []
slack_history = []
for epoch in range(epochs):
    model.train()
    #optimizer_y.zero_grad()

    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')  # L_i(x)

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        norm_grad2 = torch.norm(grad_const_vector)
        if max(0, total_loss_const.item()) * norm_grad2 > 0:
            lam = torch.max(-torch.dot(grad_const_vector , grad_obj_vector) + alphak * norm_grad2, 0)
            lam = lam[0]
            lam_x_grad2 = (lam * grad_const_vector) / (norm_grad2 ** 2)
        else:
            lam = torch.tensor(0.0)
            lam_x_grad2 = torch.zeros_like(grad_const_vector)
        updated_vector = param_vector - gamma * (grad_obj_vector + lam_x_grad2)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history.append(total_loss.item())
    dk_history.append(torch.norm(grad_obj_vector + lam_x_grad2).item())
    infeas_history.append(max(0,total_loss_const.item()))
    slack_history.append(torch.abs(lam * total_loss_const.item()).item())

    if epoch % 100 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + lam_x_grad2):.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")


#adaptive discretization (COOPER)
class DiscreteMinimaxCMP(ConstrainedMinimizationProblem):
    def __init__(self, model, X, y_true, X_const, y_const, Y_set, W_set, r):
        super().__init__()
        self.model = model
        self.X, self.y_true = X, y_true
        self.Xc, self.yc = X_const, y_const
        self.Y_set, self.W_set = Y_set, W_set
        self.r = r
        self.constraint = Constraint(
            multiplier=DenseMultiplier(num_constraints=1, device=X.device),
            constraint_type=ConstraintType.INEQUALITY
        )
    def compute_cmp_state(self, model=None, inputs=None, targets=None):
        logits   = self.model(self.X)
        phi_vals = F.cross_entropy(logits, self.y_true, reduction='none')
        phi      = torch.max(phi_vals[self.Y_set])
        logits_c = self.model(self.Xc)
        psi_vals = F.cross_entropy(logits_c, self.yc,     reduction='none')
        psi      = torch.max(psi_vals[self.W_set]) - self.r
        constraint_state = cooper.ConstraintState(violation=psi)
        return cooper.CMPState(loss=phi, observed_constraints={self.constraint: constraint_state})

def adaptive_discretization_solver(model, X, y_true, Xc, yc, r, max_iters=3000, eps=1e-3, lr=1e-3):
    n_samples, m_samples = X.shape[0], Xc.shape[0]
    dk_history_c, infeas_history_c, slack_history_c, loss_history_c = [], [], [], []
    Y_set = [torch.randint(0, n_samples, (1,)).item()]
    W_set = [torch.randint(0, m_samples, (1,)).item()]
    history = []
    for k in range(max_iters):
        cmp = DiscreteMinimaxCMP(model, X, y_true, Xc, yc, Y_set, W_set, r)
        primal_opt = torch.optim.Adam(model.parameters(), lr=lr)
        dual_opt   = torch.optim.SGD(cmp.dual_parameters(), lr=lr, maximize=True)
        coop_opt   = SimultaneousOptimizer(cmp, primal_opt, dual_opt)
        coop_opt.roll(compute_cmp_state_kwargs={})

        model.eval()
        with torch.no_grad():
            full_loss  = F.cross_entropy(model(X),  y_true, reduction='none')
            full_const = F.cross_entropy(model(Xc), yc,     reduction='none')

        y_new  = torch.argmax(full_loss).item()
        w_new  = torch.argmax(full_const).item()
        phi_new = full_loss[y_new].item()
        psi_new = full_const[w_new].item()

        if y_new not in Y_set: Y_set.append(y_new)
        if w_new not in W_set: W_set.append(w_new)

        phi_old = torch.max(full_loss[Y_set]).item()
        psi_old = torch.max(full_const[W_set]).item()

        with torch.no_grad():
            grads = [p.grad.view(-1) for p in model.parameters() if p.grad is not None]
            dk = torch.norm(torch.cat(grads)).item() if grads else 0.0

        infeas = max(0.0, psi_old - r)
        lambda_k = cmp.constraint.multiplier().item()
        slack = abs(lambda_k * (psi_old - r))

        dk_history_c.append(dk)
        infeas_history_c.append(infeas)
        slack_history_c.append(slack)
        loss_history_c.append(phi_old)

        outer_phi = phi_new - phi_old
        outer_psi = psi_new - r

        if k % 1000 == 0:
            print(f"[Adaptive Discretization] Iter {k}: phi_max={phi_old:.4f}, psi_max={psi_old:.4f}, "
                  f"outer_phi={outer_phi:.4f}, outer_psi={outer_psi:.4f}", flush=True)

        if outer_phi <= eps and outer_psi <= eps:
            print("[Adaptive Discretization] Converged at iteration", k)
            break

    return model, Y_set, W_set, dk_history_c, infeas_history_c, slack_history_c, loss_history_c

model_ad = SimpleNet().to(device)
model_ad, Y_set, W_set, dk_history_ad, infeas_history_ad, slack_history_ad, loss_history_ad = adaptive_discretization_solver(
    model_ad, X, y_true, X_const, y_true_const, r, max_iters=3000, eps=1e-3, lr=1e-3
)

#plots
def plot_series(metric_name, ys, labels, yscale='log', ylabel=None, fname=None):
    plt.figure()
    for y, lbl in zip(ys, labels):
        x = range(len(y))
        plt.plot(x, y, label=lbl, linewidth=2.0)
    if yscale:
        plt.yscale(yscale)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel(ylabel or metric_name, fontsize=14)
    plt.grid(True, which="both" if yscale == 'log' else "major", linestyle="--", alpha=0.7)
    plt.legend()
    save_and_show(fname or f"{metric_name.lower().replace(' ', '_')}_ad.png")

# Infeasibility
plot_series(
    "Infeasibility",
    [infeas_history, infeas_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Infeasibility ($[\psi(x_k,w_k)]_+$)',
    fname="infeasibility_ad.png"
)

# Stationarity
plot_series(
    "Stationarity",
    [dk_history, dk_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Stationarity ($\|d_k\|$)',
    fname="stationarity_ad.png"
)

# Slackness
plot_series(
    "Slackness",
    [slack_history, slack_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale='log',
    ylabel=r'Slackness ($|\lambda_k \cdot \psi(x_k,w_k)|$)',
    fname="slackness_ad.png"
)

# Objective Training Loss
plot_series(
    "Objective Training Loss",
    [loss_history, loss_history_ad],
    ["iDB-PD", "Adaptive Discretization"],
    yscale=None,
    ylabel="Objective Training Loss",
    fname="objective_loss_ad.png"
)

# Zip & download all plots together
!zip -r -q plots_ad.zip plots_ad
files.download('plots_ad.zip')