# -*- coding: utf-8 -*-
"""idbGDMA_CHD49.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1m4dY2LoBhbdgwCP3a9ZfaS3NzfDdnJX2
"""

# iDB-PD + GDMA variants (rho = 1,2,5,10)
#CHD49 dataset

!pip -q install matplotlib numpy torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu121
!pip -q install git+https://github.com/cooper-org/cooper

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import scipy.io
import os
import time
from google.colab import files

import cooper
from cooper import ConstrainedMinimizationProblem, Constraint, ConstraintType
from cooper.optim import SimultaneousOptimizer
from cooper.multipliers import DenseMultiplier

#load in data from drive
#change as needed to load in data
from google.colab import drive
drive.mount('/content/drive')

mat_path = '/content/drive/MyDrive/CHD_49.mat'

# Load the .mat file
mat = scipy.io.loadmat(mat_path)

# Inspect keys
print("Available keys:", mat.keys())

torch.manual_seed(2025)
np.random.seed(2025)

#helper function to save plots later
def save_and_show(fname: str):
    # Save only (no plt.show, no per-file download)
    os.makedirs("plots_gdma", exist_ok=True)
    out_path = os.path.join("plots_gdma", fname)
    plt.tight_layout()
    plt.savefig(out_path, dpi=300, bbox_inches="tight")
    plt.close()
    print(f"Saved: {out_path}")

# Assuming data is stored under 'X' and 'Y'
X = mat['data']
Y = mat['targets']

# Convert from -1/1 to 0/1
Y_binary = (Y + 1) // 2

# Convert to two multiclass labels
label_A = np.argmax(Y_binary[:, :3], axis=1)
label_B = np.argmax(Y_binary[:, 3:], axis=1)
n_samples = label_A.shape[0]
m_samples = label_A.shape[0]
# Convert to PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
label_A_tensor = torch.tensor(label_A, dtype=torch.long)
label_B_tensor = torch.tensor(label_B, dtype=torch.long)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

X = torch.tensor(X, dtype=torch.float32).to(device)

y_true = torch.tensor(label_A, dtype=torch.long).to(device)

X_const = torch.tensor(X, dtype=torch.float32).to(device)

y_true_const = torch.tensor(label_B, dtype=torch.long).to(device)


# Config
n_classes = 3
input_dim = X.shape[1]
hidden_dim = 60
lambda_reg = 1e-3
gamma0 = 0.001
gamma = 0.001/22
alpha = 0.25/gamma
epochs = 3000
sigma = 0.4/(lambda_reg*max(m_samples,n_samples))


# Model
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, n_classes)

    def forward(self, x):
        x = torch.tanh(self.fc1(x))
        return self.fc2(x)

model0 = SimpleNet().to(device)
# Sample weights y on the simplex — treat as parameter
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples

# Optimizers
loss_history0 = []
gradpsi_history = []
# Projection onto simplex (Euclidean)
def project_onto_simplex(z: torch.Tensor, dim: int = -1) -> torch.Tensor:#(v):
    return torch.softmax(z, dim = dim) #w

# Pre-training to find the threshold r for the constraint
for epoch in range(int(epochs/2)):
    model0.train()

    # Forward pass for constraint
    logitss = model0(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const

    # Backward for x
    grad_const_x = torch.autograd.grad(total_loss_const, model0.parameters(), retain_graph=False, create_graph=False)

    with torch.no_grad():
        for param, grad2 in zip(model0.parameters(), grad_const_x):
            param -= gamma0 * grad2

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))

    loss_history0.append(total_loss_const.item())
    gradpsi_history.append(torch.norm(grad2))

    if epoch % 100 == 0 or epoch == int(epochs/2) - 1:
        print(f"Epoch {epoch}: loss = {total_loss_const.item():.4f} norm_grad = {torch.norm(grad2):.4f}")

r = np.abs(total_loss_const.item())
model = SimpleNet().to(device)
# Sample weights y on the simplex
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples

# Optimizers
loss_history = []
dk_history = []
infeas_history = []
slack_history = []

#idb-PD
for epoch in range(epochs):
    model.train()
    #optimizer_y.zero_grad()

    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        norm_grad2 = torch.norm(grad_const_vector)
        if max(0, total_loss_const.item()) * norm_grad2 > 0:
            lam = torch.max(-torch.dot(grad_const_vector , grad_obj_vector) + alphak * norm_grad2, 0)
            lam = lam[0]
            lam_x_grad2 = (lam * grad_const_vector) / (norm_grad2 ** 2)
        else:
            lam = torch.tensor(0.0)
            lam_x_grad2 = torch.zeros_like(grad_const_vector)
        updated_vector = param_vector - gamma * (grad_obj_vector + lam_x_grad2)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history.append(total_loss.item())
    dk_history.append(torch.norm(grad_obj_vector + lam_x_grad2).item())
    infeas_history.append(max(0,total_loss_const.item()))
    slack_history.append(torch.abs(lam * total_loss_const.item()).item())

    if epoch % 100 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + lam_x_grad2):.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")


# GDMA/penalty baselines (rho=1,2,5,10)

# Penalty method
model = SimpleNet().to(device)
# Sample weights y on the simplex
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples
loss_history2 = []
infeas_history2 = []
dk_history2 = [];
pen_param = 1;
for epoch in range(epochs):
    model.train()
    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')  # L_i(x)

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        updated_vector = param_vector - gamma * (grad_obj_vector + pen_param*grad_const_vector)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history2.append(total_loss.item())
    infeas_history2.append(max(0,total_loss_const.item()))
    dk_history2.append(torch.norm(grad_obj_vector + pen_param*grad_const_vector).item())

    if epoch % 1000 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + pen_param*grad_const_vector):.4f} loss_obj = {total_loss.item():.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")


# Penalty method
model = SimpleNet().to(device)
# Sample weights y on the simplex
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples
loss_history3 = []
infeas_history3 = []
dk_history3 = []
pen_param = 2;
for epoch in range(epochs):
    model.train()
    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')  # L_i(x)

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        updated_vector = param_vector - gamma * (grad_obj_vector + pen_param*grad_const_vector)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history3.append(total_loss.item())
    infeas_history3.append(max(0,total_loss_const.item()))
    dk_history3.append(torch.norm(grad_obj_vector + pen_param*grad_const_vector).item())

    if epoch % 1000 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + pen_param*grad_const_vector):.4f} loss_obj = {total_loss.item():.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")


# Penalty method
model = SimpleNet().to(device)
# Sample weights y on the simplex — treat as parameter
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
#y_weights = nn.Parameter(y_weights)
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples
loss_history4 = []
infeas_history4 = []
dk_history4 = []
pen_param = 5;
for epoch in range(epochs):
    model.train()
    #optimizer_y.zero_grad()

    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')  # L_i(x)

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        updated_vector = param_vector - gamma * (grad_obj_vector + pen_param*grad_const_vector)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history4.append(total_loss.item())
    infeas_history4.append(max(0,total_loss_const.item()))
    dk_history4.append(torch.norm(grad_obj_vector + pen_param*grad_const_vector).item())

    if epoch % 1000 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + pen_param*grad_const_vector):.4f} loss_obj = {total_loss.item():.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")

# Penalty method
model = SimpleNet().to(device)
# Sample weights y on the simplex — treat as parameter
y_weights = torch.ones(n_samples, device=device, requires_grad=True) / n_samples
#y_weights = nn.Parameter(y_weights)
w_weights = torch.ones(m_samples, device=device, requires_grad=True) / m_samples
loss_history5 = []
infeas_history5 = []
dk_history5 = []
pen_param = 10;
for epoch in range(epochs):
    model.train()
    #optimizer_y.zero_grad()

    # Forward pass for objective
    logits = model(X)
    loss_per_sample = F.cross_entropy(logits, y_true, reduction='none')  # L_i(x)

    # Forward pass for constraint
    logitss = model(X_const)
    loss_const_per_sample = F.cross_entropy(logitss, y_true_const, reduction='none')  # L_i(x)

    # DRO objective
    weighted_loss = torch.dot(y_weights, loss_per_sample)
    reg_term = lambda_reg / n_samples * torch.sum((n_samples * y_weights - 1) ** 2)
    total_loss = weighted_loss - reg_term

    # DRO constraint
    weighted_loss_const = torch.dot(w_weights, loss_const_per_sample)
    reg_term_const = lambda_reg / m_samples * torch.sum((m_samples * w_weights - 1) ** 2)
    total_loss_const = weighted_loss_const - reg_term_const - r

    # Backward for x
    grad_obj_x = torch.autograd.grad(total_loss, model.parameters(), create_graph=False)
    grad_const_x = torch.autograd.grad(total_loss_const, model.parameters(), create_graph=False)
    alphak = alpha / ((epoch+2)**1.001)
    grad_obj_vector = torch.cat([g.view(-1) for g in grad_obj_x])
    grad_const_vector = torch.cat([g.view(-1) for g in grad_const_x])

    with torch.no_grad():
        param_vector = torch.cat([p.data.view(-1) for p in model.parameters()])
        updated_vector = param_vector - gamma * (grad_obj_vector + pen_param*grad_const_vector)

        pointer = 0
        for p in model.parameters():
            numel = p.numel()
            p.data.copy_(updated_vector[pointer:pointer+numel].view_as(p))
            pointer += numel

    # Backward for y (gradient ascent)
    Nk = int(2*np.ceil(np.log(epoch+2)))
    for i in range(Nk):
        grad_obj_y = loss_per_sample - lambda_reg*(n_samples * y_weights - torch.ones(n_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #y_weights += sigma * grad_obj_y  # ascent step
            y_weights.copy_(project_onto_simplex(torch.log(y_weights) + sigma * grad_obj_y))  # project onto simplex

    # Backward for w (gradient ascent)
    Mk = int(10*np.ceil(np.log(epoch+2)))
    for i in range(Mk):
        grad_const_w = loss_const_per_sample - lambda_reg*(m_samples * w_weights - torch.ones(m_samples,device=device))#torch.autograd.grad(total_loss, y_weights, create_graph=False)[0]
        with torch.no_grad():
            #w_weights += sigma * grad_const_w  # ascent step
            w_weights.copy_(project_onto_simplex(torch.log(w_weights) + sigma * grad_const_w))  # project onto simplex

    loss_history5.append(total_loss.item())
    infeas_history5.append(max(0,total_loss_const.item()))
    dk_history5.append(torch.norm(grad_obj_vector + pen_param*grad_const_vector).item())

    if epoch % 1000 == 0 or epoch == epochs - 1:
        print(f"Epoch {epoch}: norm_dk = {torch.norm(grad_obj_vector + pen_param*grad_const_vector):.4f} loss_obj = {total_loss.item():.4f} Feasibility = {max(0,total_loss_const.item()):.4f}")


#plots

# Infeasibility
plt.figure()
plt.plot(range(epochs), infeas_history,  label=r'iDB-PD', linewidth=2.5)
plt.plot(range(epochs), infeas_history2, label=r'GDMA on weighted MTL ($\rho=1$)', linewidth=2.5)
plt.plot(range(epochs), infeas_history3, label=r'GDMA on weighted MTL ($\rho=2$)', linewidth=2.5)
plt.plot(range(epochs), infeas_history4, label=r'GDMA on weighted MTL ($\rho=5$)', linewidth=2.5)
plt.plot(range(epochs), infeas_history5, label=r'GDMA on weighted MTL ($\rho=10$)', linewidth=2.5)
plt.yscale("log")
plt.xlabel("Iteration", fontsize=14)
plt.ylabel(r'Infeasibility ($[\psi(x_k,w_k)]_+$)', fontsize=14)
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
save_and_show("infeasibility_gdma.png")

# Stationarity
plt.figure()
plt.plot(range(epochs), dk_history,  label=r'iDB-PD', linewidth=2.5)
plt.plot(range(epochs), dk_history2, label=r'GDMA on weighted MTL ($\rho=1$)', linewidth=2.5)
plt.plot(range(epochs), dk_history3, label=r'GDMA on weighted MTL ($\rho=2$)', linewidth=2.5)
plt.plot(range(epochs), dk_history4, label=r'GDMA on weighted MTL ($\rho=5$)', linewidth=2.5)
plt.plot(range(epochs), dk_history5, label=r'GDMA on weighted MTL ($\rho=10$)', linewidth=2.5)
plt.yscale("log")
plt.xlabel("Iteration", fontsize=14)
plt.ylabel(r'Stationarity', fontsize=14)
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
save_and_show("stationarity_gdma.png")

# Objective Training Loss
plt.figure()
plt.plot(range(epochs), loss_history,  label=r'iDB-PD', linewidth=2.5)
plt.plot(range(epochs), loss_history2, label=r'GDMA on weighted MTL ($\rho=1$)', linewidth=2.5)
plt.plot(range(epochs), loss_history3, label=r'GDMA on weighted MTL ($\rho=2$)', linewidth=2.5)
plt.plot(range(epochs), loss_history4, label=r'GDMA on weighted MTL ($\rho=5$)', linewidth=2.5)
plt.plot(range(epochs), loss_history5, label=r'GDMA on weighted MTL ($\rho=10$)', linewidth=2.5)
plt.xlabel("Iteration", fontsize=14)
plt.ylabel("Objective Training Loss", fontsize=14)
plt.grid(True, linestyle="--", alpha=0.7)
plt.legend()
save_and_show("objective_loss_gdma.png")

# Zip & download
!zip -r -q plots_gdma.zip plots_gdma
files.download('plots_gdma.zip')

!cat /proc/cpuinfo | grep 'model name' | uniq

!cat /proc/meminfo | grep MemTotal

!lsb_release -a   # Ubuntu version
!python --version