import torch
import torch.distributed as dist
from torch.nn import BCEWithLogitsLoss
from torch_geometric.loader import DataListLoader
import os
import logging
import argparse as args
import matplotlib.pyplot as plt 
import time
import networkx as nx
import numpy as np
import pymatching
from pymatching import Matching
import torch.nn.functional as F
from GT_C import stab_to_coord, path_between_stabilizers, path_between_stabilizers_X
import math
import copy
import ldpc
import scipy.sparse
from Data import Get_toric_Code, Get_rotated_surface_Code
from Rotated_GT import get_qubits_from_edge_path


H_SPARSE_CACHE = None
L_TORCH_CACHE = None
H_ROTATED_CACHE = {}
TORIC_PATH_CACHE_Z = {} # Paths for Z-stabilizers (X-errors)
TORIC_PATH_CACHE_X = {} # Paths for X-stabilizers (Z-errors)
PATHS_PRECOMPUTED_FOR_L = None
def precompute_toric_paths(L):
    global TORIC_PATH_CACHE_Z, TORIC_PATH_CACHE_X, PATHS_PRECOMPUTED_FOR_L
    
    if PATHS_PRECOMPUTED_FOR_L == L:
        return # Already done for this size

    logging.info(f"Precomputing Toric paths for L={L}...")
    num_stabs = L * L
    
    for u in range(num_stabs):
        for v in range(u + 1, num_stabs):
            # Z-stabilizer paths (handling X errors)
            _, path_z = path_between_stabilizers(u, v, L=L)
            path_set_z = frozenset(path_z)
            TORIC_PATH_CACHE_Z[(u, v)] = path_set_z
            TORIC_PATH_CACHE_Z[(v, u)] = path_set_z
            
            # X-stabilizer paths (handling Z errors)
            _, path_x = path_between_stabilizers_X(u, v, L=L)
            path_set_x = frozenset(path_x)
            TORIC_PATH_CACHE_X[(u, v)] = path_set_x
            TORIC_PATH_CACHE_X[(v, u)] = path_set_x
            
    PATHS_PRECOMPUTED_FOR_L = L
    logging.info("Precomputation complete.")

def get_rotated_H_matrices(L): 
    global H_ROTATED_CACHE
    if L in H_ROTATED_CACHE:
        return H_ROTATED_CACHE[L]
    
    # Not in cache, compute them
    logging.info(f"Caching H_Z and H_X matrices for L={L}")
    num_stabs = (L**2 - 1) // 2
    num_qubits = L**2
    H_full, _ = Get_rotated_surface_Code(L, full_H=True)
    
    H_Z = H_full[0:num_stabs, 0:num_qubits]
    H_X = H_full[num_stabs:, num_qubits:]
    
    H_ROTATED_CACHE[L] = (H_Z, H_X)
    return H_Z, H_X

def get_bposd_matrices(args):
    global H_SPARSE_CACHE, L_TORCH_CACHE
    
    L_val = args.code_L if not hasattr(args, 'code_L_orig') else args.code_L_orig 
    
    if H_SPARSE_CACHE is None:

        if args.code_type == 'toric':
            code_func_name = f'Get_{args.code_type}_Code'
        elif args.code_type == 'rotated':
            code_func_name = f'Get_{args.code_type}_surface_Code'
        else:
            raise ValueError(f"Unknown code_type: {args.code_type}")

        H, Lz = eval(code_func_name)(L_val, full_H=args.noise_type == 'depolarization')
        
        H_SPARSE_CACHE = scipy.sparse.csr_matrix(H)
        L_TORCH_CACHE = torch.from_numpy(Lz).long().cpu()
    
    return H_SPARSE_CACHE, L_TORCH_CACHE


#global var for LER vs Epoch
ler_vs_epochs_data = {}


def train_step(model, batch, optimizer, scheduler, device, use_warmup, alpha = 0.5, beta = 0.5): # recives a graph every step  
    model.train()
    batch = batch.to(device)

    optimizer.zero_grad()

    # Forward pass
    edge_logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.syndrome)  # shape: [num_edges]
    label_vector = batch.y

    logits = edge_logits.view(-1) 
    targets = label_vector.view(-1)
    edge_probs = torch.sigmoid(edge_logits)

    L_confidance = F.binary_cross_entropy(edge_probs.view(-1), edge_probs.view(-1), reduction="mean")

    # Loss
    loss_fn = BCEWithLogitsLoss() 
    bce_loss = loss_fn(edge_logits.view(-1), label_vector.view(-1)) 

    loss = bce_loss + (0.01 * L_confidance)
    #loss = bce_loss

    # Backward pass
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
    optimizer.step()

    if use_warmup and scheduler is not None:
        scheduler.step()


    return loss.item()



def train_model(model, dataloader, optimizer, scheduler, device, args, test_dataloader_list, ps_test, start_epoch, best_loss, writer):
    best_loss = float("inf")

    train_losses = []
    test_accuracies = {}
    learning_rates = []
    
    ber_all_pred_epochs = {}
    ber_all_manhattan_epochs = {}
    ler_all_pred_epochs = {}
    ler_all_manhattan_epochs = {}


    for epoch in range(start_epoch + 1, args.epochs + 1):
        start_time = time.time()

        model.train()
        total_loss = 0.0
        batch_count = 0
        node_counts = []

        for graph_list in dataloader:
            batch_loss = 0.0
    
            for graph in graph_list:
                graph = graph.to(device) 
                node_counts.append(torch.unique(graph.edge_index).numel())
                loss = train_step(model, graph, optimizer, scheduler, device, args.use_warmup) 
                batch_loss += loss
            avg_batch_loss = batch_loss / len(graph_list)
            
            total_loss += avg_batch_loss
            batch_count += 1

        avg_loss = total_loss / len(dataloader)
        train_losses.append(avg_loss)


        end_time = time.time()
        duration = end_time - start_time

        if node_counts:
            avg_nodes = sum(node_counts) / len(node_counts)


        print(f"Epoch {epoch}: Average Loss = {avg_loss:.4f} | Time = {duration:.2f} sec")
        logging.info(f"Epoch {epoch}: Average Loss = {avg_loss:.4f} | Time = {duration:.2f} sec | Number of nodes per sample = {avg_nodes}")
        
        if not args.use_warmup: #in case we dont use warmup
            scheduler.step() 

        
        current_lr = optimizer.param_groups[0]['lr']
        learning_rates.append(current_lr)
        logging.info(f"Epoch {epoch}: Learning Rate = {current_lr:.6f}")

        #tensor board
        writer.add_scalar("Loss/train", avg_loss, epoch)
        writer.add_scalar("LR", current_lr, epoch)


        # Validation evaluation logic
        if epoch >= 1150 and epoch % 7 == 0:
            # test every 30 epochs starting from epoch 30 using the "best model"
            best_model_path = os.path.join(args.path, 'best_checkpoint.pt')
            if os.path.exists(best_model_path):

                training_state = copy.deepcopy(model.state_dict())

                checkpoint = torch.load(best_model_path, map_location=device)
                model.load_state_dict(checkpoint['model_state_dict'])
                print(f"[BEST MODEL] Loaded best model (epoch {checkpoint['epoch']}) for testing at epoch {epoch}.")
                logging.info(f"[BEST MODEL] Loaded best model (epoch {checkpoint['epoch']}) for testing at epoch {epoch}.")

                # Save tagged best model
                tagged_best_path = os.path.join(args.path, f'best_model_epoch_{epoch:03d}.pt')
                torch.save(checkpoint, tagged_best_path)

                # Run test with best model
                test_model(model, test_dataloader_list, device, ps_test, args, final_testing=True, epoch=epoch, writer=writer)

                model.load_state_dict(training_state)
                model.train()
            else:
                print("Warning: best_checkpoint.pt not found during validation.")
                logging.warning("best_checkpoint.pt not found during validation.")

        elif epoch >= 1150 and epoch % 11 == 0:
            # test every 7 epochs starting from epoch 45 using the "last model"
            print(f"[LAST MODEL] Testing last model directly at epoch {epoch}.")
            logging.info(f"[LAST MODEL] Testing last model directly at epoch {epoch}.")

            last_model_path = os.path.join(args.path, f'last_model_epoch_{epoch:03d}.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, last_model_path)
            print(f"Saved last model checkpoint at epoch {epoch}.")
            logging.info(f"Saved last model checkpoint at epoch {epoch}.")
            
            test_model(model, test_dataloader_list, device, ps_test, args, final_testing=True, epoch=epoch, writer=writer)


        # Save full checkpoint
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_loss': best_loss
        }
        torch.save(checkpoint, os.path.join(args.path, 'last_checkpoint.pt'))

        # Save best model separately if needed
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(checkpoint, os.path.join(args.path, 'best_checkpoint.pt'))
            print("Best model saved.")
            logging.info("Best model saved.")


        
    plot_training(train_losses, args.path)
    plot_learning_rate(learning_rates, args.path)

    return test_accuracies


@torch.no_grad()
def test_model(model, dataloader_list, device, ps_test, args, final_testing = False, epoch= None, writer=None):
    if args.code_type == 'toric':
        precompute_toric_paths(args.code_L)
    model.eval()
    H_full_sparse, L_full_torch = get_bposd_matrices(args)
    collected_weights = [] #TODO abletion
    results = []
    acc_list = []
    p_dict = {}

    ber_all_pred = []
    ber_all_manhattan = []
    ler_all_pred = []
    ler_all_manhattan = []
    ber_all_bposd = []
    ler_all_bposd = []
    ps_vals = [float(p) for p in ps_test]

    
    times_manhattan_mwpm = []   
    times_cpu_mwpm = []
    times_total_latency = [] 
    all_build_times = []
    all_solve_times = []

    starter = torch.cuda.Event(enable_timing=True)
    ender = torch.cuda.Event(enable_timing=True)
    

    for i, dataloader in enumerate(dataloader_list): # for each noise level 
        predictions_at_p = []
        total = 0
        correct = 0

        ber_list_pred = []
        ber_list_manhattan = []
        ler_list_pred = []  
        ler_list_manhattan = []
        ber_list_bposd = []
        ler_list_bposd = []
        p_val = ps_test[i]

        # Initialize 
        bposd_decoder = ldpc.bposd_decoder(
            H_full_sparse, 
            error_rate=float(p_val), 
            max_iter=60, 
            bp_method="ms" # 
        )
        

        
        for batch_idx, graph_list in enumerate(dataloader):
            for graph_id, graph in enumerate(graph_list):
                graph = graph.to(device)
                start_time_cpu = time.perf_counter()
                if device.type == 'cuda':
                    starter.record()
                # ----------------------
                edge_logits = model(graph.x, graph.edge_index, graph.edge_attr, graph.syndrome)
                preds = torch.sigmoid(edge_logits)
                # ----------------------
                if device.type == 'cuda':
                    ender.record()
                    torch.cuda.synchronize()
                    gpu_time_ms = starter.elapsed_time(ender)
                collected_weights.append(preds.detach().cpu()) #TODO abletion

                # ======= averaging over p both direction to one weight ===========
                #accuracy check
                if not final_testing:
                    preds_flat = preds.cpu().flatten() # type: ignore
                    labels_flat = graph.y.cpu().flatten()

                    confident_mask = (preds_flat > 0.8) | (preds_flat < 0.2)

                    # Round predictions: >0.8 → 1, <0.2 → 0
                    confident_preds = torch.where(preds_flat > 0.8, torch.tensor(1.0), preds_flat)
                    confident_preds = torch.where(preds_flat < 0.2, torch.tensor(0.0), confident_preds)

                    # Apply mask and compute accuracy
                    correct += (confident_preds == labels_flat)[confident_mask].sum().item()
                    total += confident_mask.sum().item()

                
                edge_index = graph.edge_index.cpu()
                preds_cpu = preds.cpu()
                # print("preds", preds)
                # print("edge_index", edge_index)
                # print("labels:", graph.y)
            #=============================== done accuracy check =========================
                if args.noise_type == "depolarization":
                    edge_attr_cpu = graph.edge_attr.cpu() 
                    edge_mask_Z = edge_attr_cpu[:, 3] == 1
                    edge_index_Z = edge_index[:, edge_mask_Z]
                    preds_Z = preds_cpu[edge_mask_Z]
                    
                    edge_mask_X = edge_attr_cpu[:, 3] == 0
                    edge_index_X = edge_index[:, edge_mask_X]
                    preds_X = preds_cpu[edge_mask_X]

                    edge_dict_Z = {}
                    for idx in range(edge_index_Z.size(1)):
                        u, v = edge_index_Z[0, idx].item(), edge_index_Z[1, idx].item()
                        key = tuple(sorted((u, v)))
                        if key in edge_dict_Z: edge_dict_Z[key].append(preds_Z[idx].item())
                        else: edge_dict_Z[key] = [preds_Z[idx].item()]
                    edge_list_Z = []
                    weight_list_Z = []
                    for (u, v), weights in edge_dict_Z.items():
                        edge_list_Z.append([u,v])
                        weight_list_Z.append(max(weights))

                    edge_index_clean_Z = torch.tensor(edge_list_Z, dtype=torch.long).T
                    weights_clean_Z = torch.tensor(weight_list_Z, dtype=torch.float32)

                    edge_dict_X = {}
                    for idx in range(edge_index_X.size(1)):
                        u, v = edge_index_X[0, idx].item(), edge_index_X[1, idx].item()
                        key = tuple(sorted((u, v)))
                        if key in edge_dict_X: edge_dict_X[key].append(preds_X[idx].item())
                        else: edge_dict_X[key] = [preds_X[idx].item()]
                    edge_list_X = []
                    weight_list_X = []
                    for (u, v), weights in edge_dict_X.items():
                        edge_list_X.append([u,v])
                        weight_list_X.append(max(weights))
                    
                    edge_index_clean_X = torch.tensor(edge_list_X, dtype=torch.long).T
                    weights_clean_X = torch.tensor(weight_list_X, dtype=torch.float32)

                    if final_testing:
                        start_time_mwpm = time.perf_counter()
                        if args.code_type == 'rotated':
                            b_time_Z, s_time_Z, b_time_X, s_time_X = 0.0, 0.0, 0.0, 0.0
                            ber_pred_Z, ber_manhattan_Z, ler_pred_Z, ler_manhattan_Z = decode_and_evaluate_rotated(
                                graph.z_Z, graph.syndrome_Z, graph.stab_t_Z, 
                                edge_index_clean_Z, weights_clean_Z, args, p_val
                            )
                            ber_pred_X, ber_manhattan_X, ler_pred_X, ler_manhattan_X = decode_and_evaluate_rotated(
                                graph.z_X, graph.syndrome_X, graph.stab_t_X, 
                                edge_index_clean_X, weights_clean_X, args, p_val
                            )

                        else: #toric
                            ber_pred_Z, ber_manhattan_Z, ler_pred_Z, ler_manhattan_Z, b_time_Z, s_time_Z = decode_and_evaluate(graph.z_Z, graph.syndrome_Z, graph.stab_t_Z, edge_index_clean_Z, weights_clean_Z, args, graph.L, graph.num_nodes_Z, graph.y_Z)
                            ber_pred_X, ber_manhattan_X, ler_pred_X, ler_manhattan_X, b_time_X, s_time_X = decode_and_evaluate(graph.z_X, graph.syndrome_X, graph.stab_t_X, edge_index_clean_X, weights_clean_X, args, graph.L, graph.num_nodes_Z, graph.y_X)
                        end_time_cpu = time.perf_counter()

                        all_build_times.append((b_time_Z + b_time_X) / 2)
                        all_solve_times.append((s_time_Z + s_time_X) / 2)

                        total_latency_ms = (end_time_cpu - start_time_cpu) * 1000
                        cpu_time_ms = total_latency_ms - gpu_time_ms
                        
                        times_gpu.append(gpu_time_ms)
                        times_cpu_mwpm.append(cpu_time_ms)
                        times_total_latency.append(total_latency_ms)

                        logical_error_pred = max(ler_pred_Z, ler_pred_X)
                        logical_error_mann = max(ler_manhattan_Z, ler_manhattan_X)
                        
                        ler_list_pred.append(logical_error_pred)
                        ler_list_manhattan.append(logical_error_mann)

                        ber_list_pred.append(np.mean([ber_pred_Z, ber_pred_X]))
                        ber_list_manhattan.append(np.mean([ber_manhattan_Z, ber_manhattan_X]))

                    # ===== BP-OSD DECODING (Depolarization) =====
                    if final_testing:
                        try:
                            syndrome_full = graph.syndrome.cpu().numpy().astype(np.int32)
                            z_full = torch.cat([graph.z_Z, graph.z_X], dim=0).cpu() 

                            # Decode
                            z_hat_bposd = bposd_decoder.decode(syndrome_full)
                            z_hat_bposd_torch = torch.tensor(z_hat_bposd, dtype=torch.float32).cpu()
                            
                            # Pad the correction vector if necessary
                            pad_len = z_full.shape[0] - z_hat_bposd_torch.shape[0]
                            if pad_len > 0:
                                z_hat_bposd_torch = F.pad(z_hat_bposd_torch, (0, pad_len), value=0)
                                
                            corrected_bposd = (z_full + z_hat_bposd_torch) % 2
                            ber_bposd = torch.mean(corrected_bposd).item()
                            
                            logical_syndrome_bposd = torch.matmul(L_full_torch, corrected_bposd.long()) % 2
                            ler_bposd = torch.any(logical_syndrome_bposd).item()

                            ber_list_bposd.append(ber_bposd)
                            ler_list_bposd.append(ler_bposd)

                        except Exception as e:
                            ber_list_bposd.append(0.5) 
                            ler_list_bposd.append(1.0)
                        # ===== END BP-OSD DECODING =====

                else: # independent noise
                    edge_dict = {}
                    for idx in range(edge_index.size(1)):
                        u, v = edge_index[0, idx].item(), edge_index[1, idx].item()
                        key = tuple(sorted((u, v)))
                        if key in edge_dict: edge_dict[key].append(preds_cpu[idx].item())
                        else: edge_dict[key] = [preds_cpu[idx].item()]

                    edge_list = []
                    weight_list = []
                    for (u, v), weights in edge_dict.items():
                        edge_list.append([u, v])
                        weight_list.append(max(weights))
                    
                    edge_index_clean = torch.tensor(edge_list, dtype=torch.long).T
                    weights_clean = torch.tensor(weight_list, dtype=torch.float32)
                    

                # ===== Decoder indepedent =====
                    if final_testing == True:
                        if args.code_type == 'rotated':
                            ber_pred, ber_manhattan, ler_pred, ler_manhattan = decode_and_evaluate_rotated(
                                graph.z, graph.syndrome, graph.stab_t, 
                                edge_index_clean, weights_clean, args, p_val)
                        else: # Toric
                            ber_pred, ber_manhattan, ler_pred, ler_manhattan, b_time, s_time = decode_and_evaluate(graph.z, graph.syndrome, graph.stab_t, edge_index_clean, weights_clean, args, graph.L, 0, graph.y)


                        ber_list_pred.append(ber_pred)
                        ber_list_manhattan.append(ber_manhattan)
                        ler_list_pred.append(ler_pred)
                        ler_list_manhattan.append(ler_manhattan)

                    # ===== BP-OSD DECODING =====
                    if final_testing:
                        try:
                            # 1. Get the original full error vector and syndrome
                            if args.noise_type == "depolarization":
                                # For depolarization, combine the two parts (z_Z, z_X, syndrome_Z, syndrome_X)
                                syndrome_full = graph.syndrome.cpu().numpy().astype(np.int32)
                                z_full = torch.cat([graph.z_Z, graph.z_X], dim=0).cpu()
                            else: # independent noise
                                syndrome_full = graph.syndrome.cpu().numpy().astype(np.int32)
                                z_full = graph.z.cpu()

                            # 2. Decode: Returns the correction vector
                            z_hat_bposd = bposd_decoder.decode(syndrome_full)
                            z_hat_bposd_torch = torch.tensor(z_hat_bposd, dtype=torch.float32).cpu()
                            
                            # Pad the correction vector if necessary (e.g., if max fault_id < 2*L*L)
                            pad_len = z_full.shape[0] - z_hat_bposd_torch.shape[0]
                            if pad_len > 0:
                                z_hat_bposd_torch = F.pad(z_hat_bposd_torch, (0, pad_len), value=0)
                                
                            # 3. Check for logical errors (LER) and BER
                            corrected_bposd = (z_full + z_hat_bposd_torch) % 2
                            ber_bposd = torch.mean(corrected_bposd).item()
                            
                            logical_syndrome_bposd = torch.matmul(L_full_torch, corrected_bposd.long()) % 2
                            ler_bposd = torch.any(logical_syndrome_bposd).item()

                            ber_list_bposd.append(ber_bposd)
                            ler_list_bposd.append(ler_bposd)

                        except Exception as e:
                            ber_list_bposd.append(0.5) 
                            ler_list_bposd.append(1.0) # Assume failure
                    # ===== END BP-OSD DECODING =====

 

        # test accuracy logging - after finishing all batches 
        if not final_testing:
            acc = correct / total if total > 0 else 0
            p_dict[ps_test[i]] = acc
            acc_list.append(acc)

        # if args.noise_type == "depolarization":
        #     combined_ler_pred = []
        #     combined_ler_mann = []

        #     for j in range(0, len(ler_list_pred), 2):
        #         logical_error_pred = max(ler_list_pred[j], ler_list_pred[j+1])
        #         logical_error_mann = max(ler_list_manhattan[j], ler_list_manhattan[j+1])

        #         combined_ler_pred.append(logical_error_pred)
        #         combined_ler_mann.append(logical_error_mann)
            
        #     ler_list_pred = combined_ler_pred
        #     ler_list_manhattan = combined_ler_mann

        print(f"Test @ p={ps_test[i]:.3f}:")
        if final_testing == True:
            print(f"BER_Pred={np.mean(ber_list_pred):.4f}, BER_Man={np.mean(ber_list_manhattan):.4f}, " #ber and ler are per p
                  f"LER_Pred={np.mean(ler_list_pred):.4f}, LER_Man={np.mean(ler_list_manhattan):.4f}")
        else:
            print(f"acc={acc}")
        
        logging.info(f"Test @ p={ps_test[i]:.3f}:")
        if final_testing == True:
            logging.info(f"BER_Pred={np.mean(ber_list_pred):.4f}, BER_Man={np.mean(ber_list_manhattan):.4f}, BER_BPOSD={np.mean(ber_list_bposd):.4f}, "
                    f"LER_Pred={np.mean(ler_list_pred):.4f}, LER_Man={np.mean(ler_list_manhattan):.4f}, LER_BPOSD={np.mean(ler_list_bposd):.4f}")
        else:
            logging.info(f"acc:{acc}:")
        
        ber_all_manhattan.append(np.mean(ber_list_manhattan))
        ber_all_pred.append(np.mean(ber_list_pred))
        ler_all_manhattan.append(np.mean(ler_list_manhattan))
        ler_all_pred.append(np.mean(ler_list_pred))
        #-------------------
        ber_all_bposd.append(np.mean(ber_list_bposd))
        ler_all_bposd.append(np.mean(ler_list_bposd))
        #-------------------

    # Save all predictions
    save_path = os.path.join(args.path, 'predicted_edge_weights.pt')
    torch.save(results, save_path)
    logging.info(f"Predicted edge weights saved to {save_path}")
    print(f"Predicted edge weights saved to {save_path}")

    if epoch is not None: #TODO ableation
        all_weights_flat = torch.cat(collected_weights)
        
        data_save_name = f"weights_epoch_{epoch}.pt"
        torch.save(all_weights_flat, os.path.join(args.path, data_save_name))
        
        plot_save_name = f"weights_hist_epoch_{epoch}.png"
        plot_path = os.path.join(args.path, plot_save_name)
        
        plot_weight_hist(all_weights_flat, plot_path, epoch)
        
        print(f"Saved weight histogram to {plot_path}")
        logging.info(f"Saved weight histogram to {plot_path}")
 
    # plot ber and ler
    if final_testing:
        suffix = f"epoch_{epoch:03d}" if epoch is not None else "final"
        plot_ber_vs_p(ps_vals, ber_all_pred, ber_all_manhattan, args.path, suffix, epoch, writer=writer)
        plot_ler_vs_p(ps_vals, ler_all_pred, ler_all_manhattan, ler_all_bposd, args.path, suffix, epoch, writer=writer)
        plot_ber_vs_p_log(ps_vals, ber_all_pred, ber_all_manhattan, args.path, suffix, epoch, writer=writer)
        plot_ler_vs_p_log(ps_vals, ler_all_pred, ler_all_manhattan, args.path, suffix, epoch, writer=writer)
        # ---------------- TIME ----------------------
        if times_total_latency:
            # 1. GPU Time
            avg_gpu_time = np.mean(times_gpu)
            std_gpu_time = np.std(times_gpu)

            # 2. CPU Time 
            avg_cpu_time = np.mean(times_cpu_mwpm)
            std_cpu_time = np.std(times_cpu_mwpm)

            # Total Latency 
            avg_total_time = np.mean(times_total_latency)
            std_total_time = np.std(times_total_latency)

            print("\n--- Total Decoding Latency Breakdown (ms) ---")
            print(f"1. GNN Forward (GPU Time):   {avg_gpu_time:.4f} ms (Std: {std_gpu_time:.4f} ms)")
            print(f"2. MWPM Solve (CPU Time):    {avg_cpu_time:.4f} ms (Std: {std_cpu_time:.4f} ms)")
            print(f"3. Total Latency:            {avg_total_time:.4f} ms (Std: {std_total_time:.4f} ms)")

            logging.info("\n--- Total Decoding Latency Breakdown (ms) ---")
            logging.info(f"1. GNN Forward (GPU Time):   {avg_gpu_time:.4f} ms (Std: {std_gpu_time:.4f} ms)")
            logging.info(f"2. MWPM Solve (CPU Time):    {avg_cpu_time:.4f} ms (Std: {std_cpu_time:.4f} ms)")
            logging.info(f"3. Total Latency:            {avg_total_time:.4f} ms (Std: {std_total_time:.4f} ms)")
        if final_testing and all_solve_times:
            avg_gpu_time = np.mean(times_gpu)
            avg_total_time = np.mean(times_total_latency)
            
            # Calculate means from specific isolated timers
            avg_pure_solve = np.mean(all_solve_times)
            avg_python_build = np.mean(all_build_times)

            logging.info("\n--- MEAN DECODING PERFORMANCE BREAKDOWN (ms) ---")
            logging.info(f"1. GNN Forward (GPU):      {avg_gpu_time:.4f} ms")
            logging.info(f"2. Python Graph Build:     {avg_python_build:.4f} ms") # The overhead
            logging.info(f"3. MWPM Solve (Pure CPU):  {avg_pure_solve:.4f} ms")    # The real solver speed
            logging.info(f"------------------------------------------------")
            logging.info(f"Total Mean Latency:        {avg_total_time:.4f} ms")
        # -------------------------------------------------------------
    #avg_acc = sum(acc_list) / len(acc_list) if acc_list else 0.0

    
    return 




#plots
def plot_weight_hist(weights_tensor, save_path, epoch): #TODO ableation
    weights_np = weights_tensor.numpy()
    
    plt.figure(figsize=(10, 6))
    plt.hist(weights_np, bins=100, alpha=0.7, color='blue', density=True)
    
    plt.title(f"Weight Distribution - Epoch {epoch}")
    plt.xlabel("Predicted Probability")
    plt.ylabel("Density (Log Scale)")
    plt.yscale('log') 
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    plt.savefig(save_path)
    plt.close()

def plot_training(train_losses, save_dir):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training Loss & Test Accuracy")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "training_plot.png"))
    plt.close()

def plot_learning_rate(learning_rates, save_dir):
    import matplotlib.pyplot as plt
    epochs = range(1, len(learning_rates) + 1)
    plt.figure(figsize=(10, 5))
    plt.plot(epochs, learning_rates, color='tab:red')
    plt.xlabel("Epoch")
    plt.ylabel("Learning Rate")
    plt.title("Learning Rate Schedule")
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "learning_rate_plot.png"))
    plt.close()

    
def plot_test_acc(test_accuracies, save_dir, final_epoch):
    import matplotlib.pyplot as plt
    import os

    plt.figure(figsize=(10, 6))
    for p in sorted(test_accuracies):
        acc_list = test_accuracies[p]
        # Make the x-axis match the test epochs
        if len(acc_list) == 1:
            epochs = [final_epoch]
        else:
            epochs = [40 * i for i in range(1, len(acc_list))]  
            epochs.append(final_epoch)  # Add final test
        plt.plot(epochs, acc_list, label=f"p={p:.3f}", marker='o')

    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Test Accuracy per Noise Level (p)")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, "test_accuracy_per_p.png"))
    plt.close()

def plot_ber_vs_p(ps_vals, ber_pred, ber_man, save_dir, suffix="", epoch = None, writer=None):
    plt.figure()
    plt.plot(ps_vals, ber_pred, label='BER Predicted', marker = "s")
    plt.plot(ps_vals, ber_man, label='BER Manhattan', marker = "s")
    plt.xlabel('Physical Error Rate (p)')
    plt.ylabel('Bit Error Rate (BER)')
    title = f'BER vs Physical Error Rate (epoch {epoch})' if epoch is not None else 'BER vs Physical Error Rate (final)'
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    filename = f"ber_vs_p_{suffix}.png" if suffix else "ber_vs_p.png"
    plt.savefig(os.path.join(save_dir, filename))
    if writer is not None and epoch is not None:
        writer.add_figure("BER_vs_p", plt.gcf(), global_step=epoch)
    plt.close()



def plot_ler_vs_p(ps_vals, ler_pred, ler_man, ler_bposd, save_dir, suffix="", epoch = None, writer=None):
    plt.figure()
    plt.plot(ps_vals, ler_pred, label='LER Predicted', marker = "s")
    plt.plot(ps_vals, ler_man, label='LER Manhattan', marker = "s")
    #
    #plt.plot(ps_vals, ler_bposd, label='LER BP-OSD', marker = "^")
    #
    plt.xlabel('Physical Error Rate (p)')
    plt.ylabel('Logical Error Rate (LER)')
    title = f'LER vs Physical Error Rate (epoch {epoch})' if epoch is not None else 'LER vs Physical Error Rate (final)'
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    filename = f"ler_vs_p_{suffix}.png" if suffix else "ler_vs_p.png"
    plt.savefig(os.path.join(save_dir, filename))
    if writer is not None and epoch is not None:
        writer.add_figure("LER_vs_p", plt.gcf(), global_step=epoch)
    plt.close()

    global ler_vs_epochs_data
    if epoch is not None:
        for i, p_rate in enumerate(ps_vals):
            if p_rate not in ler_vs_epochs_data:
                ler_vs_epochs_data[p_rate] = {'epochs': [], 'lers': []}
            
            ler_vs_epochs_data[p_rate]["epochs"].append(epoch)
            ler_vs_epochs_data[p_rate]['lers'].append(ler_pred[i]) # Using the predicted LER
        plot_ler_vs_epochs(ler_vs_epochs_data, save_dir)
            


def plot_ber_vs_p_log(ps_vals, ber_pred, ber_man, save_dir, suffix="", epoch = None, writer=None):
    plt.figure()
    plt.plot(ps_vals, ber_pred, label='BER Predicted', marker='s')
    plt.plot(ps_vals, ber_man, label='BER Manhattan', marker='s')
    plt.xlabel('Physical Error Rate (p)')
    plt.ylabel('BER (log scale)')
    plt.yscale('log')
    title = f'BER vs p (Log Scale) (epoch {epoch})' if epoch is not None else 'BER vs p (Log Scale) (final)'
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    filename = f"ber_vs_p_log_{suffix}.png" if suffix else "ber_vs_p_log.png"
    plt.savefig(os.path.join(save_dir, filename))
    if writer is not None and epoch is not None:
        writer.add_figure("BER_vs_p_log", plt.gcf(), global_step=epoch)
    plt.close()

def plot_ler_vs_p_log(ps_vals, ler_pred, ler_man, save_dir, suffix="", epoch = None, writer = None):
    plt.figure()
    plt.plot(ps_vals, ler_pred, label='LER Predicted', marker='s')
    plt.plot(ps_vals, ler_man, label='LER Manhattan', marker='s')
    plt.xlabel('Physical Error Rate (p)')
    plt.ylabel('LER (log scale)')
    plt.yscale('log')
    title = f'LER vs p (Log Scale) (epoch {epoch})' if epoch is not None else 'LER vs p (Log Scale) (final)'
    plt.title(title)
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    filename = f"ler_vs_p_log_{suffix}.png" if suffix else "ler_vs_p_log.png"
    plt.savefig(os.path.join(save_dir, filename))
    if writer is not None and epoch is not None:
        writer.add_figure("LER_vs_p_log", plt.gcf(), global_step=epoch)
    plt.close()


def plot_ler_vs_epochs(ler_data, save_dir):
 
    plot_folder = os.path.join(save_dir, 'epoch_plots')
    os.makedirs(plot_folder, exist_ok=True)

    plt.figure(figsize=(12, 7))
    for p_rate, data in ler_data.items():
        # Use a float for the p_rate label for consistency
        p_label = float(p_rate) 
        if data['epochs']: # Only plot if there is data
            plt.plot(data['epochs'], data['lers'], label=f'p = {p_label:.3f}', marker='o')
    
    plt.xlabel("Epoch")
    plt.ylabel("Logical Error Rate (LER)")
    plt.title("LER vs. Epochs for each Physical Error Rate")
    #plt.yscale('log') # Using a log scale is often helpful for error rates
    plt.legend()
    plt.grid(True, which="both", ls="--")
    plt.tight_layout()
    
    save_path = os.path.join(plot_folder, "ler_vs_epochs_log.png")
    plt.savefig(save_path)
    plt.close()
    print(f"LER vs. Epochs plot updated and saved to {save_path}")
    logging.info(f"LER vs. Epochs plot updated and saved to {save_path}")

#======== decoder ===========


def decode_and_evaluate(z, syndrome, stab_t, edge_index_clean, predicted_weights, args, L, node_shift, y):   #edge_index_clean is edge_index with no duplicates
    syndrome = syndrome.cpu().numpy().astype(np.int32)
    z = z.cpu()
    logical_mat = args.code.logic_matrix
    logical_mat = logical_mat.cpu()
    noise_type = args.noise_type
    # if isinstance(stab_t, torch.Tensor):
    #     stab_t = int(stab_t.item())
    # else:
    #     stab_t = int(stab_t)  # 0 => X-stab, 1 => Z-stab

    if edge_index_clean.numel() == 0:
        # No edges means no defects, so no syndrome.
        z_hat_pred = torch.zeros_like(z)
        z_hat_manhattan = torch.zeros_like(z)
        ber_pred = 0.0
        ber_manhattan = 0.0
        ler_pred = 0
        ler_manhattan = 0
        return ber_pred, ber_manhattan, ler_pred, ler_manhattan, 0.0, 0.0

    used_nodes = set(edge_index_clean.flatten().tolist())
    max_node_idx = max(used_nodes)

    assert np.all(syndrome[max_node_idx + 1:] == 0), \
    f"Syndrome contains active stabilizers beyond used nodes: {np.nonzero(syndrome[max_node_idx + 1:])}"

    if stab_t == 0 : # X stabilizers decoding - z errors
        syndrome_trimmed = syndrome[:max_node_idx + 1 - L*L].astype(np.int32) #syndrome length should be as the maximum deteced stabilzer index
        edge_index_clean = edge_index_clean - (L*L)
    elif stab_t == 1 : # Z stabilizers decoding - x errors
        syndrome_trimmed = syndrome[:max_node_idx + 1].astype(np.int32) #syndrome length should be as the maximum deteced stabilzer index
    else:
        raise ValueError(f"unknown stab_t: {stab_t}")
    



    t0_build = time.perf_counter()
    pred_matching = Matching()
    current_cache = TORIC_PATH_CACHE_X if stab_t == 0 else TORIC_PATH_CACHE_Z
    for i in range(edge_index_clean.shape[1]):
        u, v = int(edge_index_clean[0, i].item()), int(edge_index_clean[1, i].item())  # take an edge

        fault_id_set = current_cache.get((u, v))

        w = predicted_weights[i].item() # the index corresponds to the weight of the current edge
        w_prime = -math.log(max(w, 1e-6))  
        #w_prime = float(predicted_weights[i].item())
        pred_matching.add_edge(u, v, fault_ids=fault_id_set  ,weight=w_prime)
    t1_build = time.perf_counter()


    # print("edge_index_clean:", edge_index_clean)
    # print("weights_clean:", predicted_weights)
    # print("syndrome shape:", syndrome.shape)
    t0_solve = time.perf_counter()
    # Decode syndrome - in indepenent correct x errors
    z_hat_pred_list = pred_matching.decode(syndrome_trimmed) #decode and give qubits correction - the length of the correction is the max fault_id + 1 meaning the length containts the qubits till the maximum qubit in fault id's
    t1_solve = time.perf_counter()
    #convert to Tensor
    z_hat_pred = torch.tensor(z_hat_pred_list, dtype=torch.float32).cpu()

    # pad if needed the correction vector for BER and LER since it's length is max fault_id + 1
    if len(z_hat_pred) < (2 * L * L):
        pad_len = (2 * L * L) - len(z_hat_pred)
        z_hat_pred = F.pad(z_hat_pred, (0, pad_len), value=0)


    manhattan_matching = Matching()
    for j in range(edge_index_clean.shape[1]):
        u, v = int(edge_index_clean[0, j].item()), int(edge_index_clean[1, j].item())

        # if noise_type == "independent":
        #     fault_id = path_between_stabilizers(u, v, L=L)[1] #the qubits this edge represents
        # elif noise_type == "depolarization":
        #     if stab_t == 0 : # X stabilizers decoding - z errors
        #         fault_id = path_between_stabilizers_X(u, v, L=L)[1]
        #     elif stab_t == 1 : # Z stabilizers decoding - x errors
        #         fault_id = path_between_stabilizers(u, v, L=L)[1]
        #     else:
        #         raise ValueError(f"unknown stab_t: {stab_t}")
        # else:
        #     raise ValueError(f"unknown noise model: {noise_type}")
        # fault_id_set = set()
        # for q in fault_id:
        #     fault_id_set.add(q)
        fault_id_set = current_cache.get((u, v))
        w = manhattan_dist_calc(u, v, L)  # MWPM trying to minimize the distance
        manhattan_matching.add_edge(u, v, fault_ids=fault_id_set, weight=w)

    # Decode syndrome
    z_hat_manhattan_list = manhattan_matching.decode(syndrome_trimmed) #the length of the correction is the max fault_id + 1

    # convert to tensor
    z_hat_manhattan = torch.tensor(z_hat_manhattan_list, dtype=torch.float32).cpu()

    # pad if needed the correction vector for BER and LER
    if len(z_hat_manhattan) < (2 * L * L):
        pad_len = (2 * L * L) - len(z_hat_manhattan)
        z_hat_manhattan = F.pad(z_hat_manhattan, (0, pad_len), value=0)


    #=====Logic Matrix=====
    if noise_type == "independent":
        final_logic_mat = logical_mat
    if noise_type == "depolarization":
        if stab_t == 1: #Z stabilizers - x errors
            final_logic_mat = logical_mat[:2, :2*L*L]
        if stab_t == 0: #X stabilizers - z errors
            final_logic_mat = logical_mat[2:, 2*L*L:]




    # === BER & LER ===

    build_time = (t1_build - t0_build) * 1000
    solve_time = (t1_solve - t0_solve) * 1000


    # correcting
    corrected_pred = (z + z_hat_pred) % 2 #flips the qubits according to the correction operator
    corrected_manhattan = (z + z_hat_manhattan) % 2
    

    # BER
    ber_pred = torch.mean(corrected_pred).item()
    ber_manhattan = torch.mean(corrected_manhattan).item()

    # LER
    corrected_pred_long = corrected_pred.long()
    logical_syndrome_pred = torch.matmul(final_logic_mat, corrected_pred_long) % 2
    ler_pred = torch.any(logical_syndrome_pred).item()

    corrected_manhattan_long = corrected_manhattan.long()
    logical_syndrome_manhattan = torch.matmul(final_logic_mat, corrected_manhattan_long) % 2
    ler_manhattan = torch.any(logical_syndrome_manhattan).item()

    return ber_pred, ber_manhattan, ler_pred, ler_manhattan, build_time, solve_time




    


def manhattan_dist_calc(u, v, L):
    u_pos = stab_to_coord(u, L)
    v_pos = stab_to_coord(v, L)
    dist = min(abs(u_pos[0] - v_pos[0]), L - abs(u_pos[0] - v_pos[0])) + min(abs(u_pos[1] - v_pos[1]), L - abs(u_pos[1] - v_pos[1]))
    return dist





def decode_and_evaluate_rotated(z, syndrome, stab_t, edge_index_clean, predicted_weights, args, p_val):
    #print("first sample")
    # --- setup ---
    z = z.cpu()
    #print("qubits vector",z)
    syndrome = syndrome.cpu().numpy().astype(np.int32)
    #print("syndrome" ,syndrome)
    stab_t = int(stab_t.item()) # 1 for Z, 0 for X
    #print("stab_t", stab_t)
    

    L = args.code_L
    L_qubits = L * L
    precomputed_data = args.precomputed_data
    logical_mat = args.code.logic_matrix.cpu()
    #print("logical matrix", logical_mat)
    
    num_stabs_per_type = (L * L - 1) // 2 # 12 for L=5
    
    H_Z, H_X = get_rotated_H_matrices(L) #pc matrices
    #print("H_Z", H_Z)
    #print("H_X", H_X)
    virtual_node_idx = num_stabs_per_type

    if edge_index_clean.numel() == 0:
        if not np.any(syndrome):
            return 0.0, 0.0, 0, 0 # ber_pred, ber_baseline, ler_pred, ler_baseline
    

    #print("max_node_idx", max_node_idx)

    if stab_t == 1: # Z stabilizers (X errors)
        H_matrix = H_Z
        dist_map = precomputed_data['z_dist_map'] #dist between nodes
        edge_path_map = precomputed_data['z_edge_path_map']
        boundary_dist_map = precomputed_data['z_boundary_dist_map'] #dist to bounds
        boundary_edge_path_map = precomputed_data['z_boundary_edge_path_map'] 
        
        final_logic_mat = logical_mat[0:1, :L_qubits]

        max_node_idx = 0
        if edge_index_clean.numel() > 0:
            max_node_idx = edge_index_clean.max().item()
            all_nodes = edge_index_clean.flatten().unique()
            real_nodes = all_nodes[all_nodes != virtual_node_idx]
            max_real_node_idx = real_nodes.max().item()
            #print("max real node idx", max_real_node_idx)
        
        # Trim syndrome based on max node index
        syndrome_trimmed = syndrome[:max_real_node_idx + 1]
        edge_index_clean_shifted = edge_index_clean
        #print("syndrome_trimmed", syndrome_trimmed)
        #print("edge_index", edge_index_clean_shifted)

    elif stab_t == 0: # X stabilizers (Z errors)
        H_matrix = H_X
        dist_map = precomputed_data['x_dist_map']
        edge_path_map = precomputed_data['x_edge_path_map']
        boundary_dist_map = precomputed_data['x_boundary_dist_map']
        boundary_edge_path_map = precomputed_data['x_boundary_edge_path_map']
        
        final_logic_mat = logical_mat[1:2, L_qubits:]
        
        # Trim syndrome and shift edge indices based on max node index
        shift = num_stabs_per_type + 1 # one virtual
        edge_index_clean_shifted = edge_index_clean - shift #shift indices- works properly

        max_node_idx = 0
        if edge_index_clean_shifted.numel() > 0:
            max_node_idx = edge_index_clean_shifted.max().item()
            all_nodes = edge_index_clean_shifted.flatten().unique()
            real_nodes = all_nodes[all_nodes != virtual_node_idx]
            max_real_node_idx = real_nodes.max().item()
            #print("max real node idx", max_real_node_idx)

        #print("shift", shift)
        #print("DEBUGGGGG")
        #print("SYNDROME", syndrome)
        #print("SYNDROME TRIMMED WITH SHIFT", syndrome_trimmed)
        syndrome_trimmed = syndrome[:max_real_node_idx + 1]
        #print("SYNDROME TRIMMED WITHOUT SHIFT", syndrome_trimmed)
        #print("edge_index_clean BEFORE shift", edge_index_clean)
        #print("syndrome_trimmed", syndrome_trimmed)
        #print("edge_index after shift", edge_index_clean_shifted)

    else:
        raise ValueError(f"unknown stab_t: {stab_t}")

    num_detectors = num_stabs_per_type
    pred_matching = Matching()
    baseline_matching = Matching() # This will use the precomputed dist_map

    # --- Add Edges between Real Defects ---
    for i in range(edge_index_clean_shifted.shape[1]):
        u = edge_index_clean_shifted[0, i].item()
        v = edge_index_clean_shifted[1, i].item()

        if u == virtual_node_idx or v == virtual_node_idx: # add only real nodes edges
            continue

        key = (min(u, v), max(u, v))

        # Get weights
        w_pred = predicted_weights[i].item()
        w_prime = -math.log(max(w_pred, 1e-6))
        
        if key not in dist_map:
             continue
        w_baseline = dist_map[key] 


        if key not in edge_path_map:
            continue
            
        edge_path = edge_path_map[key]
        qubit_list = get_qubits_from_edge_path(edge_path, H_matrix)
        #print("qubit list" ,qubit_list)
        
        fault_id_set = set(qubit_list)


        pred_matching.add_edge(u, v, fault_ids=fault_id_set, weight=w_prime)
        #print(f"w_pred after log for {u} and {v} :", w_prime)
        baseline_matching.add_edge(u, v, fault_ids=fault_id_set, weight=w_baseline)
        #print(f"w_baseline for {u} and {v} :", w_baseline)



    # --- Add Boundary Edges ---

    
    for i in range(edge_index_clean_shifted.shape[1]):
        u = edge_index_clean_shifted[0, i].item()
        v = edge_index_clean_shifted[1, i].item()

        # We want edges involving the virtual node
        if u != virtual_node_idx and v != virtual_node_idx:
            continue
        
        # Identify the real node
        real_node = u if v == virtual_node_idx else v

        # --- Prepare Weights ---
        # Predicted Weight
        w_pred = predicted_weights[i].item()
        w_prime = -math.log(max(w_pred, 1e-9))
        #print("node to boundary", u, v)
        #print("predicted dist to boundary", w_prime)
        # --- Add to Matchings ---
        if real_node in boundary_edge_path_map:
            # boundary_edge_path_map ALREADY contains the list of qubits
            qubits = boundary_edge_path_map[real_node]
            qubits_set = set(qubits)
            #print("qubits to bounds", qubits_set)

            # 1. Add to Pred
            pred_matching.add_boundary_edge(real_node, fault_ids=qubits_set, weight=w_prime)

            # 2. Add to Baseline
            if real_node in boundary_dist_map:
                dist = boundary_dist_map[real_node]
                w_base = dist 
                baseline_matching.add_boundary_edge(real_node, fault_ids=qubits_set, weight=w_base)
                #print(f"steps distance from node {real_node} to bounds", w_base)

    # --- Decode ---
    z_hat_pred_list = pred_matching.decode(syndrome_trimmed)
    z_hat_baseline_list = baseline_matching.decode(syndrome_trimmed)

    z_hat_pred = torch.tensor(z_hat_pred_list, dtype=torch.float32).cpu()
    z_hat_baseline = torch.tensor(z_hat_baseline_list, dtype=torch.float32).cpu()

    #print("z_hat_pred", z_hat_pred)
    #print("z_hat_baseline", z_hat_baseline)

    # --- Pad Correction Vectors ---
    pad_len_pred = z.shape[0] - z_hat_pred.shape[0]
    #print("Pad_len_pred", pad_len_pred)
    if pad_len_pred > 0:
        z_hat_pred = F.pad(z_hat_pred, (0, pad_len_pred), value=0)
    elif pad_len_pred < 0:
        z_hat_pred = z_hat_pred[:z.shape[0]]

    pad_len_base = z.shape[0] - z_hat_baseline.shape[0]
    #print("Pad_len_base", pad_len_base)
    if pad_len_base > 0:
        z_hat_baseline = F.pad(z_hat_baseline, (0, pad_len_base), value=0)
    elif pad_len_base < 0:
        z_hat_baseline = z_hat_baseline[:z.shape[0]]

    # --- LER Evaluation ---
    corrected_pred = (z + z_hat_pred) % 2
    corrected_baseline = (z + z_hat_baseline) % 2

    ber_pred = torch.mean(corrected_pred).item()
    ber_baseline = torch.mean(corrected_baseline).item()

    #print("final_logic_mat:", final_logic_mat)
    
    logical_syndrome_pred = torch.matmul(final_logic_mat, corrected_pred.long()) % 2
    ler_pred = torch.any(logical_syndrome_pred).item()
    #print("ler_pred", ler_pred)

    logical_syndrome_baseline = torch.matmul(final_logic_mat, corrected_baseline.long()) % 2
    ler_baseline = torch.any(logical_syndrome_baseline).item()
    #print("ler_base", ler_baseline)

    # Return baseline "manhattan" equivalent for plotting
    return ber_pred, ber_baseline, ler_pred, ler_baseline