import itertools
import os
import json
from typing import *
import random
from argparse import ArgumentParser
from tqdm import tqdm
import math

import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
from PIL import Image
import time

import blackbox
from mipll_pruning_algorithm import structural_pruning

class HWFDataset(torch.utils.data.Dataset):
  def __init__(self, root: str, prefix: str, split: str, max_length: int = 7, exact_length: bool = False, num_train_samples: int = None):
    super(HWFDataset, self).__init__()
    self.root = root
    self.split = split
    
    # Load metadata from appropriate file
    md = json.load(open(os.path.join(root, f"HWF/hwf_{max_length}_{split}.json")))
    
    # Filter metadata based on exact length or up-to length
    if max_length > 0:
      if exact_length:
        self.metadata = [m for m in md if len(m['img_paths']) == max_length]
        print(f"Using exact length filtering: {len(self.metadata)} samples with length exactly {max_length}")
      else:
        self.metadata = [m for m in md if len(m['img_paths']) <= max_length]
        print(f"Using up-to length filtering: {len(self.metadata)} samples with length up to {max_length}")
    else:
      self.metadata = md
      print(f"Using all samples: {len(self.metadata)} samples")
    
    # Limit training samples if specified
    if num_train_samples is not None and split == "train":
      original_count = len(self.metadata)
      self.metadata = random.sample(self.metadata, min(num_train_samples, len(self.metadata)))
      print(f"Limited training samples: {len(self.metadata)}/{original_count} samples")
    self.img_transform = torchvision.transforms.Compose([
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize((0.5,), (1,))])

  def __getitem__(self, index):
    sample = self.metadata[index]

    # Input is a sequence of images
    img_seq = []
    for img_path in sample["img_paths"]:
      img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path)
      img = Image.open(img_full_path).convert("L")
      img = self.img_transform(img)
      img_seq.append(img)
    img_seq_len = len(img_seq)

    # Output is the "res" in the sample of metadata
    res = sample["res"]

    # GT is the ground truth label for each image
    gt = tuple(sample["expr"])

    # Return (input, output, gt) tuple
    return (img_seq, img_seq_len, res, gt)

  def __len__(self):
    return len(self.metadata)

  @staticmethod
  def collate_fn(batch):
    max_len = max([img_seq_len for (_, img_seq_len, _, _) in batch])
    zero_img = torch.zeros_like(batch[0][0][0])
    pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq))
    img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _, _) in batch])
    img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _, _) in batch])
    results = torch.stack([torch.tensor(res) for (_, _, res, _) in batch])
    gts = [gt for (_, _, _, gt) in batch]
    return (img_seqs, img_seq_len, results, gts)


def hwf_loader(data_dir, batch_size, prefix, max_length=7, exact_length=False, num_train_samples=None):
  train_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "train", max_length, exact_length, num_train_samples), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "test", max_length, exact_length), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True)
  return (train_loader, test_loader)


class SymbolNet(nn.Module):
  def __init__(self):
    super(SymbolNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1)
    self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1)
    self.fc1 = nn.Linear(30976, 128)
    self.fc2 = nn.Linear(128, 14)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.max_pool2d(x, 2)
    x = F.dropout(x, p=0.25, training=self.training)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = F.dropout(x, p=0.5, training=self.training)
    x = self.fc2(x)
    return F.softmax(x, dim=1)


def hwf_eval(symbols: List[str]):
  # Sanitize the input
  for i, s in enumerate(symbols):
    if i % 2 == 0 and not s.isdigit(): raise Exception("BAD")
    if i % 2 == 1 and s not in ["+", "-", "*", "/"]: raise Exception("BAD")

  # Evaluate the result
  result = eval("".join(symbols))

  return result


class HWFNet(nn.Module):
  def __init__(self, sample_count):
    super(HWFNet, self).__init__()
    self.symbol_cnn = SymbolNet()
    self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"]
    self.symbol_idx_map = {s: i for i, s in enumerate(self.symbols)}
    self.eval_formula = blackbox.BlackBoxFunction(
      hwf_eval,
      (blackbox.ListInputMapping(7, blackbox.DiscreteInputMapping(self.symbols)),),
      blackbox.UnknownDiscreteOutputMapping(fallback=0),
      sample_count=sample_count,
      sample_strategy="categorical",
      aggregate_strategy="minmax")

  def forward(self, img_seq, img_seq_len, preimages=None, labels=None):
    batch_size, formula_length, _, _, _ = img_seq.shape
    length = [l.item() for l in img_seq_len]
    symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, formula_length, -1)
    
    # Compute CNN predictions and accuracy if labels provided
    cnn_correct = 0
    total_symbols = 0
    if labels is not None:
      for i in range(batch_size):
        seq_len = img_seq_len[i].item()
        for j in range(seq_len):
          # Get CNN prediction for this position
          cnn_pred = torch.argmax(symbol[i, j])
          
          # Convert ground truth symbol to class index
          gt_symbol = labels[i][j]
          if gt_symbol.isdigit():
            gt_class = int(gt_symbol)
          else:
            # Map operators to classes 10-13
            operator_map = {'+': 10, '-': 11, '*': 12, '/': 13}
            gt_class = operator_map[gt_symbol]
          
          if cnn_pred.item() == gt_class:
            cnn_correct += 1
          total_symbols += 1
    
    if preimages is not None:
      # Constrain symbol probabilities based on preimages
      constrained_symbol = symbol.clone()  # Start with original probabilities
      for batch_idx in range(batch_size):
        for pos_idx in range(length[batch_idx]):
          # Get possible symbols from preimages for this position
          possible_symbols = set()
          for preimg_seq in preimages[batch_idx]:
            if pos_idx < len(preimg_seq):
              possible_symbols.add(preimg_seq[pos_idx])
          
          # Convert to indices and apply constraint
          possible_indices = [self.symbol_idx_map[s] for s in possible_symbols if s in self.symbol_idx_map]
          if possible_indices:
            # Create mask for possible symbols on the same device as symbol
            mask = torch.zeros(len(self.symbols), device=symbol.device, dtype=symbol.dtype)
            mask[possible_indices] = 1.0
            
            # Apply mask and handle zero case
            masked_probs = symbol[batch_idx, pos_idx] * mask
            prob_sum = masked_probs.sum()
            
            if prob_sum > 1e-8:  # If we have non-zero probabilities
              constrained_symbol[batch_idx, pos_idx] = masked_probs / prob_sum
            else:
              # If all masked probabilities are ~0, use uniform over possible symbols
              uniform_probs = torch.zeros(len(self.symbols), device=symbol.device, dtype=symbol.dtype)
              uniform_probs[possible_indices] = 1.0 / len(possible_indices)
              constrained_symbol[batch_idx, pos_idx] = uniform_probs
          # If no constraints available, keep original probabilities (already in constrained_symbol)
      
      # Ensure constrained_symbol is on the same device as the input
      constrained_symbol = constrained_symbol.to(symbol.device)
      result = self.eval_formula(blackbox.ListInput(constrained_symbol, length))
    else:
      result = self.eval_formula(blackbox.ListInput(symbol, length))
    
    if labels is not None:
      return result + (cnn_correct,)
    else:
      return result


class Trainer():
  def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate, sample_count, use_preimages=False, use_pruning=False):
    self.network = HWFNet(sample_count).to(device)
    self.device = device
    self.use_pruning = use_pruning
    
    if use_preimages:
      self.preimages = {}
      for l in range(1, 8, 2):  # lengths 1, 3, 5, 7
        t_begin = time.time()
        print(f"Preimage for length {l}")
        self.build_preimage(l)
        print(f"Time taken: {time.time() - t_begin}")
      
      # Setup ResNet transform and feature extractor only if pruning is needed
      if use_pruning:
        self.resnet_transform = transforms.Compose([
          transforms.Lambda(lambda x: x.expand(x.shape[0],3,*x.shape[2:]) ),
          transforms.Resize(224),  # Resize to 224x224
          transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalization
                              std=[0.229, 0.224, 0.225])
        ])
        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        for p in model.parameters():
          p.requires_grad = False
        self.feature_extractor = torch.nn.Sequential(*(list(model.children())[:-1]))
        self.feature_extractor = self.feature_extractor.to(self.device)
        self.feature_extractor.eval()
      else:
        self.resnet_transform = None
        self.feature_extractor = None
    else:
      self.preimages = None
      self.resnet_transform = None
      self.feature_extractor = None
      
    self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
    self.train_loader = train_loader
    self.test_loader = test_loader
    self.loss_fn = F.binary_cross_entropy
    self.model_root = model_root
    self.model_name = model_name
    self.min_test_loss = 100000000.0
    self.best_accuracy = 0.0
    self.best_cnn_accuracy = 0.0

  def eval_result_eq(self, a, b, threshold=0.01):
    result = abs(a - b) < threshold
    return result

  def build_preimage(self, length):
    if length in self.preimages:
      pass
    else:
      digits = [str(i) for i in range(10)]
      ops = ["+", "-", "*", "/"]
      possibilities = []
      for i in range(length):
        possibilities.append(digits if i % 2 == 0 else ops)
      preimages = {}
      combinations = itertools.product(*possibilities)
      for comb in combinations:
        expr = "".join(comb)
        # evaluate the expression while ignoring division by zero
        try:
          result = eval(expr)
          formatted_result = self.format_result_string(result)
          if formatted_result not in preimages:
            preimages[formatted_result] = []
          preimages[formatted_result].append(comb)
        except ZeroDivisionError:
          continue
      self.preimages[length] = preimages

  def get_preimage(self, length, result):
    # print(result, self.format_result_string(result))
    if length not in self.preimages:
      self.build_preimage(length)
    formatted_result = self.format_result_string(result)
    if formatted_result not in self.preimages[length]:
      print(f"Formatted result {formatted_result} not in preimages: {sorted(list(self.preimages[length].keys()))}")
    preimages = self.preimages[length][formatted_result]

    return preimages

  def format_result_string(self, result):
    # take a number and format it to be a float point number with 4 decimal places
    if isinstance(result, float):
      new_result = f"{result:.4f}"
      if new_result[-5:] == ".0000":
        new_result = new_result[:-5]
      return new_result
    elif isinstance(result, int):
      return f"{result}"
    else:
      raise ValueError(f"Unsupported type {type(result)} for result: {result}")

  def train_epoch(self, epoch):
    self.network.train()
    num_items = 0
    train_loss = 0
    total_correct = 0
    iter = tqdm(self.train_loader, total=len(self.train_loader))
    for (i, (img_seq, img_seq_len, label, gts)) in enumerate(iter):
      (output_mapping, y_pred) = self.network(img_seq.to(self.device), img_seq_len.to(self.device), preimages=None)
      y_pred = y_pred.to("cpu")

      # Normalize label format
      batch_size, num_outputs = y_pred.shape
      y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1)

      # Compute loss
      loss = self.loss_fn(y_pred, y)
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      if not math.isnan(loss.item()):
        train_loss += loss.item()

      # Collect index and compute accuracy
      if num_outputs > 0:
        y_index = torch.argmax(y, dim=1)
        y_pred_index = torch.argmax(y_pred, dim=1)
        correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item()
      else:
        correct_count = 0

      # Stats
      num_items += batch_size
      total_correct += correct_count
      perc = 100. * total_correct / num_items
      avg_loss = train_loss / (i + 1)

      # Prints
      iter.set_description(f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)")

  def train_epoch_w_preimage(self, epoch, sample_size=200):
    self.network.train()
    num_items = 0
    train_loss = 0
    total_correct = 0
    iter = tqdm(self.train_loader, total=len(self.train_loader))
    
    for (i, (img_seq, img_seq_len, label, gts)) in enumerate(iter):
      batch_size, formula_length, _, _, _ = img_seq.shape
      
      # Get preimages if available (without pruning)
      if self.preimages is not None:
        preimages = []
        for seq_len, gt, lbl in zip(img_seq_len, gts, label):
          preimg = self.get_preimage(seq_len.item(), lbl.item())
          # sample sample_size preimages randomly
          if sample_size > 0 and len(preimg) > sample_size:
            preimg = random.sample(preimg, sample_size)
          if gt not in preimg:
            preimg += [gt,]

          assert gt in preimg, f"GT {gt} not in preimage {preimg}, evals: {eval(''.join(gt))} , preimage evals: {eval(''.join(preimg[0]))}"
          preimages.append(preimg)

        # No pruning - use preimages directly
      else:
        preimages = None
      
      (output_mapping, y_pred) = self.network(img_seq.to(self.device), img_seq_len.to(self.device), preimages=preimages)
      y_pred = y_pred.to("cpu")

      # Normalize label format
      batch_size, num_outputs = y_pred.shape
      y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1)

      # Compute loss
      loss = self.loss_fn(y_pred, y)
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      if not math.isnan(loss.item()):
        train_loss += loss.item()

      # Collect index and compute accuracy
      if num_outputs > 0:
        y_index = torch.argmax(y, dim=1)
        y_pred_index = torch.argmax(y_pred, dim=1)
        correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item()
      else:
        correct_count = 0

      # Stats
      num_items += batch_size
      total_correct += correct_count
      perc = 100. * total_correct / num_items
      avg_loss = train_loss / (i + 1)

      # Prints
      iter.set_description(f"[Train Preimage Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)")

  def train_epoch_w_purification(self, epoch, args, sample_size=200):
    self.network.train()
    num_items = 0
    train_loss = 0
    total_correct = 0
    iter = tqdm(self.train_loader, total=len(self.train_loader))
    num_og_preimages = 0
    num_pruned_preimages = 0
    num_og_gt_proofs = 0
    num_pruned_gt_proofs = 0
    t_begin_total_epoch = time.time()
    
    for (i, (img_seq, img_seq_len, label, gts)) in enumerate(iter):
      batch_size, formula_length, _, _, _ = img_seq.shape
      
      # Get and process preimages if available
      if self.preimages is not None:
        preimages = []
        for seq_len, gt, lbl in zip(img_seq_len, gts, label):
          preimg = self.get_preimage(seq_len.item(), lbl.item())
          # sample sample_size preimages randomly
          if sample_size > 0 and len(preimg) > sample_size:
            preimg = random.sample(preimg, sample_size)
          if gt not in preimg:
            preimg += [gt,]

          assert gt in preimg, f"GT {gt} not in preimage {preimg}, evals: {eval(''.join(gt))} , preimage evals: {eval(''.join(preimg[0]))}"
          preimages.append(preimg)

        # pruning the preimages
        # first transform the img_seq for resnet
        # first, make it a list of formula_length x batch_size x 1 x H x W
        img_seq_perm = img_seq.permute(1, 0, 2, 3, 4)
        
        images = [self.resnet_transform(img) for img in img_seq_perm]
        features = [self.feature_extractor(img.to(self.device)).tolist() for img in images]

        images_for_pruning = []
        features_for_pruning = []
        labels_for_pruning = []

        for i in range(batch_size):
          length = img_seq_len[i].item()
          images_for_pruning.append([img[i] for img in images[:length]])
          features_for_pruning.append([feature[i] for feature in features[:length]])
          labels_for_pruning.append(tuple(gts[i]))

        # Count ground truth proofs before pruning
        for batch_idx, preimg_list in enumerate(preimages):
          gt = gts[batch_idx]
          if gt in preimg_list:
            num_og_gt_proofs += 1

        # prune the preimages
        pruned_preimages = structural_pruning(images_for_pruning, preimages, labels_for_pruning, features_for_pruning, None, args)
        
        # Count ground truth proofs after pruning
        for batch_idx, pruned_preimg_list in enumerate(pruned_preimages):
          gt = gts[batch_idx]
          if gt in pruned_preimg_list:
            num_pruned_gt_proofs += 1
        
        num_og_preimages += sum([len(preimg) for preimg in preimages])
        num_pruned_preimages += sum([len(preimg) for preimg in pruned_preimages])
        preimages = pruned_preimages

      else:
        preimages = None
      
      t_begin_epoch = time.time()
      (output_mapping, y_pred) = self.network(img_seq.to(self.device), img_seq_len.to(self.device), preimages=preimages)
      y_pred = y_pred.to("cpu")

      # Normalize label format
      batch_size, num_outputs = y_pred.shape
      y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1)

      # Compute loss
      loss = self.loss_fn(y_pred, y)
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      if not math.isnan(loss.item()):
        train_loss += loss.item()

      # Collect index and compute accuracy
      if num_outputs > 0:
        y_index = torch.argmax(y, dim=1)
        y_pred_index = torch.argmax(y_pred, dim=1)
        correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item()
      else:
        correct_count = 0

      # Stats
      num_items += batch_size
      total_correct += correct_count
      perc = 100. * total_correct / num_items
      avg_loss = train_loss / (i + 1)

      # Prints
      if self.preimages is not None:
        iter.set_description(f"[Train Purif Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%), Pruned: {num_pruned_preimages}/{num_og_preimages} ({100. * num_pruned_preimages / num_og_preimages if num_og_preimages > 0 else 0:.2f}%), GT Proofs: {num_pruned_gt_proofs}/{num_og_gt_proofs}")
      else:
        iter.set_description(f"[Train Purif Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)")
    
    total_epoch_time = time.time() - t_begin_total_epoch
    print(f"Total Epoch Time: {total_epoch_time}")

  def test_epoch(self, epoch):
    self.network.eval()
    num_items = 0
    test_loss = 0
    total_correct = 0
    total_cnn_correct = 0
    total_symbols = 0
    with torch.no_grad():
      iter = tqdm(self.test_loader, total=len(self.test_loader))
      for i, (img_seq, img_seq_len, label, gts) in enumerate(iter):
        (output_mapping, y_pred, cnn_correct) = self.network(img_seq.to(self.device), img_seq_len.to(self.device), preimages=None, labels=gts)
        y_pred = y_pred.to("cpu")

        # Normalize label format
        batch_size, num_outputs = y_pred.shape
        y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1)

        # Compute loss
        loss = self.loss_fn(y_pred, y)
        if not math.isnan(loss.item()):
          test_loss += loss.item()

        # Collect index and compute accuracy
        if num_outputs > 0:
          y_index = torch.argmax(y, dim=1)
          y_pred_index = torch.argmax(y_pred, dim=1)
          correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item()
        else:
          correct_count = 0

        # Stats
        num_items += batch_size
        total_correct += correct_count
        total_cnn_correct += cnn_correct
        # Count total symbols in this batch
        batch_symbols = sum([length.item() for length in img_seq_len])
        total_symbols += batch_symbols
        
        perc = 100. * total_correct / num_items
        cnn_perc = 100. * total_cnn_correct / total_symbols if total_symbols > 0 else 0.0
        avg_loss = test_loss / (i + 1)

        # Prints
        iter.set_description(f"[Test Epoch {epoch}] Loss: {avg_loss:.4f}, Overall: {total_correct}/{num_items} ({perc:.2f}%), CNN: {total_cnn_correct}/{total_symbols} ({cnn_perc:.2f}%)")

    # Save model
    if test_loss < self.min_test_loss:
      self.min_test_loss = test_loss
      # Save state dict instead of entire model to avoid pickling issues with blackbox timeout decorator
      torch.save({
        'model_state_dict': self.network.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'test_loss': test_loss,
        'epoch': epoch,
        'best_accuracy': self.best_accuracy,
        'best_cnn_accuracy': self.best_cnn_accuracy
      }, os.path.join(self.model_root, self.model_name))
    
    # Track best accuracies
    if perc > self.best_accuracy:
      self.best_accuracy = perc
    if cnn_perc > self.best_cnn_accuracy:
      self.best_cnn_accuracy = cnn_perc

  def load_model(self, checkpoint_path):
    """
    Load model from saved checkpoint.
    
    Usage:
    trainer = Trainer(...)
    trainer.load_model("path/to/checkpoint.pkl")
    """
    checkpoint = torch.load(checkpoint_path, map_location=self.device)
    self.network.load_state_dict(checkpoint['model_state_dict'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    self.min_test_loss = checkpoint.get('test_loss', float('inf'))
    self.best_accuracy = checkpoint.get('best_accuracy', 0.0)
    self.best_cnn_accuracy = checkpoint.get('best_cnn_accuracy', 0.0)
    print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
    print(f"Best accuracy: {self.best_accuracy:.2f}%, Best CNN accuracy: {self.best_cnn_accuracy:.2f}%")

  def train(self, n_epochs):
    # self.test_epoch(0)
    for epoch in range(1, n_epochs + 1):
      self.train_epoch(epoch)
      self.test_epoch(epoch)


if __name__ == "__main__":
  # Command line arguments
  parser = ArgumentParser("hwf")
  parser.add_argument("--model-name", type=str, default="hwf.pkl")
  parser.add_argument("--n-epochs", type=int, default=50)
  parser.add_argument("--sample-count", type=int, default=100)
  parser.add_argument("--dataset-prefix", type=str, default="expr")
  parser.add_argument("--batch-size", type=int, default=16)
  parser.add_argument("--learning-rate", type=float, default=0.0001)
  parser.add_argument("--loss-fn", type=str, default="bce")
  parser.add_argument("--seed", type=int, default=1234)
  parser.add_argument("--do-not-use-hash", action="store_true")
  parser.add_argument("--cuda", action="store_true")
  parser.add_argument("--gpu", type=int, default=0)
  parser.add_argument("--jit", action="store_true")
  parser.add_argument("--recompile", action="store_true")
  parser.add_argument("--preimage", action="store_true", help="Use preimage for normal training (without purification)")
  parser.add_argument("--structured-pruning", action="store_true")
  parser.add_argument('--structure-k', help='knn', default=10, type=int)
  parser.add_argument('--percent', help='knn', default=0.001, type=float)
  parser.add_argument("--mock_proximity", default = False, action="store_true")
  parser.add_argument("--num-training-samples", type=int, default=None, help="Number of training samples to use (default: use all)")
  parser.add_argument("--max-length", type=int, default=7, help="Maximum sequence length to load from dataset")
  parser.add_argument("--exact-length", action="store_true", help="Filter for exact length instead of up-to length")
  args = parser.parse_args()

  # Parameters
  torch.manual_seed(args.seed)
  random.seed(args.seed)
  if args.cuda:
    if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}")
    else: raise Exception("No cuda available")
  else: device = torch.device("cpu")

  # Data
  data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../../data/"))
  model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf"))
  if not os.path.exists(model_dir): os.makedirs(model_dir)
  train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix, max_length=args.max_length, exact_length=args.exact_length, num_train_samples=args.num_training_samples)

  # Training
  # Determine if we should use preimages (for both normal and purification training)
  use_preimage_flag = args.preimage or args.structured_pruning
  
  trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate, args.sample_count, 
                   use_preimages=use_preimage_flag, use_pruning=args.structured_pruning)
  
  if args.structured_pruning:
    print(f"Starting purification training with structural pruning for {args.n_epochs} epochs...")
    print(f"Purification settings: mock_proximity={args.mock_proximity}, structure_k={args.structure_k}, percent={args.percent}")
    if trainer.preimages is None:
      print("WARNING: Purification training enabled but no preimage available!")
    # Use purification training
    for epoch in range(1, args.n_epochs + 1):
      trainer.train_epoch_w_purification(epoch, args)
      trainer.test_epoch(epoch)
  elif args.preimage:
    print(f"Starting preimage training (without pruning) for {args.n_epochs} epochs...")
    if trainer.preimages is None:
      print("WARNING: Preimage training enabled but no preimage available!")
    # Use preimage training without pruning
    for epoch in range(1, args.n_epochs + 1):
      trainer.train_epoch_w_preimage(epoch)
      trainer.test_epoch(epoch)
  else:
    print(f"Starting normal training for {args.n_epochs} epochs...")
    trainer.train(args.n_epochs)
  
  print(f"Best accuracy: {trainer.best_accuracy:.2f}%, Best CNN accuracy: {trainer.best_cnn_accuracy:.2f}%")