import numpy as np
import matplotlib.pyplot as plt
import math
import pathlib
import random
import sklearn.linear_model
import torch
import torch.nn as nn
from torch.nn.modules.transformer import MultiheadAttention
from torch.optim.lr_scheduler import LambdaLR
import time
import sys
import logging

# ---------------- Logging Setup ---------------- #
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("optuna_64d.log", mode="w"),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# ---------------- Device Setup ---------------- #
if torch.cuda.is_available():
    device = torch.device("cuda")
    logger.info(f"Using GPU: {device}")
else:
    device = torch.device("cpu")
    logger.warning("No GPU detected -> using CPU")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.876066Z","iopub.execute_input":"2025-08-01T11:59:25.876453Z","iopub.status.idle":"2025-08-01T11:59:25.883208Z","shell.execute_reply.started":"2025-08-01T11:59:25.876433Z","shell.execute_reply":"2025-08-01T11:59:25.882435Z"}}
import torch
import matplotlib.pyplot as plt

def rbf_kernel(x1, x2, lengthscale=0.2, variance=1.0):
    """
    Compute the RBF kernel matrix between x1 and x2 for multi-dimensional inputs.
    x1: (N, D)
    x2: (M, D)
    Returns:
        (N, M) kernel matrix
    """
    diff = x1.unsqueeze(1) - x2.unsqueeze(0)  # (N, M, D)
    dist_sq = (diff ** 2).sum(-1)  # (N, M)
    return variance * torch.exp(-0.5 * dist_sq / lengthscale**2)

def get_gp_prior(num_datasets, num_features, num_points_in_each_dataset, hyperparameters):
    """
    Generate synthetic data from a multivariate input GP prior.
    Outputs are shifted to match the voltage range (e.g., 0.9–1.2).
    """
    lengthscale = hyperparameters.get('lengthscale', 0.2)
    kernel_variance = hyperparameters.get('kernel_variance', 0.01)  # Lower variance for smooth voltage-like signals
    output_noise = hyperparameters.get('output_noise', 0.01)
    mean_shift = hyperparameters.get('mean_shift', 1.05)  # Shifted to mid-range of expected voltages

    xs = torch.rand(num_datasets, num_points_in_each_dataset, num_features)
    ys = []

    for i in range(num_datasets):
        x = xs[i]  # (N, D)
        K = rbf_kernel(x, x, lengthscale, kernel_variance)  # (N, N)
        K += output_noise**2 * torch.eye(num_points_in_each_dataset)
        y = torch.distributions.MultivariateNormal(
            torch.full((num_points_in_each_dataset,), mean_shift),
            K
        ).sample()
        ys.append(y)

    ys = torch.stack(ys, dim=0)  # (B, N)
    return xs, ys

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.884109Z","iopub.execute_input":"2025-08-01T11:59:25.884413Z","iopub.status.idle":"2025-08-01T11:59:25.904102Z","shell.execute_reply.started":"2025-08-01T11:59:25.884390Z","shell.execute_reply":"2025-08-01T11:59:25.903403Z"}}
class PriorDataLoader(object):
    def __init__(self, get_prior_data_fn, batch_size, num_points_in_each_dataset):
        self.batch_size = batch_size
        self.num_points_in_each_dataset = num_points_in_each_dataset
        self.get_prior_data_fn = get_prior_data_fn
        
        self.epoch = 0
        
    def get_batch(self, train = True, batch_size=None):
        """
        Returns:
            xs, ys, trainset_size
        """
        self.epoch += train
        bs = batch_size if batch_size else self.batch_size
        return self.get_prior_data_fn(bs, self.num_points_in_each_dataset), self._sample_trainset_size()

    def _sample_trainset_size(self):
        # samples a number between 1 and n-1 with higher weights to larger numbers
        # Appendix E.1 of Muller et al. (2021)
        min_samples = 1
        max_samples = self.num_points_in_each_dataset - 1
        
        sampling_weights = [1 / (max_samples - min_samples - i) for i in range(max_samples - min_samples)]
        return random.choices(range(min_samples, max_samples), sampling_weights)[0]

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.905810Z","iopub.execute_input":"2025-08-01T11:59:25.906047Z","iopub.status.idle":"2025-08-01T11:59:25.947429Z","shell.execute_reply.started":"2025-08-01T11:59:25.906032Z","shell.execute_reply":"2025-08-01T11:59:25.946672Z"}}
def get_bucket_limts(num_outputs, ys):
    """
    Creates buckets based on the values in y. 
    
    Args:
        num_outputs: number of buckets to create
        ys: values of y in the prior
    
    Returns:
        bucket_limits: An array containing the borders for each bucket. 
    """
    ys  = ys.flatten()

    if len(ys) % num_outputs:
        ys = ys[:-(len(ys) % num_outputs)]

    ys_per_bucket = len(ys) // num_outputs
    full_range = (ys.min(), ys.max())

    ys_sorted, _ = ys.sort(0)

    bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1] + ys_sorted[ys_per_bucket::ys_per_bucket]) / 2
    bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], dim=0)
    return bucket_limits


# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.948108Z","iopub.execute_input":"2025-08-01T11:59:25.948292Z","iopub.status.idle":"2025-08-01T11:59:25.961285Z","shell.execute_reply.started":"2025-08-01T11:59:25.948277Z","shell.execute_reply":"2025-08-01T11:59:25.960608Z"}}
def y_to_bucket_idx(ys, bl):
    """
    Maps the value of y to the corresponding bucket in `bl`
    
    Args:
        ys: value of y to be mapped into a bucket in bl
        bl: bucket limits specifiying the borders of buckets
    
    Returns:
        values of corresponding bucket number for y in bl
    """
    target_sample = torch.searchsorted(bl, ys) - 1
    target_sample[ys <= bl[0]] = 0
    target_sample[ys >= bl[-1]] = len(bl) - 1 - 1
    return target_sample

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.962218Z","iopub.execute_input":"2025-08-01T11:59:25.962728Z","iopub.status.idle":"2025-08-01T11:59:25.980598Z","shell.execute_reply.started":"2025-08-01T11:59:25.962701Z","shell.execute_reply":"2025-08-01T11:59:25.979965Z"}}
# Paste the model code here from the model files whose hyperparameters you want to tune
class Encoder(nn.Module):
    """Typical self attention module in transformer"""
    def __init__(self, d_model, n_heads, n_hidden, dropout=0.0):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.out = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, d_model),
        )
        
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, src, trainset_size):
        src_left, src_right = src[:, :trainset_size], src[:, trainset_size:]
        x_left = self.self_attn(src_left, src_left, src_left)[0] # all train points to each other
        x_right = self.self_attn(src_right, src_left, src_left)[0] # test points attend to train points
        x = torch.cat([x_left, x_right], dim=1)
        x = self.norm1(src + self.dropout(x))
        return self.norm2(self.dropout(self.out(x)) + x)


class Transformer(nn.Module):
    def __init__(self, num_features, n_out, n_layers=2, d_model=512, n_heads=4, n_hidden=1024, dropout=0.0, normalize=lambda x:x):
        super().__init__()
        
        self.x_encoder = nn.Linear(num_features, d_model)
        self.y_encoder = nn.Linear(1, d_model)
        
        self.model = nn.ModuleList(
            [Encoder(d_model, n_heads, n_hidden, dropout) for _ in range(n_layers)]
        )
        
        self.out = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.GELU(),
            nn.Linear(n_hidden, n_out)
        )
                
        self.normalize = normalize
        self.init_weights()
        
    def forward(self, x, y, trainset_size):
        """
        Args:
            x: num_datasets x number_of_points x num_features
            y: num_datasets x number_of_points
            trainset_size: int specifying the number of points to use as training dataset size
        
        Returns:
            outputs for each x
        """
        x_src = self.x_encoder(self.normalize(x))
        y_src = self.y_encoder(y)
        
        src = torch.cat([x_src[:, :trainset_size] + y_src[:, :trainset_size], x_src[:, trainset_size:]], dim=1)
        for encoder in self.model:
            src = encoder(src, trainset_size)

        return self.out(src)
    
    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.981445Z","iopub.execute_input":"2025-08-01T11:59:25.982242Z","iopub.status.idle":"2025-08-01T11:59:25.997695Z","shell.execute_reply.started":"2025-08-01T11:59:25.982214Z","shell.execute_reply":"2025-08-01T11:59:25.997066Z"}}
def get_halfnormal_with_p_weight_before(range_max, p=0.5):
    """
    Constructs a half normal distribution. 
    Args:
        range_max: used for scaling the half normal so that `p` portion of the half normal lies within range_max
        p: Cumulative probability under half normal

    Returns:
        half normal distribution
    """
    s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
    return torch.distributions.HalfNormal(s)

def compute_bar_distribution_loss(logits, target_y, bucket_limits, label_smoothing=0.0):
    """
    Implements Reimann distribution for logits. See Appendix E of Muller et al. 2021.
    
    Args:
        logits: num_datasets  x num_points_in_each_dataset x num_outputs_for_classification
        target_y: target class for each point 
        bucket_limits: border limits for each class
        label_smoothing: constant to define the amount of label smoothing to be used.
    
    Returns:
        loss: scalar value
    
    """
    target_y_idx = y_to_bucket_idx(target_y, bucket_limits)

    bucket_widths = bucket_limits[1:] - bucket_limits[:-1]
    bucket_log_probs = torch.log_softmax(logits, -1)
    scaled_bucket_log_probs = bucket_log_probs - torch.log(bucket_widths) # Refer to the equation above
    log_probs = scaled_bucket_log_probs.gather(-1, target_y_idx[..., None]).squeeze(-1)
    
    # full support distribution correction using half normals
    side_normals = (
        get_halfnormal_with_p_weight_before(bucket_widths[0]),
        get_halfnormal_with_p_weight_before(bucket_widths[-1])
    )
    # Correction for the bucket in the starting
    first_bucket = target_y_idx == 0
    log_probs[first_bucket] += side_normals[0].log_prob((bucket_limits[1] - target_y[first_bucket])).clamp(min=1e-8) + torch.log(bucket_widths[0])

    # Correction for the bucket at the end
    last_bucket = target_y_idx == len(bucket_widths) - 1
    log_probs[last_bucket] += side_normals[1].log_prob((target_y[last_bucket] - bucket_limits[-2])).clamp(min=1e-8) + torch.log(bucket_widths[-1])
    
    nll_loss = -log_probs
    smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
    
    loss = (1-label_smoothing) * nll_loss + label_smoothing * smooth_loss
    return loss.mean()

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.998299Z","iopub.execute_input":"2025-08-01T11:59:25.998484Z","iopub.status.idle":"2025-08-01T11:59:26.016106Z","shell.execute_reply.started":"2025-08-01T11:59:25.998468Z","shell.execute_reply":"2025-08-01T11:59:26.015490Z"}}
# copied from huggingface
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        # After warmup, decay by 0.1 every 100 steps
        decay_steps = (current_step - num_warmup_steps) // 4000
        return 0.1 ** decay_steps

    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

# Parameters related to prior data
num_points_in_each_dataset = 500
hyperparameters = {
    'lengthscale': 0.6,
    'kernel_variance': 0.01,
    'output_noise': 1e-4,
    'mean_shift': 1.0
}
bias_term = True
num_features = 10
num_outputs = 500
max_num_datasets = 16

logger.info("  GP hyperparameters:")
for k, v in hyperparameters.items():
    logger.info(f"    {k}: {v}")

# Riemannian distribution
_, ys = get_gp_prior(1000, num_features, num_points_in_each_dataset, hyperparameters)  # calibration or random sampling
bucket_limits = get_bucket_limts(num_outputs, ys)

# Training data
get_prior_data_fn = lambda x, y: get_gp_prior(x, num_features, y, hyperparameters)
dl = PriorDataLoader(get_prior_data_fn, batch_size=max_num_datasets, num_points_in_each_dataset=num_points_in_each_dataset)

# Models
uniform_normalize = lambda x: (x - 0.5) / math.sqrt(1 / 12)

# Learning settings
epochs = 200  # only used for scheduler init, not full training here
steps_per_epoch = 400
warmup_epochs = epochs // 4
validate_epoch = 10
lr = 0.001
bucket_limits = bucket_limits.to(device)

# Test data
(test_xs, test_ys), test_trainset_size = dl.get_batch(False, 64)
test_xs, test_ys = test_xs.to(device), test_ys.to(device)

import optuna

def objective(trial):
    # 🔹 Suggest hyperparameters
    lr = trial.suggest_loguniform("lr", 1e-5, 1e-2)
    d_model = trial.suggest_categorical("d_model", [32, 64, 128, 256])
    n_heads = trial.suggest_categorical("n_heads", [2, 4, 8])
    n_hidden = trial.suggest_categorical("n_hidden", [128, 256, 512, 1024])
    n_layers = trial.suggest_categorical("n_layers", [1, 2, 4])  # fixed categorical bug

    logger.info(f"🔍 Trial {trial.number} starting with params: "
                f"lr={lr}, d_model={d_model}, n_heads={n_heads}, "
                f"n_hidden={n_hidden}, n_layers={n_layers}")

    # 🔹 Create model
    modelt = Transformer(
        num_features, n_out=num_outputs,
        d_model=d_model, n_layers=n_layers,
        n_hidden=n_hidden, n_heads=n_heads,
        normalize=uniform_normalize
    ).to(device)

    optimizer = torch.optim.AdamW(modelt.parameters(), lr=lr)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_epochs, epochs)

    # === Single quick training step ===
    (xs, ys), trainset_size = dl.get_batch()
    xs, ys = xs.to(device), ys.to(device)
    pred_y = modelt(xs, ys.unsqueeze(-1), trainset_size)

    logits = pred_y[:, trainset_size:]
    target_y = ys[:, trainset_size:].clone().view(*logits.shape[:-1])

    loss = compute_bar_distribution_loss(logits, target_y, bucket_limits)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(modelt.parameters(), 1.)
    optimizer.step()
    scheduler.step()

    # === Quick validation ===
    with torch.no_grad():
        pred_test_y = modelt(test_xs, test_ys.unsqueeze(-1), test_trainset_size)
        logits = pred_test_y[:, test_trainset_size:]
        target_test_y = test_ys[:, test_trainset_size:].clone().view(*logits.shape[:-1])
        val_loss = compute_bar_distribution_loss(logits, target_test_y, bucket_limits)

    logger.info(f"[Trial {trial.number}] Train Loss: {loss.item():.5f} | Val Loss: {val_loss.item():.5f}")

    # Store both losses in the trial so we can filter later
    trial.set_user_attr("train_loss", loss.item())

    return val_loss.item()


if __name__ == "__main__":
    study = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
    logger.info("🚀 Starting Optuna optimization...")
    study.optimize(objective, n_trials=1000, timeout=None)

    # Best by validation loss
    best_val_trial = study.best_trial

    # Best by training loss
    best_train_trial = min(study.trials, key=lambda t: t.user_attrs.get("train_loss", float("inf")))

    logger.info("🏆 Best trial (Validation Loss):")
    logger.info(f"  Val Loss: {best_val_trial.value}")
    logger.info("  Params:")
    for key, value in best_val_trial.params.items():
        logger.info(f"    {key}: {value}")

    logger.info("🏆 Best trial (Training Loss):")
    logger.info(f"  Train Loss: {best_train_trial.user_attrs['train_loss']}")
    logger.info("  Params:")
    for key, value in best_train_trial.params.items():
        logger.info(f"    {key}: {value}")

