# In[1]: imports & device setup
import h5py
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from pathlib import Path
from typing import List, Tuple, Dict
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# In[2]: metric‐computation functions (as provided, with minimal adaptation)
def find_segments_ones(array: np.ndarray) -> List[Tuple[int,int]]:
    ones_idx = np.where(array == 1)[0]
    if len(ones_idx) == 0:
        return []
    split_idx = np.where(np.diff(ones_idx) > 1)[0] + 1
    split_ones_idx = np.split(ones_idx, split_idx)
    return [(seg[0], seg[-1] + 1) for seg in split_ones_idx]

def exon_level(threshold: float, y_labels: np.ndarray, p_labels: np.ndarray, metrics: Dict[str,int]):
    y_segs = find_segments_ones((y_labels >= threshold).astype(int))
    p_segs = find_segments_ones((p_labels >= threshold).astype(int))
    y_set = set(y_segs)
    p_set = set(p_segs)
    metrics[f'TP_{threshold}'] += len(y_set & p_set)
    metrics[f'FP_{threshold}'] += len(p_set - y_set)
    metrics[f'FN_{threshold}'] += len(y_set - p_set)

def compute_exon_metrics(
    all_y: List[np.ndarray],
    all_p: List[np.ndarray],
) -> Dict[str, float]:
    # initialize counters
    thresholds = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,0.95]
    counts = {f'{m}_{t}': 0 for t in thresholds for m in ('TP','FP','FN')}
    # accumulate over transcripts
    y_all = np.concatenate(all_y, axis=0)  # shape (sum_L, 5)
    p_all = np.concatenate(all_p, axis=0)  # shape (sum_L, 5)
    for t in thresholds:
        exon_level(t, y_all[:,1], p_all[:,1], counts)
    # compute precision/recall/f1
    metrics = {}
    for t in thresholds:
        tp = counts[f'TP_{t}']; fp = counts[f'FP_{t}']; fn = counts[f'FN_{t}']
        rec = tp/(tp+fn) if (tp+fn)>0 else 0.0
        prec = tp/(tp+fp) if (tp+fp)>0 else 0.0
        f1 = 2*prec*rec/(prec+rec) if (prec+rec)>0 else 0.0
        metrics[f'precision_exon_level_{t}'] = prec
        metrics[f'recall_exon_level_{t}']    = rec
        metrics[f'f1_exon_level_{t}']        = f1
    metrics['max_f1_exon_level'] = max(metrics[f'f1_exon_level_{t}'] for t in thresholds)
    return metrics

# In[3]: open HDF5 files, collect keys, create Datasets & DataLoaders

import h5py
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader

# adjust paths as needed
train_labels_path = '/home/jovyan/shares/SR003.nfs2/mane_no_intergenic_combined/mane_transcript_train_dataset_max_exon_cds.hdf5'
val_labels_path   = '/home/jovyan/shares/SR003.nfs2/mane_no_intergenic_combined/mane_transcript_val_dataset_max_exon_cds.hdf5'
train_emb_path    = './mane_transcript_train_dataset_max_exon_cds_evo2_embeddings_compressed_length_no_greater_32k.h5'
val_emb_path      = './mane_transcript_val_dataset_max_exon_cds_evo2_embeddings_compressed_length_no_greater_32k.h5'

train_lbl_f = h5py.File(train_labels_path, 'r', swmr=True)
val_lbl_f   = h5py.File(val_labels_path,   'r', swmr=True)
train_emb_f = h5py.File(train_emb_path,    'r', swmr=True)
val_emb_f   = h5py.File(val_emb_path,      'r', swmr=True)

# collect transcript keys
train_keys = sorted(train_lbl_f.keys())
val_keys   = sorted(val_lbl_f.keys())
n_train = len(train_keys)
n_val   = len(val_keys)

# sanity‐check embedding keys
assert set(train_keys) == set(train_emb_f.keys()), "Train embed file keys mismatch!"
assert set(val_keys)   == set(val_emb_f.keys()),   "Val embed file keys mismatch!"

print(f"→ number of training transcripts:   {n_train}")
print(f"→ number of validation transcripts: {n_val}")

# Dataset wrapper
class H5TranscriptDataset(Dataset):
    def __init__(self, lbl_f, emb_f, keys, max_seq_length=None):
        self.lbl_f = lbl_f
        self.emb_f = emb_f
        self.keys = keys
        self.max_seq_length = max_seq_length

        # clear or create the broken‐indices log
        open('broken_indices.txt', 'w').close()

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        original_idx = idx
        while True:
            key = self.keys[idx]
            try:
                # read labels
                y_np = self.lbl_f[key]['labels_atcg'][:, :]
                # read embeddings (only slice, no full-array materialization)
                x_np = self.emb_f[key]['embeddings'][0, :, :]
                # optional truncation
                if self.max_seq_length:
                    x_np = x_np[:self.max_seq_length]
                    y_np = y_np[:self.max_seq_length]
                # to tensors
                x = torch.from_numpy(x_np).float()
                y = torch.from_numpy(y_np).float()

                assert x.shape[0] == y.shape[0]
                return x, y

            except OSError:
                # log this broken index
                with open('broken_indices.txt', 'a') as f:
                    f.write(f"{idx}\n")
                # skip to next index (wrap around)
                idx = (idx + 1) % len(self.keys)
                if idx == original_idx:
                    # no good samples left
                    raise RuntimeError("All HDF5 samples appear to be corrupted.")


max_seq_length   = 32000      # e.g. 20000 to restrict letters per transcript, or None

# instantiate datasets
train_dataset = H5TranscriptDataset(train_lbl_f, train_emb_f, train_keys, max_seq_length)
val_dataset   = H5TranscriptDataset(val_lbl_f,   val_emb_f,   val_keys,   max_seq_length)

# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=32,
    pin_memory=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=32,
    pin_memory=True
)

print(f"→ train_loader: {len(train_loader)} batches, shuffle=True")
print(f"→ val_loader:   {len(val_loader)} batches, shuffle=True")



# In[4]: hyperparameters & model/optimizer setup
# --- hyperparameters you can tweak ---
lr               = 5e-4      # learning rate
weight_decay     = 0.0 # 1e-4      # AdamW weight decay
num_iterations   = 4_000_000     # total gradient steps
val_period       = 30_000       # run validation every N iterations
val_sample_size  = 1000      # number of transcripts to sample at validation
# ------------------------------------

model = nn.Linear(1920, 5).to(device)
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
criterion = nn.BCEWithLogitsLoss()
best_val_metric = -1.0

# In[5]: training + validation loop
train_iter = iter(train_loader)

for it in tqdm(range(1, num_iterations+1)):
    # ---- TRAIN ----
    try:
        x_batch, y_batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        x_batch, y_batch = next(train_iter)

    # x_batch: (1, L, 1920), y_batch: (1, L, 5)
    x = x_batch.squeeze(0).to(device)
    y = y_batch.squeeze(0).to(device)

    optimizer.zero_grad()
    logits = model(x)
    loss = criterion(logits, y)
    loss.backward()
    optimizer.step()

    if it % 100 == 0:
        print(f"[train] iter {it:6d} — loss {loss.item():.4f}")

    # ---- VALIDATION ----
    if it % val_period == 0:
        # sample some validation indices
        idxs = random.sample(range(len(val_dataset)), min(val_sample_size, len(val_dataset)))
        all_y, all_p = [], []

        model.eval()
        with torch.no_grad():
            for idx in idxs:
                xv, yv = val_dataset[idx]
                xv = xv.to(device)
                logits_v = model(xv)
                probs_v  = torch.sigmoid(logits_v).cpu().numpy()
                all_p.append(probs_v)       # shape (L,5)
                all_y.append(yv.numpy())    # shape (L,5)
        model.train()

        # compute & print metrics
        metrics = compute_exon_metrics(all_y, all_p)
        mv = metrics['max_f1_exon_level']
        print(f"[val]   iter {it:6d} — max_f1_exon_level {mv:.4f}")

        # append metrics to text file
        with open('validation_metrics.txt', 'a') as f:
            f.write(f"Iteration: {it}\n")
            for name, val in metrics.items():
                f.write(f"{name}: {val:.6f}\n")
            f.write("\n")

        # checkpoint if improved
        if mv > best_val_metric:
            best_val_metric = mv
            ckpt = (
                f"ckpt_lr{lr}_wd{weight_decay}"
                f"_maxf1{mv:.4f}.pt"
            )
            torch.save({
                'iteration': it,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_val_metric': best_val_metric,
                'hyperparams': {
                    'lr': lr,
                    'weight_decay': weight_decay,
                    'max_seq_length': max_seq_length
                }
            }, ckpt)
            print(f" *** saved checkpoint: {ckpt}")