import numpy as np
import matplotlib.pyplot as plt

import math
import pathlib
import random
import sklearn.linear_model
import torch
import torch.nn as nn
from torch.nn.modules.transformer import MultiheadAttention
from torch.optim.lr_scheduler import LambdaLR
import time
import sys
import logging

# ---------------- Logging Setup ---------------- #
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("2D_DVA_training.log", mode="w"),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# ---------------- Device Setup ---------------- #
if torch.cuda.is_available():
    device = torch.device("cuda")
    logger.info(f"Using GPU: {device}")
else:
    device = torch.device("cpu")
    logger.warning("No GPU detected -> using CPU")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

import torch
import matplotlib.pyplot as plt

def rbf_kernel(x1, x2, lengthscale=0.2, variance=1.0):
    """
    Compute the RBF kernel matrix between x1 and x2 for multi-dimensional inputs.
    x1: (N, D)
    x2: (M, D)
    Returns:
        (N, M) kernel matrix
    """
    diff = x1.unsqueeze(1) - x2.unsqueeze(0)  # (N, M, D)
    dist_sq = (diff ** 2).sum(-1)  # (N, M)
    return variance * torch.exp(-0.5 * dist_sq / lengthscale**2)

def get_gp_prior(num_datasets, num_features, num_points_in_each_dataset, hyperparameters):
    """
    Generate synthetic data from a multivariate input GP prior.
    Outputs are shifted to match the voltage range (e.g., 0.9–1.2).
    """
    lengthscale = hyperparameters.get('lengthscale', 0.2)
    kernel_variance = hyperparameters.get('kernel_variance', 0.01)  # Lower variance for smooth voltage-like signals
    output_noise = hyperparameters.get('output_noise', 0.01)
    mean_shift = hyperparameters.get('mean_shift', 1.05)  # Shifted to mid-range of expected voltages

    xs = torch.rand(num_datasets, num_points_in_each_dataset, num_features)
    ys = []

    for i in range(num_datasets):
        x = xs[i]  # (N, D)
        K = rbf_kernel(x, x, lengthscale, kernel_variance)  # (N, N)
        K += output_noise**2 * torch.eye(num_points_in_each_dataset)
        y = torch.distributions.MultivariateNormal(
            torch.full((num_points_in_each_dataset,), mean_shift),
            K
        ).sample()
        ys.append(y)

    ys = torch.stack(ys, dim=0)  # (B, N)
    return xs, ys

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.884109Z","iopub.execute_input":"2025-08-01T11:59:25.884413Z","iopub.status.idle":"2025-08-01T11:59:25.904102Z","shell.execute_reply.started":"2025-08-01T11:59:25.884390Z","shell.execute_reply":"2025-08-01T11:59:25.903403Z"}}
class PriorDataLoader(object):
    def __init__(self, get_prior_data_fn, batch_size, num_points_in_each_dataset):
        self.batch_size = batch_size
        self.num_points_in_each_dataset = num_points_in_each_dataset
        self.get_prior_data_fn = get_prior_data_fn
        
        self.epoch = 0
        
    def get_batch(self, train = True, batch_size=None):
        """
        Returns:
            xs, ys, trainset_size
        """
        self.epoch += train
        bs = batch_size if batch_size else self.batch_size
        return self.get_prior_data_fn(bs, self.num_points_in_each_dataset), self._sample_trainset_size()

    def _sample_trainset_size(self):
        # samples a number between 1 and n-1 with higher weights to larger numbers
        # Appendix E.1 of Muller et al. (2021)
        min_samples = 1
        max_samples = self.num_points_in_each_dataset - 1
        
        sampling_weights = [1 / (max_samples - min_samples - i) for i in range(max_samples - min_samples)]
        return random.choices(range(min_samples, max_samples), sampling_weights)[0]

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.905810Z","iopub.execute_input":"2025-08-01T11:59:25.906047Z","iopub.status.idle":"2025-08-01T11:59:25.947429Z","shell.execute_reply.started":"2025-08-01T11:59:25.906032Z","shell.execute_reply":"2025-08-01T11:59:25.946672Z"}}
def get_bucket_limts(num_outputs, ys):
    """
    Creates buckets based on the values in y. 
    
    Args:
        num_outputs: number of buckets to create
        ys: values of y in the prior
    
    Returns:
        bucket_limits: An array containing the borders for each bucket. 
    """
    ys  = ys.flatten()

    if len(ys) % num_outputs:
        ys = ys[:-(len(ys) % num_outputs)]

    ys_per_bucket = len(ys) // num_outputs
    full_range = (ys.min(), ys.max())

    ys_sorted, _ = ys.sort(0)

    bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1] + ys_sorted[ys_per_bucket::ys_per_bucket]) / 2
    bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], dim=0)
    return bucket_limits


# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.948108Z","iopub.execute_input":"2025-08-01T11:59:25.948292Z","iopub.status.idle":"2025-08-01T11:59:25.961285Z","shell.execute_reply.started":"2025-08-01T11:59:25.948277Z","shell.execute_reply":"2025-08-01T11:59:25.960608Z"}}
def y_to_bucket_idx(ys, bl):
    """
    Maps the value of y to the corresponding bucket in `bl`
    
    Args:
        ys: value of y to be mapped into a bucket in bl
        bl: bucket limits specifiying the borders of buckets
    
    Returns:
        values of corresponding bucket number for y in bl
    """
    target_sample = torch.searchsorted(bl, ys) - 1
    target_sample[ys <= bl[0]] = 0
    target_sample[ys >= bl[-1]] = len(bl) - 1 - 1
    return target_sample
# --- new encoder block ---
class GPEncoder(nn.Module):
    """GP-style attention: Q/K from x, V from y, with residuals, LN, pos enc and a learnable temperature."""
    def __init__(self, d_model, n_heads, n_hidden, dropout=0.0, use_pos=False):
        super().__init__()
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.ffn  = nn.Sequential(
            nn.Linear(d_model, n_hidden), nn.GELU(), nn.Dropout(dropout),
            nn.Linear(n_hidden, d_model),
        )

        # Norms
        self.ln_qk = nn.LayerNorm(d_model)   # pre-norm for Q/K
        self.ln_v  = nn.LayerNorm(d_model)   # pre-norm for V
        self.ln_att = nn.LayerNorm(d_model)  # post-attn norm (residual)
        self.ln_ff  = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

        # small learnable attention temperature (helps avoid saturation)
        self.attn_scale = nn.Parameter(torch.tensor(1.0 / (d_model // n_heads) ** 0.5))

        # optional positional encoding (learnable)
        self.use_pos = use_pos
        if self.use_pos:
            self.pos_proj = nn.Linear(d_model, d_model)  # simple learnable pos embed projection

    def add_pos(self, x, pos):
        # pos: (B, L, D) or (L, D) - we just project and add
        if pos is None:
            return x
        return x + self.pos_proj(pos)

    def forward(self, x_left, x_right, y_left, pos_left=None, pos_right=None, debug=False):
        # x_left: (B,T,D), x_right: (B, M, D), y_left: (B,T,D)
        # optionally add positional encodings if provided
        if self.use_pos:
            x_left = self.add_pos(x_left, pos_left)
            x_right = self.add_pos(x_right, pos_right)

        # Pre-norm for Q/K and V
        k_left = self.ln_qk(x_left)
        q_left = k_left
        v_left = self.ln_v(y_left)

        # --- train self-attention: Q=K=train, V=Y_train ---
        attn_out_left, attn_weights_left = self.attn(
            query=q_left, key=k_left, value=v_left,
            need_weights=True, average_attn_weights=False
        )
        # residual + norm
        left_ctx = self.ln_att(x_left + attn_out_left)

        # --- test cross-attention: Q=test, K=train, V=Y_train ---
        q_right = self.ln_qk(x_right)
        attn_out_right, attn_weights_right = self.attn(
            query=q_right, key=k_left, value=v_left,
            need_weights=True, average_attn_weights=False
        )
        right_ctx = self.ln_att(x_right + attn_out_right)

        # concat back
        h = torch.cat([left_ctx, right_ctx], dim=1)  # (B, T+M, D)
        h = h + self.dropout(self.ffn(self.ln_ff(h)))


        if debug:
            return h, (attn_weights_left, attn_weights_right)
        return h

class Transformer(nn.Module):
    def __init__(self, num_features, n_out, n_layers=2, d_model=512, n_heads=4,
                 n_hidden=1024, dropout=0.0, normalize=lambda x:x):
        super().__init__()
        self.x_encoder = nn.Linear(num_features, d_model)   # encodes x only
        self.y_encoder = nn.Linear(1, d_model)              # encodes y only
        self.model = nn.ModuleList([GPEncoder(d_model, n_heads, n_hidden, dropout)
                                    for _ in range(n_layers)])
        self.out = nn.Sequential(nn.Linear(d_model, n_hidden), nn.GELU(),
                                 nn.Linear(n_hidden, n_out))
        self.normalize = normalize
        self.init_weights()

    def forward(self, x, y, trainset_size):
        x_enc = self.x_encoder(self.normalize(x))
        y_enc = self.y_encoder(y)

        x_left,  x_right = x_enc[:, :trainset_size],  x_enc[:, trainset_size:]
        y_left           = y_enc[:, :trainset_size]   # only train labels participate as values

        h = torch.cat([x_left, x_right], dim=1)  # just to keep shapes consistent
        for layer in self.model:
            h = layer(x_left, x_right, y_left)
            # re-split for next layer
            x_left, x_right = h[:, :trainset_size], h[:, trainset_size:]

        return self.out(h)

    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)


def get_halfnormal_with_p_weight_before(range_max, p=0.5):
    """
    Constructs a half normal distribution. 
    Args:
        range_max: used for scaling the half normal so that `p` portion of the half normal lies within range_max
        p: Cumulative probability under half normal

    Returns:
        half normal distribution
    """
    s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
    return torch.distributions.HalfNormal(s)

def compute_bar_distribution_loss(logits, target_y, bucket_limits, label_smoothing=0.0):
    """
    Implements Reimann distribution for logits. See Appendix E of Muller et al. 2021.
    
    Args:
        logits: num_datasets  x num_points_in_each_dataset x num_outputs_for_classification
        target_y: target class for each point 
        bucket_limits: border limits for each class
        label_smoothing: constant to define the amount of label smoothing to be used.
    
    Returns:
        loss: scalar value
    
    """
    target_y_idx = y_to_bucket_idx(target_y, bucket_limits)

    bucket_widths = bucket_limits[1:] - bucket_limits[:-1]
    bucket_log_probs = torch.log_softmax(logits, -1)
    scaled_bucket_log_probs = bucket_log_probs - torch.log(bucket_widths) # Refer to the equation above
    log_probs = scaled_bucket_log_probs.gather(-1, target_y_idx[..., None]).squeeze(-1)
    
    # full support distribution correction using half normals
    side_normals = (
        get_halfnormal_with_p_weight_before(bucket_widths[0]),
        get_halfnormal_with_p_weight_before(bucket_widths[-1])
    )
    # Correction for the bucket in the starting
    first_bucket = target_y_idx == 0
    log_probs[first_bucket] += side_normals[0].log_prob((bucket_limits[1] - target_y[first_bucket])).clamp(min=1e-8) + torch.log(bucket_widths[0])

    # Correction for the bucket at the end
    last_bucket = target_y_idx == len(bucket_widths) - 1
    log_probs[last_bucket] += side_normals[1].log_prob((target_y[last_bucket] - bucket_limits[-2])).clamp(min=1e-8) + torch.log(bucket_widths[-1])
    
    nll_loss = -log_probs
    smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
    
    loss = (1-label_smoothing) * nll_loss + label_smoothing * smooth_loss
    return loss.mean()

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.998299Z","iopub.execute_input":"2025-08-01T11:59:25.998484Z","iopub.status.idle":"2025-08-01T11:59:26.016106Z","shell.execute_reply.started":"2025-08-01T11:59:25.998468Z","shell.execute_reply":"2025-08-01T11:59:26.015490Z"}}
# copied from huggingface
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        # After warmup, decay by 0.1 every 100 steps
        decay_steps = (current_step - num_warmup_steps) // 4000
        return 0.1 ** decay_steps

    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:26.016989Z","iopub.execute_input":"2025-08-01T11:59:26.017267Z","iopub.status.idle":"2025-08-01T12:01:06.593889Z","shell.execute_reply.started":"2025-08-01T11:59:26.017234Z","shell.execute_reply":"2025-08-01T12:01:06.593028Z"}}
# Parameters related to prior data
num_points_in_each_dataset = 100
hyperparameters = {
    'lengthscale': 0.6,
    'kernel_variance': 0.01,
    'output_noise':1e-2,
    'mean_shift': 1.0
}
bias_term = True
num_features = 2
num_outputs = 100
max_num_datasets = 16
logger.info("  GP hyperparameters:")
for k, v in hyperparameters.items():
    logger.info(f"    {k}: {v}")
# Reimannian distribution
_, ys = get_gp_prior(1000, num_features, num_points_in_each_dataset, hyperparameters) #just calibration or randomly sampling 10000 y to get bins
bucket_limits = get_bucket_limts(num_outputs, ys)

# training data
get_prior_data_fn = lambda x, y: get_gp_prior(x, num_features, y, hyperparameters)
dl = PriorDataLoader(get_prior_data_fn, batch_size=max_num_datasets, num_points_in_each_dataset=num_points_in_each_dataset)

# Models
uniform_normalize = lambda x: (x-0.5)/math.sqrt(1/12)

# Learning
epochs = 100
steps_per_epoch = 500
warmup_epochs = epochs // 4
validate_epoch = 10
lr = 0.001
bucket_limits = bucket_limits.to(device)

logger.info("Hyperparameters:")
logger.info(f"  epochs: {epochs}")
logger.info(f"  steps_per_epoch: {steps_per_epoch}")
logger.info(f"  warmup_epochs: {warmup_epochs}")
logger.info(f"  validate_epoch: {validate_epoch}")
logger.info(f"  learning_rate: {lr}")
logger.info(f"  bucket_limits: {bucket_limits}")
logger.info(f"  num_points_in_each_dataset: {num_points_in_each_dataset}")
logger.info(f"  bias_term: {bias_term}")
logger.info(f"  num_features: {num_features}")
logger.info(f"  num_outputs: {num_outputs}")
logger.info(f"  max_num_datasets: {max_num_datasets}")
logger.info(f"  calibration_size: 1000")

# test data
(test_xs, test_ys), test_trainset_size = dl.get_batch(False, 64)
test_xs, test_ys = test_xs.to(device), test_ys.to(device)
torch.save(bucket_limits.cpu(), "bucket_limits_2d_100.pth") # save bucket limits to be used during inference

# # will take time in case of gp prior as we r calibrating the bins through 10000 datasets which is highly computational in case of GP

import torch
import torch.nn as nn
import torch.nn.functional as F

class GPAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, Q, K, V):
        scale = 1.0 / (Q.size(-1) ** 0.5)
        A = self.softmax(torch.bmm(Q, K.transpose(1, 2)) * scale)
        return torch.bmm(A, V)


class CNN_Attention_GPAligned(nn.Module):
    def __init__(self, d_model, kernel_size=5):
        super().__init__()
        self.conv = nn.Conv1d(d_model, d_model, kernel_size,
                              padding=kernel_size//2, groups=d_model, bias=False)

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_vy = nn.Linear(d_model, d_model)

        self.attn = GPAttention(d_model)

        self.ln_qk = nn.LayerNorm(d_model)
        self.ln_v  = nn.LayerNorm(d_model)
        self.ln_out = nn.LayerNorm(d_model)

        # tiny FFN (optional)
        self.ff = nn.Sequential(nn.Linear(d_model, d_model), nn.GELU())

    def _smooth_x(self, x):
        return self.conv(x.transpose(1, 2)).transpose(1, 2)

    def forward(self, x_left, x_right, y_left):
        xl = self._smooth_x(x_left)
        xr = self._smooth_x(x_right)

        # Q,K ← X ; V ← Y
        K = self.W_k(self.ln_qk(xl))
        Q_left  = self.W_q(self.ln_qk(xl))
        Q_right = self.W_q(self.ln_qk(xr))
        V = self.W_vy(self.ln_v(y_left))

        # attention + residual
        left_ctx  = xl + self.attn(Q_left,  K, V)
        right_ctx = xr + self.attn(Q_right, K, V)

        h = torch.cat([left_ctx, right_ctx], dim=1)

        # optional FFN + residual
        h = self.ln_out(h + self.ff(h))
        return h
class CNNModel_GPAligned(nn.Module):
    def __init__(self, num_features, n_out, n_layers=1, d_model=32,
                 kernel_size=5, normalize=lambda x: x):
        super().__init__()
        self.x_encoder = nn.Linear(num_features, d_model)
        self.y_encoder = nn.Linear(1, d_model)
        self.blocks = nn.ModuleList([
            CNN_Attention_GPAligned(d_model, kernel_size)
            for _ in range(n_layers)
        ])
        self.out = nn.Sequential(
            nn.Linear(d_model, d_model), nn.GELU(),
            nn.Linear(d_model, n_out)
        )
        self.normalize = normalize
        self._init()

    def forward(self, x, y, trainset_size):
        x_enc = self.x_encoder(self.normalize(x))  # (B,N,D)
        y_enc = self.y_encoder(y)                  # (B,N,D)

        x_left,  x_right = x_enc[:, :trainset_size],  x_enc[:, trainset_size:]
        y_left           = y_enc[:, :trainset_size]

        h = torch.cat([x_left, x_right], dim=1)
        for blk in self.blocks:
            h = blk(h[:, :trainset_size], h[:, trainset_size:], y_left)

        return self.out(h)

    def _init(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases following the
    values of the cosine function between 0 and `pi * cycles` after a warmup
    period during which it increases linearly between 0 and 1.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

# ---------------- CNN Training ---------------- #
modelc = CNNModel_GPAligned(
    num_features, n_out=num_outputs, 
    d_model=32, kernel_size=5, n_layers=1,
    normalize=uniform_normalize
)
modelc.to(device)

optimizer = torch.optim.AdamW(modelc.parameters(), lr=lr)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_epochs, epochs)

# Create directory for saving CNN mhi there
cnn_save_dir = os.path.join("2D_DVA", "cnn")
os.makedirs(cnn_save_dir, exist_ok=True)

logger.info("Starting CNN training...")
start_time = time.time()

# Track validation loss for plotting
val_losses = []

for step in range(epochs * steps_per_epoch + 1):
    epoch = step // steps_per_epoch  # current epoch

    (xs, ys), trainset_size = dl.get_batch()
    xs, ys = xs.to(device), ys.to(device)
    pred_y = modelc(xs, ys.unsqueeze(-1), trainset_size)

    logits = pred_y[:, trainset_size:]
    target_y = ys[:, trainset_size:].clone().view(*logits.shape[:-1])

    loss = compute_bar_distribution_loss(logits, target_y, bucket_limits)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(modelc.parameters(), 1.)
    optimizer.step()
    scheduler.step()

    if step % (validate_epoch * steps_per_epoch) == 0:
        with torch.no_grad():
            pred_test_y = modelc(test_xs, test_ys.unsqueeze(-1), test_trainset_size)

            logits = pred_test_y[:, test_trainset_size:]
            target_test_y = test_ys[:, test_trainset_size:].clone().view(*logits.shape[:-1])

            val_loss = compute_bar_distribution_loss(logits, target_test_y, bucket_limits)

        val_losses.append(val_loss.item())

        logger.info(
            f"[CNN] Epoch {epoch}/{epochs} | Step {step} | "
            f"Val loss: {val_loss.item():.5f} | LR: {scheduler.get_last_lr()[0]:.6e} | "
            f"Train loss: {loss.item():.5f}"
        )

        # Save model every 10th epoch (skip epoch 0)
        if epoch > 0 and epoch % 10 == 0:
            save_path = os.path.join(cnn_save_dir, f"cnn_epoch_{epoch}.pth")
            torch.save(modelc.state_dict(), save_path)
            logger.info(f"💾 Saved CNN model at: {save_path}")

total_time = time.time() - start_time
logger.info(f"[CNN] 🕒 Total training time: {total_time:.2f} seconds")

# Final save
final_path = os.path.join(cnn_save_dir, "cnn_final.pth")
torch.save(modelc.state_dict(), final_path)
logger.info(f"✅ Final CNN model saved at: {final_path}")

logger.info(f"  CNNModel config: d_model=32, kernel_size=5, n_layers=1, normalize={uniform_normalize}")

# === Save Error vs Epoch Curve ===
plt.figure()
plt.plot(range(len(val_losses)), val_losses, marker='o', label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Validation Loss vs Epoch (CNN)")
plt.grid(True)
plt.legend()
curve_path = os.path.join(cnn_save_dir, "val_loss_curve.png")
plt.savefig(curve_path)
plt.close()
logger.info(f"📉 Validation loss curve saved at: {curve_path}")


# # %% [code] {"execution":{"iopub.status.busy":"2025-08-01T12:01:06.596757Z","iopub.execute_input":"2025-08-01T12:01:06.597151Z"}}
import time
import os
# ---------------- Transformer Training ---------------- #
modelt = Transformer(
    num_features, n_out=num_outputs, 
    d_model=128, n_layers=1, n_hidden=512, n_heads=4, 
    normalize=uniform_normalize
)
modelt.to(device)

optimizer = torch.optim.AdamW(modelt.parameters(), lr=lr)

num_training_steps = epochs * steps_per_epoch
num_warmup_steps  = warmup_epochs * steps_per_epoch
scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps
)

# Create directory for saving Transformer models
transformer_save_dir = os.path.join("2D_DVA", "transformer")
os.makedirs(transformer_save_dir, exist_ok=True)

logger.info("Starting Transformer training...")
start_time = time.time()

# Store validation loss per epoch
val_losses = []

for step in range(epochs * steps_per_epoch + 1):
    epoch = step // steps_per_epoch  # current epoch

    (xs, ys), trainset_size = dl.get_batch()
    xs, ys = xs.to(device), ys.to(device)
    pred_y = modelt(xs, ys.unsqueeze(-1), trainset_size)

    logits = pred_y[:, trainset_size:]
    target_y = ys[:, trainset_size:].clone().view(*logits.shape[:-1])

    loss = compute_bar_distribution_loss(logits, target_y, bucket_limits)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(modelt.parameters(), 1.)
    optimizer.step()
    scheduler.step()

    if step % (validate_epoch * steps_per_epoch) == 0:
        with torch.no_grad():
            pred_test_y = modelt(test_xs, test_ys.unsqueeze(-1), test_trainset_size)

            logits = pred_test_y[:, test_trainset_size:]
            target_test_y = test_ys[:, test_trainset_size:].clone().view(*logits.shape[:-1])

            val_loss = compute_bar_distribution_loss(logits, target_test_y, bucket_limits)

        val_losses.append(val_loss.item())

        logger.info(
            f"[Transformer] Epoch {epoch}/{epochs} | Step {step} | "
            f"Val loss: {val_loss.item():.5f} | LR: {scheduler.get_last_lr()[0]:.6e} | "
            f"Train loss: {loss.item():.5f}"
        )

        # Save model every 10th epoch (but not at epoch 0)
        if epoch > 0 and epoch % 10 == 0:
            save_path = os.path.join(transformer_save_dir, f"transformer_epoch_{epoch}.pth")
            torch.save(modelt.state_dict(), save_path)
            logger.info(f"💾 Saved Transformer model at: {save_path}")

total_time = time.time() - start_time
logger.info(f"[Transformer] 🚀 Total training time: {total_time:.2f} seconds")

# Final save
final_path = os.path.join(transformer_save_dir, "transformer_final.pth")
torch.save(modelt.state_dict(), final_path)
logger.info(f"✅ Final Transformer model saved at: {final_path}")

# === Save Error vs Epoch Curve ===
plt.figure()
plt.plot(range(len(val_losses)), val_losses, marker='o', label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Validation Loss vs Epoch (Transformer)")
plt.grid(True)
plt.legend()
curve_path = os.path.join(transformer_save_dir, "val_loss_curve.png")
plt.savefig(curve_path)
plt.close()
logger.info(f"📉 Validation loss curve saved at: {curve_path}")

logger.info(f"  Transformer config: d_model=128, n_layers=1, n_hidden=512, n_heads=4, normalize={uniform_normalize}")
