import itertools
import os
import pickle
import json
import csv
import time
import random
from typing import *
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.multiprocessing import Pool
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np

from argparse import ArgumentParser
from tqdm import tqdm
import math

import unstructured_dataset
import structured_dataset
import task_dataset
import task_program
import output
import blackbox
from constants import *
from mipll_pruning_algorithm import structural_pruning


class Dataset(torch.utils.data.Dataset):
    def __init__(
            self,
            config: dict,
            train: bool,
            individual_labels: bool = False):
        self.dataset = task_dataset.TaskDataset(config, train, individual_labels=individual_labels)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        return self.dataset.__getitem__(index)

    @staticmethod
    def collate_fn(batch):
        data_dicts = [item[0] for item in batch]
        config = batch[0][1]
        collate_fns = {input[NAME]: structured_dataset.get_structured_dataset_static(
            input).collate_fn for input in config}
        imgs = {input[NAME]: collate_fns[input[NAME]](
            [data_dict[input[NAME]] for data_dict in data_dicts], input) for input in config}
        results = [item[2] for item in batch]
        # check if the batch also contains individual labels
        if len(batch[0]) == 4:
            individual_labels = [item[3] for item in batch]
            return (imgs, results, individual_labels)
        return (imgs, results)


def train_test_loader(configuration, batch_size_train, batch_size_test, num_training_samples=None, individual_labels=False):
    train_dataset = Dataset(configuration, train=True, individual_labels=individual_labels)
    
    # If num_training_samples is specified, create a subset sampler
    if num_training_samples is not None:
        # Create a random subset of indices (using global seed set in main)
        import random
        total_samples = len(train_dataset)
        if num_training_samples > total_samples:
            print(f"Warning: Requested {num_training_samples} samples but only {total_samples} available. Using all samples.")
            num_training_samples = total_samples
        
        # Sample random indices
        sampled_indices = random.sample(range(total_samples), num_training_samples)
        train_sampler = torch.utils.data.SubsetRandomSampler(sampled_indices)
        
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            collate_fn=Dataset.collate_fn,
            batch_size=batch_size_train,
            sampler=train_sampler
        )
    else:
        train_loader = torch.utils.data.DataLoader(
            train_dataset,
            collate_fn=Dataset.collate_fn,
            batch_size=batch_size_train,
            shuffle=True
        )

    test_loader = torch.utils.data.DataLoader(
        Dataset(configuration, train=False, individual_labels=individual_labels),
        collate_fn=Dataset.collate_fn,
        batch_size=batch_size_test,
        shuffle=True
    )

    return train_loader, test_loader


class TaskNet(nn.Module):
    def __init__(
            self,
            unstructured_datasets: List[unstructured_dataset.UnstructuredDataset],
            config: dict,
            fn: Callable,
            output_mapping: output.OutputMapping,
            sample_count: int,
            batch_size_train: int,
            caching: bool,
            preimage: dict = None):
        super(TaskNet, self).__init__()

        self.config = config
        self.unstructured_datasets = unstructured_datasets
        self.structured_datasets = [
            structured_dataset.get_structured_dataset_static(input) for input in config]

        self.nets_dict = {}
        self.nets = self.get_nets_list()
        self.preimage = preimage
        self.forward_fns = [partial(sd.forward, self.nets[i])
                            for i, sd in enumerate(self.structured_datasets)]
        input_mappings = tuple([sd.get_input_mapping(
            config[i]) for i, sd in enumerate(self.structured_datasets)])
        loss_aggregator = task_config.get(LOSS_AGGREGATOR, ADD_MULT)
        self.eval_formula = \
            blackbox.BlackBoxFunction(function=fn,
                                      input_mappings=input_mappings,
                                      output_mapping=output_mapping,
                                      batch_size=batch_size_train,
                                      loss_aggregator=loss_aggregator,
                                      caching=caching,
                                      sample_count=sample_count)

        self.pool = Pool(processes=batch_size_train)

    def get_nets_list(self):
        nets = []

        def add_net(ud_name, ud):
            if ud_name not in self.nets_dict:
                self.nets_dict[ud_name] = ud.net()
            nets.append(self.nets_dict[ud_name])
        for ud in self.unstructured_datasets:
            add_net(ud.name, ud)
        return nets

    def parameters(self):
        return [net.parameters() for net in self.nets_dict.values()]

    def task_test(self, args, x):
        return self.sampling.sample_test(args, data=x)

    def forward(self, x, labels=None, preimage_batch=None):
        if isinstance(x, dict):
            keys = [key for key in x]
            distrs = [self.forward_fns[i](x[key]) for i, key in enumerate(keys)]
        else:
            distrs = [self.forward_fns[i](x[i]) for i in range(len(x))]
        
        # Store individual classifier predictions for accuracy tracking
        individual_predictions = []
        for i, distr in enumerate(distrs):
            # Get the most likely prediction for each classifier
            individual_pred = torch.argmax(distr, dim=-1)
            individual_predictions.append(individual_pred)
        
        # Use custom preimage_batch if provided, otherwise create from self.preimage
        if preimage_batch is None:
            preimage_batch = None
            if self.preimage is not None and labels is not None:
                batch_size = labels.shape[0] if hasattr(labels, 'shape') else len(labels)
                preimage_batch = []
                for sample_idx in range(batch_size):
                    label = labels[sample_idx]
                    if label in self.preimage:
                        # Each element is a list of N sets of candidate labels (one for each input)
                        sample_preimage = self.preimage[label]
                        preimage_batch.append(sample_preimage)
                    else:
                        # If label not in preimage, use None to indicate no filtering
                        preimage_batch.append(None)

        # print("Here is the preimage batch with its label", list(zip(preimage_batch, labels)))
        # exit()
        if isinstance(x, dict):
            inputs = [self.structured_datasets[i].distrs_to_input(distrs[i], x[keys[i]], input)
                    for i, input in enumerate(self.config)]
        else:
            inputs = [self.structured_datasets[i].distrs_to_input(distrs[i], x[i], input)
                    for i, input in enumerate(self.config)]
        
        # Get the combined result from eval_formula
        combined_result = self.eval_formula(*inputs, preimage_batch=preimage_batch)
        
        # Return combined result along with individual predictions
        return combined_result + (individual_predictions,)

    def eval(self):
        for net in self.nets_dict.values():
            net.eval()

    def train(self):
        for net in self.nets_dict.values():
            net.train()

    def close(self):
        self.pool.close()

def get_preimage(task_name, sum_n):
    if task_name == "sum":
      preimage = {}
      digits = list(range(10))

      # all combinations of sum_n digits with itertools
      for combination in tqdm(itertools.product(digits, repeat=sum_n)):
        sum = 0
        for digit in combination:
          sum += digit
        if sum not in preimage:
          preimage[sum] = []
        preimage[sum].append(list(combination))

    elif task_name == "max":
      digits = list(range(10))
      ip = [digits for _ in range(sum_n)]

      # all possible combinations of the digits
      # use itertools
      all_combinations = list(itertools.product(*ip))

      # getting the preimage for max
      preimage = {}
      for comb in all_combinations:
        max_val = max(comb)
        if max_val not in preimage:
            preimage[max_val] = []
        preimage[max_val].append(list(comb))

    # for key, value in preimage.items():
    #     preimage[key] = list(zip(*value))
    #     for i in range(len(preimage[key])):
    #         preimage[key][i] = list(set(preimage[key][i]))
    return preimage

class Trainer():
    def __init__(
            self,
            train_loader: torch.utils.data.DataLoader,
            test_loader: torch.utils.data.DataLoader,
            unstructured_datasets: List[unstructured_dataset.UnstructuredDataset],
            learning_rate: float,
            config: dict,
            fn: Callable,
            output_mapping: output.OutputMapping,
            sample_count: int,
            batch_size_train: int,
            caching: bool,
            task_name: str,
            use_preimage:bool = False):
        self.task_name = task_name
        self.use_preimage = use_preimage
        self.preimage = None
        if use_preimage:
            try:
                task_func, num_digits, dataset = task_name.split("_")
                print(f"Initializing preimage for task: {task_func}, digits: {num_digits}")
                self.preimage = get_preimage(task_func, int(num_digits))
                if self.preimage:
                    print(f"Preimage loaded successfully with {len(self.preimage)} entries")
                else:
                    print("Warning: get_preimage returned None")
            except Exception as e:
                print(f"Error initializing preimage: {e}")
                print(f"Task name: {task_name} (expected format: 'tasktype_number_dataset')")
                self.preimage = None
        self.config = config
        self.network = TaskNet(unstructured_datasets=unstructured_datasets,
                               config=config,
                               fn=fn,
                               output_mapping=output_mapping,
                               sample_count=sample_count,
                               batch_size_train=batch_size_train,
                               caching=caching,
                               preimage=self.preimage)
        self.output_mapping = output_mapping
        self.optimizers = [optim.Adam(
            net.parameters(), lr=learning_rate) for net in self.network.nets_dict.values()]
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.loss_fn = F.binary_cross_entropy
        self.best_test_score = 0.0
        self.best_test_epoch = 0
        self.epochs_without_improvement = 0
        self.patience = 5  # Number of epochs to wait before terminating
        self.best_individual_acc = 0.0  # Track best combined individual classifier accuracy

    def train_epoch(self, epoch):
        self.network.train()
        num_items = 0
        train_loss = 0
        total_correct = 0
        
        # Initialize individual classifier tracking (will be set dynamically)
        individual_correct = []
        individual_total = []
        
        iter = tqdm(self.train_loader, total=len(self.train_loader))
        for (i, batch_data) in enumerate(iter):
            # Handle both formats: with and without individual labels
            if len(batch_data) == 3:  # (data, target, individual_labels)
                data, target, individual_labels = batch_data
                has_individual_labels = True
            else:  # (data, target)
                data, target = batch_data
                has_individual_labels = False
            
            network_output = self.network(data, labels=target)
            (output_mapping, y_pred_sim, y_pred, individual_predictions) = network_output

            # Normalize label format
            batch_size = y_pred_sim.shape[0]
            norm_label, y = self.output_mapping.get_normalized_labels(
                y_pred_sim, target, output_mapping)

            if output_mapping:
                # Compute loss
                loss = self.loss_fn(y_pred_sim, norm_label)
                for optimizer in self.optimizers:
                    optimizer.zero_grad()
                loss.backward()
                for optimizer in self.optimizers:
                    optimizer.step()
                if not math.isnan(loss.item()):
                    train_loss += loss.item()

            # Collect index and compute accuracy
            if output_mapping:
                y_index = torch.argmax(y, dim=1)
                y_pred_index = torch.argmax(y_pred, dim=1)
                correct_count = torch.sum(torch.where(torch.sum(
                    y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size, device=DEVICE).bool())).item()
            else:
                correct_count = 0

            # Compute individual classifier accuracies if labels are available
            if has_individual_labels and individual_predictions:
                # Initialize arrays on first batch if not already done
                if not individual_correct:
                    individual_correct = [0] * len(individual_predictions)
                    individual_total = [0] * len(individual_predictions)
                
                for classifier_idx in range(len(individual_predictions)):
                    if classifier_idx < len(individual_labels[0]):  # Check if this classifier has labels
                        for batch_idx in range(batch_size):
                            individual_total[classifier_idx] += 1
                            pred = individual_predictions[classifier_idx][batch_idx].item()
                            true_label = individual_labels[batch_idx][classifier_idx]
                            if isinstance(true_label, torch.Tensor):
                                true_label = true_label.item()
                            if pred == true_label:
                                individual_correct[classifier_idx] += 1

            # Stats
            num_items += batch_size
            total_correct += correct_count
            perc = 100. * total_correct / num_items
            avg_loss = train_loss / (i + 1)

            # Create description with combined individual accuracy
            desc = f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Overall: {total_correct}/{num_items} ({perc:.2f}%)"
            if has_individual_labels and individual_predictions and individual_total:
                total_individual_correct = sum(individual_correct)
                total_individual_total = sum(individual_total)
                if total_individual_total > 0:
                    individual_acc = 100. * total_individual_correct / total_individual_total
                    desc += f" | Individual: {total_individual_correct}/{total_individual_total} ({individual_acc:.2f}%)"
            
            iter.set_description(desc)

    def test_epoch(self, epoch):
        self.network.eval()
        num_items = 0
        test_loss = 0
        total_correct = 0
        
        # Initialize individual classifier tracking (will be set dynamically)
        individual_correct = []
        individual_total = []
        
        with torch.no_grad():
            iter = tqdm(self.test_loader, total=len(self.test_loader))
            for i, batch_data in enumerate(iter):
                # Handle both formats: with and without individual labels
                if len(batch_data) == 3:  # (data, target, individual_labels)
                    data, target, individual_labels = batch_data
                    has_individual_labels = True
                else:  # (data, target)
                    data, target = batch_data
                    has_individual_labels = False
                
                network_output = self.network(data)
                (output_mapping, y_pred_sim, y_pred, individual_predictions) = network_output

                # Normalize label format
                batch_size = y_pred_sim.shape[0]

                norm_label, y = self.output_mapping.get_normalized_labels(
                    y_pred_sim, target, output_mapping)

                # Compute loss
                loss = self.loss_fn(y_pred_sim, norm_label)
                if not math.isnan(loss.item()):
                    test_loss += loss.item()

                # Collect index and compute accuracy
                if output_mapping:
                    y_index = torch.argmax(y, dim=1)
                    y_pred_index = torch.argmax(
                        y_pred, dim=1)
                    correct_count = torch.sum(torch.where(torch.sum(
                        y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size, device=DEVICE).bool())).item()
                else:
                    correct_count = 0

                # Compute individual classifier accuracies if labels are available
                if has_individual_labels and individual_predictions:
                    # Initialize arrays on first batch if not already done
                    if not individual_correct:
                        individual_correct = [0] * len(individual_predictions)
                        individual_total = [0] * len(individual_predictions)
                    
                    for classifier_idx in range(len(individual_predictions)):
                        if classifier_idx < len(individual_labels[0]):  # Check if this classifier has labels
                            for batch_idx in range(batch_size):
                                individual_total[classifier_idx] += 1
                                pred = individual_predictions[classifier_idx][batch_idx].item()
                                true_label = individual_labels[batch_idx][classifier_idx]
                                if isinstance(true_label, torch.Tensor):
                                    true_label = true_label.item()
                                if pred == true_label:
                                    individual_correct[classifier_idx] += 1

                # Stats
                num_items += batch_size
                total_correct += correct_count
                perc = 100. * total_correct / num_items
                avg_loss = test_loss / (i + 1)

                # Create description with combined individual accuracy
                desc = f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Overall: {total_correct}/{num_items} ({perc:.2f}%)"
                if has_individual_labels and individual_predictions and individual_total:
                    total_individual_correct = sum(individual_correct)
                    total_individual_total = sum(individual_total)
                    if total_individual_total > 0:
                        individual_acc = 100. * total_individual_correct / total_individual_total
                        desc += f" | Individual: {total_individual_correct}/{total_individual_total} ({individual_acc:.2f}%)"
                
                iter.set_description(desc)

        # Calculate current combined individual accuracy for potential storage
        current_individual_acc = 0.0
        if has_individual_labels and individual_predictions and individual_total:
            total_individual_correct = sum(individual_correct)
            total_individual_total = sum(individual_total)
            if total_individual_total > 0:
                current_individual_acc = 100. * total_individual_correct / total_individual_total
        
        # Track best test score
        if perc > self.best_test_score:
            self.best_test_score = perc
            self.best_test_epoch = epoch
            self.epochs_without_improvement = 0
            self.best_individual_acc = current_individual_acc  # Store combined individual accuracy at best epoch
        else:
            self.epochs_without_improvement += 1
        
        # Report current best score
        best_summary = f"[Test Epoch {epoch}] Current: {perc:.2f}% | Best: {self.best_test_score:.2f}% (Epoch {self.best_test_epoch}) | No improvement: {self.epochs_without_improvement}"
        if hasattr(self, 'best_individual_acc') and current_individual_acc > 0:
            best_summary += f" | Best Individual: {self.best_individual_acc:.2f}%"
        print(best_summary)
        
        # Check for convergence (only if accuracy > 50%)
        if perc > 30.0 and self.epochs_without_improvement >= self.patience:
            print(f"\n{'='*60}")
            print(f"Convergence reached! No improvement for {self.patience} epochs.")
            print(f"Final best test accuracy: {self.best_test_score:.2f}% (Epoch {self.best_test_epoch})")
            if hasattr(self, 'best_individual_acc') and self.best_individual_acc > 0:
                print(f"Best individual classifier accuracy: {self.best_individual_acc:.2f}%")
            print(f"{'='*60}")
            return True  # Signal to terminate training
        
        return False  # Continue training

        dir = f"{os.path.dirname(os.path.abspath(__file__))}/checkpoint/{args.task}/"
        if not os.path.exists(dir):
            os.makedirs(dir)
        ckpt_path = os.path.join(dir, f"{args.seed}-{epoch}.pkl")
        torch.save(self.network.nets[0].state_dict(), ckpt_path)


    def train(self, n_epochs):
        for epoch in range(1, n_epochs + 1):
            self.train_epoch(epoch)
            should_terminate = self.test_epoch(epoch)
            
            if should_terminate:
                break
        
        # Report best test score
        print(f"\n{'='*60}")
        print(f"Training completed!")
        print(f"Best test accuracy: {self.best_test_score:.2f}% (Epoch {self.best_test_epoch})")
        if hasattr(self, 'best_individual_acc') and self.best_individual_acc > 0:
            print(f"Best individual classifier accuracy: {self.best_individual_acc:.2f}%")
        print(f"{'='*60}")
        
        self.network.close()

    def train_w_purification(self, vlm, n_epochs, args, warmup_epochs=None):
        resnet_transform = transforms.Compose([
        transforms.Lambda(lambda x: x.expand(x.shape[0],3,*x.shape[2:]) ),
        transforms.Resize(224),  # Resize to 224x224
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalization
                            std=[0.229, 0.224, 0.225])
        ])
        model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        # model = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
        for p in model.parameters():
            p.requires_grad = False
        feature_extractor = torch.nn.Sequential(*(list(model.children())[:-1]))
        feature_extractor = feature_extractor.to(DEVICE)
        feature_extractor.eval()
        
        epoch = 0
        self.converged_epoch = 0
        self.test_epoch(epoch)
        while True:
            pruned = 0
            total = 0
            retained_GT = 0
            total_samples = 0
            total_correct = 0
            train_loss = 0
            num_items = 0
            
            # Initialize individual classifier tracking (will be set dynamically)
            individual_correct = []
            individual_total = []
            
            self.network.train()
            iter = tqdm(self.train_loader, total=len(self.train_loader))
            for (i, batch) in enumerate(iter):
                # print("Here is the batch", batch)
                data, target, labels = batch
                # get the data names from the config
                data_names = [ input[NAME] for input in self.config ]
                # print("Here are the data names", data_names)
                # get the data as a list of tensors
                data = [ data[name] for name in data_names ]
                # print(data, target, labels)
                # print("Here is the data", data)
                # print("Here is the target", target)
                # print("Here is the labels", labels)
                # exit()
                batch_size = len(target)
                p = [ self.preimage[tgt] for tgt in target ]
                total += sum([ len(proof) for proof in p ])

                images = [ resnet_transform(img) for img in data ]
                features = [ feature_extractor(img.to(DEVICE)).tolist() for img in images ]

                images_for_pruning = [ ]
                features_for_pruning = [ ]
                labels_for_pruning = [ ]

                for i in range(batch_size):
                    images_for_pruning.append([ img[i] for img in images ])
                    features_for_pruning.append([ feature[i] for feature in features ])
                    labels_for_pruning.append(list(labels[i]))
                
                p = structural_pruning(images_for_pruning, p, labels_for_pruning, features_for_pruning, vlm, args)
                pruned += sum([ len(proof) for proof in p ])

                # check if each proof contains the label
                for (proofs, label) in zip(p, labels):
                    total_samples += 1
                    if list(label) in proofs:
                        retained_GT += 1
                    else:
                        # print(proofs, label, "not in proof")
                        pass
                
                network_output = self.network(data, labels=target, preimage_batch=p)
                (output_mapping, y_pred_sim, y_pred, individual_predictions) = network_output

                # Normalize label format
                batch_size = y_pred_sim.shape[0]
                norm_label, y = self.output_mapping.get_normalized_labels(
                    y_pred_sim, target, output_mapping)

                if output_mapping:
                    # Compute loss
                    loss = self.loss_fn(y_pred_sim, norm_label)
                    for optimizer in self.optimizers:
                        optimizer.zero_grad()
                    loss.backward()
                    for optimizer in self.optimizers:
                        optimizer.step()
                    if not math.isnan(loss.item()):
                        train_loss += loss.item()

                # Collect index and compute accuracy
                if output_mapping:
                    y_index = torch.argmax(y, dim=1)
                    y_pred_index = torch.argmax(y_pred, dim=1)
                    correct_count = torch.sum(torch.where(torch.sum(
                        y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size, device=DEVICE).bool())).item()
                else:
                    correct_count = 0

                # Compute individual classifier accuracies (purification training has individual labels)
                if individual_predictions and labels is not None:
                    # Initialize arrays on first batch if not already done
                    if not individual_correct:
                        individual_correct = [0] * len(individual_predictions)
                        individual_total = [0] * len(individual_predictions)
                    
                    for classifier_idx in range(len(individual_predictions)):
                        if classifier_idx < len(labels[0]):  # Check if this classifier has labels
                            for batch_idx in range(batch_size):
                                individual_total[classifier_idx] += 1
                                pred = individual_predictions[classifier_idx][batch_idx].item()
                                true_label = labels[batch_idx][classifier_idx]
                                if isinstance(true_label, torch.Tensor):
                                    true_label = true_label.item()
                                if pred == true_label:
                                    individual_correct[classifier_idx] += 1

                # Stats
                num_items += batch_size
                total_correct += correct_count
                perc = 100. * total_correct / num_items
                avg_loss = train_loss / (i + 1)

                # Create description with combined individual accuracy and purification stats
                desc = f"[Train Purif Epoch {epoch}] Loss: {loss.item():.4f} Overall: {perc:.1f}%"
                if individual_predictions and individual_total:
                    total_individual_correct = sum(individual_correct)
                    total_individual_total = sum(individual_total)
                    if total_individual_total > 0:
                        individual_acc = 100. * total_individual_correct / total_individual_total
                        desc += f" | Individual: {individual_acc:.2f}%"
                
                desc += f" | Proofs: {pruned}/{total} ({100. * pruned / total:.2f}%) | GT Retained: {retained_GT}/{total_samples} ({100. * retained_GT / total_samples:.2f}%)"
                iter.set_description(desc)

            should_terminate = self.test_epoch(epoch)
            if should_terminate:
                break
            
            epoch += 1
            if epoch > n_epochs:
                break
        
        # Report best test score for purification training
        print(f"\n{'='*60}")
        print(f"Purification training completed!")
        print(f"Best test accuracy: {self.best_test_score:.2f}% (Epoch {self.best_test_epoch})")
        if hasattr(self, 'best_individual_acc') and self.best_individual_acc > 0:
            print(f"Best individual classifier accuracy: {self.best_individual_acc:.2f}%")
        print(f"{'='*60}")
            


if __name__ == "__main__":
    # Argument parser
    parser = ArgumentParser("neuro-symbolic-dataset")
    parser.add_argument("--n-epochs", type=int, default=10)
    parser.add_argument("--n-samples", type=int, default=100)
    parser.add_argument("--num-training-samples", type=int, default=None, 
                        help="Number of training samples to use. If None, uses all available samples.")
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--task", type=str, default='sum_2_mnist')
    parser.add_argument("--configuration", type=str,
                        default="configuration.json")
    parser.add_argument("--caching", type=bool, default=True)
    parser.add_argument("--threaded", type=int, default=0)
    parser.add_argument("--purification", action="store_true", help="Use structural pruning for training")
    parser.add_argument("--preimage", action="store_true", help="Use preimage for normal training (without purification)")

    # learning rate
    parser.add_argument("--learning-rate", type=float, default=None)
    
    # Structural pruning arguments (defaults match dolphin_sum_n_algo1_structured_pruning.py)
    parser.add_argument("--mock-proximity", action="store_true", default=False, 
                        help="Use mock proximity instead of real feature-based pruning")
    parser.add_argument("--structure-k", type=int, default=10,
                        help="Top-k most similar instances for structural pruning")
    parser.add_argument("--percent", type=float, default=0.001,
                        help="Percentage threshold for similarity (alternative to structure-k)")
    parser.add_argument("--model-name", type=str, default="blip2",
                        help="Model name for feature extraction (blip2, clip, blip, albef, resnet18, etc.)")
    
    args = parser.parse_args()
    
    dir_path = os.path.dirname(os.path.realpath(__file__))

    # environment init
    torch.multiprocessing.set_start_method('spawn')

    # Read json
    configuration = json.load(
        open(os.path.join(dir_path, args.configuration)))

    # Parameters
    n_epochs = args.n_epochs

    torch.manual_seed(args.seed)
    random.seed(args.seed)

    task_config = configuration[args.task]

    # Initialize the train and test loaders
    batch_size_train = task_config[BATCH_SIZE_TRAIN]
    batch_size_test = task_config[BATCH_SIZE_TEST]
    train_loader, test_loader = train_test_loader(
        task_config, batch_size_train, batch_size_test, args.num_training_samples, individual_labels=True)

    # Set the output mapping
    om = output.get_output_mapping(task_config)

    # Create trainer and train
    py_func = task_config[PY_PROGRAM]
    learning_rate = task_config[LEARNING_RATE] if args.learning_rate is None else args.learning_rate
    fn = task_program.dispatcher[py_func]
    config = task_config[INPUTS]
    unstructured_datasets = [task_dataset.TaskDataset.get_unstructured_dataset(
        input, train=True) for input in task_config[INPUTS]]
    # Determine if we should use preimages (for both normal and purification training)
    use_preimage_flag = args.preimage or args.purification
    
    trainer = Trainer(train_loader=train_loader,
                        test_loader=test_loader,
                        unstructured_datasets=unstructured_datasets,
                        learning_rate=learning_rate,
                        config=config,
                        fn=fn,
                        output_mapping=om,
                        sample_count=args.n_samples,
                        batch_size_train=batch_size_train,
                        caching=args.caching,
                        task_name=args.task,
                        use_preimage=use_preimage_flag)
    
    # Choose training method based on arguments
    if args.purification:
        print(f"Starting purification training with structural pruning for {n_epochs} epochs...")
        print(f"Purification settings: mock_proximity={args.mock_proximity}, structure_k={args.structure_k}, percent={args.percent}, model_name={args.model_name}")
        if trainer.preimage is None:
            print("WARNING: Purification training enabled but no preimage available!")
            print("This will result in empty preimage batches for structural pruning.")
        vlm_model = None  # VLM model can be added here if needed
        trainer.train_w_purification(vlm_model, n_epochs, args)
    else:
        if args.preimage:
            print(f"Starting normal training with preimages for {n_epochs} epochs...")
        else:
            print(f"Starting normal training without preimages for {n_epochs} epochs...")
        trainer.train(n_epochs)
