import numpy as np
import matplotlib.pyplot as plt
import math
import pathlib
import random
import sklearn.linear_model
import torch
import torch.nn as nn
from torch.nn.modules.transformer import MultiheadAttention
from torch.optim.lr_scheduler import LambdaLR
import time
import sys
import logging

# ---------------- Logging Setup ---------------- #
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    handlers=[
        logging.FileHandler("training_VA_5D.log", mode="w"),
        logging.StreamHandler(sys.stdout)
    ]
)
logger = logging.getLogger(__name__)

# ---------------- Device Setup ---------------- #
if torch.cuda.is_available():
    device = torch.device("cuda")
    logger.info(f"Using GPU: {device}")
else:
    device = torch.device("cpu")
    logger.warning("No GPU detected -> using CPU")

import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.876066Z","iopub.execute_input":"2025-08-01T11:59:25.876453Z","iopub.status.idle":"2025-08-01T11:59:25.883208Z","shell.execute_reply.started":"2025-08-01T11:59:25.876433Z","shell.execute_reply":"2025-08-01T11:59:25.882435Z"}}
import torch
import matplotlib.pyplot as plt

def rbf_kernel(x1, x2, lengthscale=0.2, variance=1.0):
    """
    Compute the RBF kernel matrix between x1 and x2 for multi-dimensional inputs.
    x1: (N, D)
    x2: (M, D)
    Returns:
        (N, M) kernel matrix
    """
    diff = x1.unsqueeze(1) - x2.unsqueeze(0)  # (N, M, D)
    dist_sq = (diff ** 2).sum(-1)  # (N, M)
    return variance * torch.exp(-0.5 * dist_sq / lengthscale**2)

def get_gp_prior(num_datasets, num_features, num_points_in_each_dataset, hyperparameters):
    """
    Generate synthetic data from a multivariate input GP prior.
    Outputs are shifted to match the voltage range (e.g., 0.9–1.2).
    """
    lengthscale = hyperparameters.get('lengthscale', 0.2)
    kernel_variance = hyperparameters.get('kernel_variance', 0.01)  # Lower variance for smooth voltage-like signals
    output_noise = hyperparameters.get('output_noise', 0.01)
    mean_shift = hyperparameters.get('mean_shift', 1.05)  # Shifted to mid-range of expected voltages

    xs = torch.rand(num_datasets, num_points_in_each_dataset, num_features)
    ys = []

    for i in range(num_datasets):
        x = xs[i]  # (N, D)
        K = rbf_kernel(x, x, lengthscale, kernel_variance)  # (N, N)
        K += output_noise**2 * torch.eye(num_points_in_each_dataset)
        y = torch.distributions.MultivariateNormal(
            torch.full((num_points_in_each_dataset,), mean_shift),
            K
        ).sample()
        ys.append(y)

    ys = torch.stack(ys, dim=0)  # (B, N)
    return xs, ys

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.884109Z","iopub.execute_input":"2025-08-01T11:59:25.884413Z","iopub.status.idle":"2025-08-01T11:59:25.904102Z","shell.execute_reply.started":"2025-08-01T11:59:25.884390Z","shell.execute_reply":"2025-08-01T11:59:25.903403Z"}}
class PriorDataLoader(object):
    def __init__(self, get_prior_data_fn, batch_size, num_points_in_each_dataset):
        self.batch_size = batch_size
        self.num_points_in_each_dataset = num_points_in_each_dataset
        self.get_prior_data_fn = get_prior_data_fn
        
        self.epoch = 0
        
    def get_batch(self, train = True, batch_size=None):
        """
        Returns:
            xs, ys, trainset_size
        """
        self.epoch += train
        bs = batch_size if batch_size else self.batch_size
        return self.get_prior_data_fn(bs, self.num_points_in_each_dataset), self._sample_trainset_size()

    def _sample_trainset_size(self):
        # samples a number between 1 and n-1 with higher weights to larger numbers
        # Appendix E.1 of Muller et al. (2021)
        min_samples = 1
        max_samples = self.num_points_in_each_dataset - 1
        
        sampling_weights = [1 / (max_samples - min_samples - i) for i in range(max_samples - min_samples)]
        return random.choices(range(min_samples, max_samples), sampling_weights)[0]

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.905810Z","iopub.execute_input":"2025-08-01T11:59:25.906047Z","iopub.status.idle":"2025-08-01T11:59:25.947429Z","shell.execute_reply.started":"2025-08-01T11:59:25.906032Z","shell.execute_reply":"2025-08-01T11:59:25.946672Z"}}
def get_bucket_limts(num_outputs, ys):
    """
    Creates buckets based on the values in y. 
    
    Args:
        num_outputs: number of buckets to create
        ys: values of y in the prior
    
    Returns:
        bucket_limits: An array containing the borders for each bucket. 
    """
    ys  = ys.flatten()

    if len(ys) % num_outputs:
        ys = ys[:-(len(ys) % num_outputs)]

    ys_per_bucket = len(ys) // num_outputs
    full_range = (ys.min(), ys.max())

    ys_sorted, _ = ys.sort(0)

    bucket_limits = (ys_sorted[ys_per_bucket-1::ys_per_bucket][:-1] + ys_sorted[ys_per_bucket::ys_per_bucket]) / 2
    bucket_limits = torch.cat([full_range[0].unsqueeze(0), bucket_limits, full_range[1].unsqueeze(0)], dim=0)
    return bucket_limits


# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.948108Z","iopub.execute_input":"2025-08-01T11:59:25.948292Z","iopub.status.idle":"2025-08-01T11:59:25.961285Z","shell.execute_reply.started":"2025-08-01T11:59:25.948277Z","shell.execute_reply":"2025-08-01T11:59:25.960608Z"}}
def y_to_bucket_idx(ys, bl):
    """
    Maps the value of y to the corresponding bucket in `bl`
    
    Args:
        ys: value of y to be mapped into a bucket in bl
        bl: bucket limits specifiying the borders of buckets
    
    Returns:
        values of corresponding bucket number for y in bl
    """
    target_sample = torch.searchsorted(bl, ys) - 1
    target_sample[ys <= bl[0]] = 0
    target_sample[ys >= bl[-1]] = len(bl) - 1 - 1
    return target_sample

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.962218Z","iopub.execute_input":"2025-08-01T11:59:25.962728Z","iopub.status.idle":"2025-08-01T11:59:25.980598Z","shell.execute_reply.started":"2025-08-01T11:59:25.962701Z","shell.execute_reply":"2025-08-01T11:59:25.979965Z"}}
class Encoder(nn.Module):
    """Typical self attention module in transformer"""
    def __init__(self, d_model, n_heads, n_hidden, dropout=0.1):
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)
        self.out = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, d_model),
        )
        
        self.dropout = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        
    def forward(self, src, trainset_size):
        src_left, src_right = src[:, :trainset_size], src[:, trainset_size:]
        x_left = self.self_attn(src_left, src_left, src_left)[0] # all train points to each other
        x_right = self.self_attn(src_right, src_left, src_left)[0] # test points attend to train points
        x = torch.cat([x_left, x_right], dim=1)
        x = self.norm1(src + self.dropout(x))
        return self.norm2(self.dropout(self.out(x)) + x)


class Transformer(nn.Module):
    def __init__(self, num_features, n_out, n_layers=2, d_model=512, n_heads=4, n_hidden=1024, dropout=0.0, normalize=lambda x:x):
        super().__init__()
        
        self.x_encoder = nn.Linear(num_features, d_model)
        self.y_encoder = nn.Linear(1, d_model)
        
        self.model = nn.ModuleList(
            [Encoder(d_model, n_heads, n_hidden, dropout) for _ in range(n_layers)]
        )
        
        self.out = nn.Sequential(
            nn.Linear(d_model, n_hidden),
            nn.GELU(),
            nn.Linear(n_hidden, n_out)
        )
                
        self.normalize = normalize
        self.init_weights()
        
    def forward(self, x, y, trainset_size):
        """
        Args:
            x: num_datasets x number_of_points x num_features
            y: num_datasets x number_of_points
            trainset_size: int specifying the number of points to use as training dataset size
        
        Returns:
            outputs for each x
        """
        x_src = self.x_encoder(self.normalize(x))
        y_src = self.y_encoder(y)
        
        src = torch.cat([x_src[:, :trainset_size] + y_src[:, :trainset_size], x_src[:, trainset_size:]], dim=1)
        for encoder in self.model:
            src = encoder(src, trainset_size)

        return self.out(src)
    
    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.981445Z","iopub.execute_input":"2025-08-01T11:59:25.982242Z","iopub.status.idle":"2025-08-01T11:59:25.997695Z","shell.execute_reply.started":"2025-08-01T11:59:25.982214Z","shell.execute_reply":"2025-08-01T11:59:25.997066Z"}}
def get_halfnormal_with_p_weight_before(range_max, p=0.5):
    """
    Constructs a half normal distribution. 
    Args:
        range_max: used for scaling the half normal so that `p` portion of the half normal lies within range_max
        p: Cumulative probability under half normal

    Returns:
        half normal distribution
    """
    s = range_max / torch.distributions.HalfNormal(torch.tensor(1.)).icdf(torch.tensor(p))
    return torch.distributions.HalfNormal(s)

def compute_bar_distribution_loss(logits, target_y, bucket_limits, label_smoothing=0.0):
    """
    Implements Reimann distribution for logits. See Appendix E of Muller et al. 2021.
    
    Args:
        logits: num_datasets  x num_points_in_each_dataset x num_outputs_for_classification
        target_y: target class for each point 
        bucket_limits: border limits for each class
        label_smoothing: constant to define the amount of label smoothing to be used.
    
    Returns:
        loss: scalar value
    
    """
    target_y_idx = y_to_bucket_idx(target_y, bucket_limits)

    bucket_widths = bucket_limits[1:] - bucket_limits[:-1]
    bucket_log_probs = torch.log_softmax(logits, -1)
    scaled_bucket_log_probs = bucket_log_probs - torch.log(bucket_widths) # Refer to the equation above
    log_probs = scaled_bucket_log_probs.gather(-1, target_y_idx[..., None]).squeeze(-1)
    
    # full support distribution correction using half normals
    side_normals = (
        get_halfnormal_with_p_weight_before(bucket_widths[0]),
        get_halfnormal_with_p_weight_before(bucket_widths[-1])
    )
    # Correction for the bucket in the starting
    first_bucket = target_y_idx == 0
    log_probs[first_bucket] += side_normals[0].log_prob((bucket_limits[1] - target_y[first_bucket])).clamp(min=1e-8) + torch.log(bucket_widths[0])

    # Correction for the bucket at the end
    last_bucket = target_y_idx == len(bucket_widths) - 1
    log_probs[last_bucket] += side_normals[1].log_prob((target_y[last_bucket] - bucket_limits[-2])).clamp(min=1e-8) + torch.log(bucket_widths[-1])
    
    nll_loss = -log_probs
    smooth_loss = -scaled_bucket_log_probs.mean(dim=-1)
    
    loss = (1-label_smoothing) * nll_loss + label_smoothing * smooth_loss
    return loss.mean()

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:25.998299Z","iopub.execute_input":"2025-08-01T11:59:25.998484Z","iopub.status.idle":"2025-08-01T11:59:26.016106Z","shell.execute_reply.started":"2025-08-01T11:59:25.998468Z","shell.execute_reply":"2025-08-01T11:59:26.015490Z"}}
# copied from huggingface
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        # After warmup, decay by 0.1 every 100 steps
        decay_steps = (current_step - num_warmup_steps) // 4000
        return 0.1 ** decay_steps

    return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T11:59:26.016989Z","iopub.execute_input":"2025-08-01T11:59:26.017267Z","iopub.status.idle":"2025-08-01T12:01:06.593889Z","shell.execute_reply.started":"2025-08-01T11:59:26.017234Z","shell.execute_reply":"2025-08-01T12:01:06.593028Z"}}
# Parameters related to prior data
num_points_in_each_dataset = 500
hyperparameters = {
    'lengthscale': 0.6,
    'kernel_variance': 0.01,
    'output_noise':1e-4,
    'mean_shift': 1.0
}
bias_term = True
num_features = 5
num_outputs = 500
max_num_datasets = 16
logger.info("  GP hyperparameters:")
for k, v in hyperparameters.items():
    logger.info(f"    {k}: {v}")
# Reimannian distribution
_, ys = get_gp_prior(1000, num_features, num_points_in_each_dataset, hyperparameters) #just calibration or randomly sampling 10000 y to get bins
bucket_limits = get_bucket_limts(num_outputs, ys)

# training data
get_prior_data_fn = lambda x, y: get_gp_prior(x, num_features, y, hyperparameters)
dl = PriorDataLoader(get_prior_data_fn, batch_size=max_num_datasets, num_points_in_each_dataset=num_points_in_each_dataset)

# Models
uniform_normalize = lambda x: (x-0.5)/math.sqrt(1/12)

# Learning
epochs = 200
steps_per_epoch = 400
warmup_epochs = epochs // 4
validate_epoch = 10
lr = 0.001
bucket_limits = bucket_limits.to(device)

logger.info("Hyperparameters:")
logger.info(f"  epochs: {epochs}")
logger.info(f"  steps_per_epoch: {steps_per_epoch}")
logger.info(f"  warmup_epochs: {warmup_epochs}")
logger.info(f"  validate_epoch: {validate_epoch}")
logger.info(f"  learning_rate: {lr}")
logger.info(f"  num_points_in_each_dataset: {num_points_in_each_dataset}")
logger.info(f"  bias_term: {bias_term}")
logger.info(f"  num_features: {num_features}")
logger.info(f"  num_outputs: {num_outputs}")
logger.info(f"  max_num_datasets: {max_num_datasets}")
logger.info(f"  calibration_size: 1000")

# test data
(test_xs, test_ys), test_trainset_size = dl.get_batch(False, 64)
test_xs, test_ys = test_xs.to(device), test_ys.to(device)
torch.save(bucket_limits.cpu(), "bucket_limits_5d_VA_500.pth") # save bucket limits to be used during inference

# will take time in case of gp prior as we r calibrating the bins through 10000 datasets which is highly computational in case of GP

# %% [code] {"execution":{"iopub.status.busy":"2025-08-01T12:01:06.596757Z","iopub.execute_input":"2025-08-01T12:01:06.597151Z"}}
import time
import os
# ---------------- Transformer Training ---------------- #
for run_id in ['A', 'B', 'C', 'D', 'E']:
    logger.info(f"--- Starting Transformer Training: Run {run_id} ---")
    modelt = Transformer(
        num_features, n_out=num_outputs,
        d_model=64, n_layers=2, n_hidden=1024, n_heads=8,
        normalize=uniform_normalize
    )
    modelt.to(device)

    optimizer = torch.optim.AdamW(modelt.parameters(), lr=lr)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_epochs, epochs)

    # Create directory for saving Transformer models
    transformer_save_dir = os.path.join("5D_VA", "transformer")
    os.makedirs(transformer_save_dir, exist_ok=True)

    start_time = time.time()

    for step in range(epochs * steps_per_epoch + 1):
        epoch = step // steps_per_epoch  # current epoch

        (xs, ys), trainset_size = dl.get_batch()
        xs, ys = xs.to(device), ys.to(device)
        pred_y = modelt(xs, ys.unsqueeze(-1), trainset_size)

        logits = pred_y[:, trainset_size:]
        target_y = ys[:, trainset_size:].clone().view(*logits.shape[:-1])

        loss = compute_bar_distribution_loss(logits, target_y, bucket_limits)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(modelt.parameters(), 1.)
        optimizer.step()
        scheduler.step()

        if step % (validate_epoch * steps_per_epoch) == 0:
            with torch.no_grad():
                pred_test_y = modelt(test_xs, test_ys.unsqueeze(-1), test_trainset_size)

                logits = pred_test_y[:, test_trainset_size:]
                target_test_y = test_ys[:, test_trainset_size:].clone().view(*logits.shape[:-1])

                val_loss = compute_bar_distribution_loss(logits, target_test_y, bucket_limits)

            logger.info(
                f"[Transformer Run {run_id}] Epoch {epoch}/{epochs} | Step {step} | "
                f"Val loss: {val_loss.item():.5f} | LR: {scheduler.get_last_lr()[0]:.6e} | "
                f"Train loss: {loss.item():.5f}"
            )

    total_time = time.time() - start_time
    logger.info(f"[Transformer Run {run_id}] 🚀 Total training time: {total_time:.2f} seconds")

    # Final save for the current run
    final_path = os.path.join(transformer_save_dir, f"transformer_final_run_{run_id}.pth")
    torch.save(modelt.state_dict(), final_path)
    logger.info(f"✅ Final Transformer model for run {run_id} saved at: {final_path}")

# %% [code]
import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.W_q = nn.Linear(d_model, d_model // 2)
        self.W_k = nn.Linear(d_model, d_model // 2)
        self.W_v = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
    
    def forward(self, x):
        q, k, v = self.W_q(x), self.W_k(x), self.W_v(x)
        attn_weights = self.softmax(torch.bmm(q, k.transpose(1, 2)) / (q.size(-1) ** 0.5))
        attn_output = torch.bmm(attn_weights, v)
        return attn_output + x  # Residual connection

class CNN_Attention(nn.Module):
    def __init__(self, d_model, kernel_size=5, dropout=0):
        super().__init__()
        self.conv1 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size//2)
        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size, padding=kernel_size//2)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attn = AttentionBlock(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src, trainset_size):
        src_left, src_right = src[:, :trainset_size, :], src[:, trainset_size:, :]
        
        x_left = self.conv1(src_left.permute(0, 2, 1)).permute(0, 2, 1)
        x_right = self.conv2(src_right.permute(0, 2, 1)).permute(0, 2, 1)

        x = torch.cat([x_left, x_right], dim=1)
        x = self.norm1(src + self.dropout(x))  
        x = self.attn(x)  # Apply attention

        return self.norm2(src + self.dropout(x))

class CNNModel(nn.Module):
    def __init__(self, num_features, n_out, n_layers=6, d_model=128, kernel_size=5, dropout=0, normalize=lambda x: x):
        super().__init__()
        
        self.x_encoder = nn.Linear(num_features, d_model)
        self.y_encoder = nn.Linear(1, d_model)
        
        self.model = nn.ModuleList([CNN_Attention(d_model, kernel_size, dropout) for _ in range(n_layers)])
        
        self.out = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, n_out)
        )
        
        self.normalize = normalize
        self.init_weights()
    
    def forward(self, x, y, trainset_size):
        x_src = self.x_encoder(self.normalize(x))
        y_src = self.y_encoder(y)
        src = torch.cat([x_src[:, :trainset_size] + y_src[:, :trainset_size], x_src[:, trainset_size:]], dim=1)
        
        for layer in self.model:
            src = layer(src, trainset_size)

        return self.out(src)
    
    def init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)



# %% [code]
# copied from huggingface
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases following the
    values of the cosine function between 0 and `pi * cycles` after a warmup
    period during which it increases linearly between 0 and 1.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)

import os
import matplotlib.pyplot as plt
import time

# ---------------- CNN Training ---------------- #
for run_id in ['A', 'B', 'C', 'D', 'E']:
    logger.info(f"--- Starting CNN Training: Run {run_id} ---")
    modelc = CNNModel(
        num_features, n_out=num_outputs,
        d_model=32, kernel_size=5, n_layers=4,
        normalize=uniform_normalize
    )
    modelc.to(device)

    optimizer = torch.optim.AdamW(modelc.parameters(), lr=lr)
    scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_epochs, epochs)

    # Create directory for saving CNN models
    cnn_save_dir = os.path.join("5D_VA", "cnn")
    os.makedirs(cnn_save_dir, exist_ok=True)

    start_time = time.time()

    for step in range(epochs * steps_per_epoch + 1):
        epoch = step // steps_per_epoch  # current epoch

        (xs, ys), trainset_size = dl.get_batch()
        xs, ys = xs.to(device), ys.to(device)
        pred_y = modelc(xs, ys.unsqueeze(-1), trainset_size)

        logits = pred_y[:, trainset_size:]
        target_y = ys[:, trainset_size:].clone().view(*logits.shape[:-1])

        loss = compute_bar_distribution_loss(logits, target_y, bucket_limits)

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(modelc.parameters(), 1.)
        optimizer.step()
        scheduler.step()

        if step % (validate_epoch * steps_per_epoch) == 0:
            with torch.no_grad():
                pred_test_y = modelc(test_xs, test_ys.unsqueeze(-1), test_trainset_size)

                logits = pred_test_y[:, test_trainset_size:]
                target_test_y = test_ys[:, test_trainset_size:].clone().view(*logits.shape[:-1])

                val_loss = compute_bar_distribution_loss(logits, target_test_y, bucket_limits)

            logger.info(
                f"[CNN Run {run_id}] Epoch {epoch}/{epochs} | Step {step} | "
                f"Val loss: {val_loss.item():.5f} | LR: {scheduler.get_last_lr()[0]:.6e} | "
                f"Train loss: {loss.item():.5f}"
            )

    total_time = time.time() - start_time
    logger.info(f"[CNN Run {run_id}] 🕒 Total training time: {total_time:.2f} seconds")

    # Final save for the current run
    final_path = os.path.join(cnn_save_dir, f"cnn_final_run_{run_id}.pth")
    torch.save(modelc.state_dict(), final_path)
    logger.info(f"✅ Final CNN model for run {run_id} saved at: {final_path}")