import itertools
import string
import scallopy
import wandb
from datetime import datetime
from torchvision import transforms, models
import os
import time
import random
from typing import *

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from argparse import ArgumentParser
from tqdm import tqdm

from purification import purification
from mipll_pruning_algorithm import structural_pruning

# from torchql import Query, Table, Database

from dolphin import Distribution, PreImage
from dolphin.provenances import get_provenance

import sys 

sys.path.append("./")
# from lavis.models import load_model_and_preprocess
# from lavis.models import load_model_and_preprocess

# Install the following: 
# pip install omegaconf
# https://github.com/facebookresearch/iopath
# pip install timm
# pip install webdataset
# pip install opencv-python
# pip install transformers
# pip install fairscale
# pip install einops
# pip install spacy
# pip install pycocoevalcap

mnist_img_transform = torchvision.transforms.Compose([
  torchvision.transforms.ToTensor(),
  torchvision.transforms.Normalize(
    (0.1307,), (0.3081,)
  )
])

class PreImageFeatureDataset(torch.utils.data.Dataset):
  def __init__(
      self,
      samples,
      preimages,
      labels,
      targets,
      features,
  ):
    self.samples = samples
    self.preimages = preimages
    self.labels = labels
    self.targets = targets
    self.features = features
    self.length = len(samples)

  def __len__(self):
    return self.length
  
  def __getitem__(self, idx):
    return self.samples[idx], self.preimages[idx], self.labels[idx], self.targets[idx], self.features[idx]
  
  @staticmethod
  def collate_fn(batch):
    samples = []
    preimages = []
    labels = []
    targets = []
    features = []
    for item in batch:
      samples.append(item[0])
      preimages.append(item[1])
      labels.append(item[2])
      targets.append(item[3])
      features.append(item[4])
    
    return samples, preimages, labels, targets, features
class MNISTSumNDataset(torch.utils.data.Dataset):
  def __init__(
    self,
    root: str,
    sum_n: int,
    train: bool = True,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = False,
    num_training_samples: int = -1,
    dataset: str = "msum",
  ):
    self.task = dataset
    # Contains a MNIST dataset
    # self.mnist_dataset = torchvision.datasets.CIFAR10(
    self.mnist_dataset = torchvision.datasets.MNIST(
      root,
      train=train,
      transform=transform,
      target_transform=target_transform,
      download=download,
    )
    # if num_training_samples > 0:
    #   # sample num_training_samples * sum_n samples from the dataset
    #   self.mnist_dataset = random.sample(self.mnist_dataset, num_training_samples * sum_n)
    self.sum_n = sum_n
    self.index_map = list(range(len(self.mnist_dataset)))
    random.shuffle(self.index_map)
    if num_training_samples > 0:
      self.index_map = self.index_map[:num_training_samples * self.sum_n]

  def __len__(self):
    return int(len(self.index_map) / self.sum_n)

  def __getitem__(self, idx):
     # Get n data points
    imgs = ()
    labels = ()
    result = 0
    for i in range(self.sum_n):
      img, digit = self.mnist_dataset[self.index_map[idx*self.sum_n + i]]
      imgs = imgs + (img,)
      labels = labels + (digit,)

      if self.task == "msum":
        result += digit 
      else:
        result = max(result, digit)
    # Each data has two images and the GT is the sum of n digits
    return (*imgs, labels, result)

  @staticmethod
  def collate_fn(batch):
    imgs = ()
    labels = ()
    for i in range(len(batch[0])-2):
      a = torch.stack([item[i] for item in batch])
      imgs = imgs + (a,)
    digits = torch.stack([torch.tensor(item[len(batch[0])-1]).long() for item in batch])
    labels = [item[len(batch[0])-2] for item in batch]
    return ((imgs), (labels), digits)
  

# main questions: what exactly is classifier $f$? Seems to be the MNIST classifier without the symbolic layer, but not sure.
# Why is D' := D in Algo 1? D' should be empty and get populated with data via the purification process...
# how to set the thresholds? And what is the purpose of the step threshold?
# what classifies as pretraining f? Just one epoch of training?
# when training, we can use the candidate label vectors to instantiate distributions. However, the distributions will still
#     consider all possible combinations of candidate label vectors. Is this correct, or does the preimage also determine the
#     possible combinations?
# shouldn't eps keep increasing? Otherwise eps will always be < eps_max and the loop will never terminate.

# while debugging: trace the amount of labels that are removed from the candidate label vectors, which ones stay, and if
# we prune out a gold label vector

class MNISTSumNDatasetPreImage(torch.utils.data.Dataset):
  def __init__(
    self,
    sum_data: MNISTSumNDataset | List[Tuple],
    pre_image: Dict[int, List[Tuple]],
  ):
    # Contains a MNIST dataset

    length = len(sum_data)
    self.data = []
    for i in range(length):
      data = sum_data[i]
      imgs = data[:-2]
      labels = data[-2]
      sum = data[-1]
      proofs = pre_image[sum]
      self.data.append((*imgs, labels, sum, proofs))

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
     # Get n data points
    return self.data[idx]
  
  def get_proofs(self):
    return [ item[-1] for item in self.data ]
  
  def update_proofs_samplewise(self, pre_image):
    updated_data = []
    for i in range(len(self.data)):
      data = self.data[i]
      data_without_proofs = data[:-1]
      updated_data.append((*data_without_proofs, pre_image[i]))
    new_dataset = MNISTSumNDatasetPreImage([], [])
    new_dataset.data = updated_data
    return new_dataset

  @staticmethod
  def collate_fn(batch):
    imgs = ()
    labels = ()
    for i in range(len(batch[0])-3):
      a = torch.stack([item[i] for item in batch])
      imgs = imgs + (a,)
    digits = torch.stack([torch.tensor(item[len(batch[0])-2]).long() for item in batch])
    proofs = [item[len(batch[0])-1] for item in batch]
    labels = [item[len(batch[0])-3] for item in batch]
    return ((imgs), labels, digits, proofs)


def mnist_sum_n_loader(data_dir, sum_n, batch_size_train, batch_size_test, num_training_samples=-1, dataset="msum"):
  train_dataset = MNISTSumNDataset(
    data_dir,
    sum_n,
    train=True,
    download=True,
    transform=mnist_img_transform,
    num_training_samples=num_training_samples,
    dataset=dataset
  )
  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    collate_fn=MNISTSumNDataset.collate_fn,
    batch_size=batch_size_train,
    shuffle=True
  )
  
  test_dataset = MNISTSumNDataset(
    data_dir,
    sum_n,
    train=False,
    download=True,
    transform=mnist_img_transform,
    dataset=dataset
  )
  test_loader = torch.utils.data.DataLoader(
    test_dataset,
    collate_fn=MNISTSumNDataset.collate_fn,
    batch_size=batch_size_test,
    shuffle=True
  )

  return train_loader, test_loader


class MNISTNet(nn.Module):
  def __init__(self):
    super(MNISTNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
    self.fc1 = nn.Linear(1024, 1024)
    self.fc2 = nn.Linear(1024, 10)

  def forward(self, x):
    x = F.max_pool2d(self.conv1(x), 2)
    x = F.max_pool2d(self.conv2(x), 2)
    x = x.view(-1, 1024)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, p = 0.5, training=self.training)
    x = self.fc2(x)
    return F.softmax(x, dim=1)


class MNISTSumNNet(nn.Module):
  def __init__(self, k, sum_n, provenance, dataset = "msum"):
    super(MNISTSumNNet, self).__init__()

    # MNIST Digit Recognition Network
    # self.mnist_net = CIFAR10Net()
    # self.mnist_net = ResNet18()
    self.mnist_net = MNISTNet()
    self.dataset = dataset

    self.scl_ctx = scallopy.ScallopContext(provenance=provenance, k=k)
    self.relations = [] 
    self.variables = []
    for x in range(1, sum_n+1):
         self.scl_ctx.add_relation(f'digit_{x}', int, input_mapping=list(range(10)))
         a =  f'{random.choice(string.ascii_letters)}{x}'
         self.variables += [a]
         self.relations += [f'digit_{x}({a})']
    
    if self.dataset == "msum":
      self.scl_ctx.add_rule(f'sum_{sum_n}({"+".join(self.variables)}) = {", ".join(self.relations)}')
      # The `sum_n` logical reasoning module
      self.scallop_prg = self.scl_ctx.forward_function(f'sum_{sum_n}', output_mapping=[(i,) for i in range((sum_n*9)+1)])
    else:
      print("Max")
      self.scl_ctx.add_relation('digits', (int, int))
      for i, (var, rel) in enumerate(zip(self.variables, self.relations)):
        self.scl_ctx.add_rule(f'digits({i}, {var}) = {rel}')
      self.scl_ctx.add_rule("maximum(x) = x := max(y: digits(_, y))")
      self.scallop_prg = self.scl_ctx.forward_function(f'maximum', output_mapping=[(i,) for i in range((sum_n*9)+1)])

  def forward(self, x: Tuple[torch.Tensor, torch.Tensor], labels=None, pre_image = None):
    # if labels is not None:
    #   labels = torch.tensor(labels).to(x[0].device)
    #   correct = 0
    # for i in range(len(x)):
    #   if PreImage.record_preimage:
    #     probs = torch.rand((1, 10), device="cpu")
    #     print("Processing preimage for sums up to", i+1, "digits")
    #   else:
    #     probs = self.mnist_net(x[i])
    #     if labels is not None:
    #       preds = probs.data.max(1, keepdim=True)[1]
    #       correct += preds.eq(labels[:,i].data.view_as(preds)).sum()

    #   if pre_image is None:
    #     digits = range(10)
    #     a = Distribution(probs, digits)
    #     if i == 0:
    #       res = a
    #     else:
    #       res = res + a
    #   else:
    #     batch_size = len(pre_image)
    #     # print(pre_image)
    #     a = []
        
    #     for idx in range(batch_size):
    #       label_vector = pre_image[idx]
    #       # print(labels[idx].tolist())
    #       # exit()
    #       # print(label_vector)
    #       # print(labels[idx])
    #       digits = list(set(( label[i] for label in label_vector )))
    #       # exit()
    #       digits.sort()
    #       # print(digits)
    #       a.append(Distribution(probs[idx, :].view(1, -1), range(10)).map_symbols(digits))
        
    #     a = Distribution.stack(a)
    #     if i == 0:
    #         res = a
    #     else:
    #         res = res + a

    # if PreImage.record_preimage:
    #   return res.preimage
    # if labels is None:
    #   return res.map_symbols(range(len(x) * 9 + 1)).get_probabilities() # Tensor b x 19
    # else:
    #   return res.map_symbols(range(len(x) * 9 + 1)).get_probabilities(), correct

    correct = 0
    parameters = {}
    if pre_image is None:
      # correct = -1
      for i in range(len(x)):
        parameters[f"digit_{i+1}"] = self.mnist_net(x[i])
        if labels is not None:
          label_batch = torch.tensor([ label[i] for label in labels ])
          prediction = parameters[f"digit_{i+1}"].data.max(1, keepdim=True)[1].to("cpu")
          correct += prediction.eq(label_batch.data.view_as(prediction)).sum()
    else:
      parameters = {}
      predictions = [ self.mnist_net(x[i]) for i in range(len(x)) ]
      for bidx in range(len(x[0])):
        proofs = pre_image[bidx]
        for i in range(len(x)):
          if labels is not None:
            label = labels[bidx][i]
            prediction = predictions[i][bidx].data.max(0, keepdim=True)[1].item()
            if prediction == label:
              correct += 1

          digit_probs = []
          added_digits = set()
          for proof in proofs:
            # proof: digit_1, digit_2, digit_3
            digit = proof[i]
            if digit not in added_digits:
              probability = predictions[i][bidx, digit]
              tup = (probability, (digit, ))
              digit_probs.append(tup)
              added_digits.add(digit)
          if f"digit_{i+1}" not in parameters:
            parameters[f"digit_{i+1}"] = []
          parameters[f"digit_{i+1}"].append(digit_probs)
      # print(parameters['digit_1'], len(parameters["digit_1"]), len(parameters["digit_1"][0]))
      # exit()

    if labels is not None:
      return self.scallop_prg(**parameters), correct
    return self.scallop_prg(**parameters)


def bce_loss(output, ground_truth, dataset="msum"):
  ground_truth = torch.tensor(ground_truth).to(output.device)
  gt = torch.nn.functional.one_hot(ground_truth, num_classes=(sum_n*9)+1).float()
  return F.binary_cross_entropy(output, gt)


def nll_loss(output, ground_truth, dataset):
  return F.nll_loss(output, ground_truth)


class Trainer():
  def __init__(self, train_loader, test_loader, model_dir, learning_rate, loss, provenance, device, k, sum_n, args):
    self.device = device
    self.dataset = args.dataset
    self.model_dir = model_dir
    # Distribution.provenance = get_provenance(provenance)
    # if k > 0:
    # Distribution.provenance.k = k
    self.network = MNISTSumNNet(k, sum_n, provenance, dataset=self.dataset).to(self.device)
    self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
    self.train_loader = train_loader
    self.test_loader = test_loader
    self.best_loss = 10000000000
    self.provenance = provenance
    self.sum_n = sum_n
    self.best_acc = 0
    self.best_cls_acc = 0
    self.with_preimage = args.with_preimage
    self.with_purification = args.with_purification
    if loss == "nll":
      self.loss = nll_loss
    elif loss == "bce":
      self.loss = bce_loss
    else:
      raise Exception(f"Unknown loss function `{loss}`")

  def train_epoch(self, epoch, train_loader=None):
    

    if train_loader is None:
      train_loader = self.train_loader
    default_preimage = [ list(itertools.product(range(10), repeat=self.sum_n)) 
                         for _ in range(train_loader.batch_size) ]
    self.network.train()
    iter = tqdm(train_loader, total=len(train_loader))
    t_begin_epoch = time.time()
    for batch in iter:
      preimage = None
      if len(batch) == 3:
        data, labels, target = batch
      else:
        data, labels, target, preimage = batch
      imgs = ()
      for x in range(self.sum_n):
        imgs = imgs + (data[x].to(self.device),)
      target = target.to(self.device)
      self.optimizer.zero_grad()
      
      output = self.network(imgs, pre_image=default_preimage)
      loss = self.loss(output, target, dataset=self.dataset)
      loss.backward()
      self.optimizer.step()
      iter.set_description(f"[Train Epoch {epoch}] Loss: {loss.item():.4f}")        
        
    total_epoch_time = time.time() - t_begin_epoch
    wandb.log(
      {
        "epoch": epoch,
        "total_epoch_time": total_epoch_time,
      }
    )
    print(f"Total Epoch Time: {total_epoch_time}")

  def test_epoch(self, epoch):
    self.network.eval()
    num_items = len(self.test_loader.dataset)
    test_loss = 0
    correct = 0
    correct_classified = 0
    with torch.no_grad():
      iter = tqdm(self.test_loader, total=len(self.test_loader))
      for (data, labels, target) in iter:
        imgs = ()
        for x in range(sum_n):
          imgs = imgs + (data[x].to(self.device),)
        target = target.to(self.device)

        output, correct_ind = self.network(imgs, labels=labels)
        test_loss += self.loss(output, target, self.dataset).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
        correct_classified += correct_ind
        correct_perc = 100. * correct_classified / (num_items * self.sum_n)
        perc = 100. * correct / num_items
        if perc > 97.00:
          # record sum_n + epoch number combination when accuracy is high
          file_path = f'torch_mnist_sum_n_{self.provenance}_epoch_count.log'
          current_timestamp = datetime.now()
          formatted_timestamp = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
          if os.path.exists(file_path):
            with open(file_path, 'a') as file:
              file.write(f'sum n={self.sum_n}, epoch num={epoch}, {formatted_timestamp}\n')
          else:
             with open(file_path, 'w') as file:
              file.write(f'sum n={self.sum_n}, epoch num={epoch}, {formatted_timestamp}\n')
        iter.set_description(f"[Test Epoch {epoch}] Total loss: {test_loss:.4f}, Accuracy: {correct}/{num_items} ({perc:.2f}%), Classifier Accuracy: {correct_classified}/{num_items * self.sum_n} ({correct_perc:.2f}%)")
      wandb.log(
        {
          "epoch": epoch,
          "accuracy": perc,
          "per_class_accuracy": correct_perc,
        }
      )
      if test_loss < self.best_loss:
        self.best_loss = test_loss
        # torch.save(self.network, os.path.join(model_dir, "sum_2_best.pt"))
      # if correct_perc > self.best_cls_acc:
      if perc > self.best_acc: # or correct_perc > self.best_cls_acc:
        self.converged_epoch = epoch
        self.best_acc = perc
        self.best_cls_acc = correct_perc
        print(f"Saving best model with accuracy {self.best_acc:.2f}% at {os.path.join(self.model_dir, f'mnist_sum_n_{self.sum_n}_best.pt')}")
        torch.save(self.network.mnist_net, os.path.join(self.model_dir, f"mnist_sum_n_{self.sum_n}_best.pt"))
      print(f"Best loss: {self.best_loss:.4f}")
      print(f"Best acc: {self.best_acc:.2f}%")
      print(f"Best classifier acc: {self.best_cls_acc:.2f}%")

  def train(self, n_epochs):

    self.test_epoch(0)
    epoch = 0
    while True:
      self.train_epoch(epoch)
      self.test_epoch(epoch)
      epoch += 1
      if epoch > n_epochs:
        break

      print(f"Converged epoch: {self.converged_epoch}, current epoch: {epoch}")
      if epoch > self.converged_epoch + 10:
        print(f"Converged at epoch {self.converged_epoch}, stopping training")
        break

  def get_preimage(self):
    if self.dataset == "msum":
      preimage = {}
      digits = list(range(10))

      # all combinations of sum_n digits with itertools
      for combination in tqdm(itertools.product(digits, repeat=self.sum_n)):
        sum = 0
        for digit in combination:
          sum += digit
        if sum not in preimage:
          preimage[sum] = []
        preimage[sum].append(list(combination))

      return preimage
    elif self.dataset == "mmax":
      digits = list(range(10))
      ip = [digits for _ in range(self.sum_n)]

      # all possible combinations of the digits
      # use itertools
      all_combinations = list(itertools.product(*ip))

      # getting the preimage for max
      preimage = {}
      for comb in all_combinations:
        max_val = max(comb)
        if max_val not in preimage:
            preimage[max_val] = []
        preimage[max_val].append(list(comb))
      return preimage
  
  def train_w_purification(self, vlm, n_epochs, warmup_epochs=1, pretrained_model=False, preimage_only=False):

    if pretrained_model is None:
      resnet_transform = transforms.Compose([
      transforms.Lambda(lambda x: x.expand(x.shape[0],3,*x.shape[2:]) ),
      transforms.Resize(224),  # Resize to 224x224
      transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalization
                          std=[0.229, 0.224, 0.225])
      ])
      model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
      for p in model.parameters():
        p.requires_grad = False
      feature_extractor = torch.nn.Sequential(*(list(model.children())[:-1]))
      feature_extractor = feature_extractor.to(self.device)
      feature_extractor.eval()
    else:
      resnet_transform = lambda x : x
      model = torch.load(pretrained_model, weights_only=False)
      for p in model.parameters():
        p.requires_grad = False
      feature_extractor = model
      model.fc2 = nn.Identity()
      # feature_extractor = torch.nn.Sequential(*(list(model.children())))
      feature_extractor = feature_extractor.to(self.device)
      feature_extractor.eval()


    # getting the preimage
    preimage = self.get_preimage()
    
    epoch = 0
    self.converged_epoch = 0
    # self.test_epoch(epoch)
    while True:
      pruned = 0
      total = 0
      retained_GT = 0
      total_samples = 0
      self.network.train()
      total_loss = 0
      # train loop
      iter = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training")
      for idx, batch in iter:
        # s, p, l, t, f = batch
        s, l, t = batch
        batch_size = len(l)
        p = [ preimage[tgt.item()] for tgt in t ]
        total += sum([ len(proof) for proof in p ])

        # pruning
        images = [ resnet_transform(img) for img in s ]
        features = [ feature_extractor(img.to(self.device)).tolist() for img in images ]

        images_for_pruning = [ ]
        features_for_pruning = [ ]
        labels_for_pruning = [ ]

        for i in range(batch_size):
          images_for_pruning.append([ img[i] for img in images ])
          features_for_pruning.append([ feature[i] for feature in features ])
          labels_for_pruning.append(list(l[i]))
        
        if preimage_only:
          pass
        else:
          p = structural_pruning(images_for_pruning, p, labels_for_pruning, features_for_pruning, vlm, args)
        pruned += sum([ len(proof) for proof in p ])

        # check if each proof contains the label
        for (proofs, label) in zip(p, l):
          total_samples += 1
          if list(label) in proofs:
            retained_GT += 1
          else:
            # print(proofs, label, "not in proof")
            pass

        self.optimizer.zero_grad()
        logits = self.network([ d.to(self.device) for d in s ], pre_image=p)

        loss = self.loss(logits, t, dataset=self.dataset)
        total_loss += loss.item()
        loss.backward()
        self.optimizer.step()
        iter.set_description(f"[Train P=None Epoch {epoch}] Loss: {loss.item():.4f} Retained Proofs: {pruned}/{total} ({100. * pruned / total:.2f}%) Retained GT: {retained_GT}/{total_samples} ({100. * retained_GT / total_samples:.2f}%)")
      # self.train_epoch(epoch, train_loader=train_w_preimage_loader)
      
      # test loop
      self.test_epoch(epoch)
      epoch += 1
      if epoch > n_epochs:
        break

      print(f"Converged epoch: {self.converged_epoch}, current epoch: {epoch}")

      if epoch > self.converged_epoch + 10:
        print(f"Converged at epoch {self.converged_epoch}, stopping training")
        break
    

if __name__ == "__main__":
  # Argument parser
  parser = ArgumentParser()
  parser.add_argument("--sum-n", type=int, default=5)
  parser.add_argument("--n-epochs", type=int, default=5)
  parser.add_argument("--batch-size-train", type=int, default=64)
  parser.add_argument("--batch-size-test", type=int, default=64)
  parser.add_argument("--learning-rate", type=float, default=0.001)
  parser.add_argument("--loss-fn", type=str, default="bce")
  parser.add_argument("--seed", type=int, default=1234)
  parser.add_argument("--provenance", type=str, default="damp") #, choices=['damp', 'dmmp', 'dtkp-am'])
  parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
  parser.add_argument("--topk", type=int, default=1)
  parser.add_argument("--only_preimage", action="store_true")
  parser.add_argument("--with_preimage", action="store_true")
  parser.add_argument("--with_purification", action="store_true")
  parser.add_argument("--num-training-samples", type=int, default=-1)
  parser.add_argument('--model_name', default='blip2', type=str) #choices=['clip', 'blip', 'blip2', 'albef']
  parser.add_argument('--model_type', default='pretrain', type=str)
  parser.add_argument('--structure-k', help='knn', default=10, type=int)
  parser.add_argument('--percent', help='knn', default=0.001, type=float)
  parser.add_argument("--mock_proximity", default = False, action="store_true")
  parser.add_argument("--gpu", default = 0, type=int)
  parser.add_argument("--dataset", default = "msum", type=str, choices=["msum", "mmax"])
  parser.add_argument("--pruning_model", default = None, type=str)
  parser.add_argument("--model-dir", default=None, type=str)
  args = parser.parse_args()

  print(args)

  # Parameters
  sum_n = args.sum_n
  n_epochs = args.n_epochs
  batch_size_train = args.batch_size_train
  batch_size_test = args.batch_size_test
  learning_rate = args.learning_rate
  loss_fn = args.loss_fn
  provenance = args.provenance
  torch.manual_seed(args.seed)
  random.seed(args.seed)

  config = {
    "sum_n": sum_n,
    "n_epochs": n_epochs,
    "batch_size_train": batch_size_train, 
    "batch_size_test": batch_size_test,
    "provenance": provenance,
    "seed": args.seed,
    "experiment_type": "dolphin", 
    "with_preimage": args.with_preimage,
    "with_purification": args.with_purification,
  }

  # TODO Try the same model with VQAR -- pretrained AlexNet?
  # Pre-trained VLM 
  feature_dim_map = {'blip2': 768, 'clip': 512, 'blip': 768, 'albef': 768, 'resnet18_s': 512, 'resnet18_c': 512, 'resnet18_i': 512}
  # vlm, vis_processors, txt_processors = load_model_and_preprocess(name=args.model_name+'_feature_extractor', model_type=args.model_type, is_eval=True, device='cpu')
  vlm = None


  # Data
  data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data"))
  model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../model/mnist_sum_{sum_n}')) if args.model_dir is None else args.model_dir
  os.makedirs(model_dir, exist_ok=True)

  # Dataloaders
  train_loader, test_loader = mnist_sum_n_loader(data_dir, sum_n, batch_size_train, batch_size_test, num_training_samples=args.num_training_samples, dataset=args.dataset)
  print(args.device, torch.cuda.is_available())
  if args.device == "cuda" and torch.cuda.is_available():
    print("Using CUDA")
    device_name = "cuda"
  elif args.device == "mps" and torch.backends.mps.is_available():
    device_name = "mps"
  else:
    device_name = "cpu"

  device = torch.device(device_name)
  timestamp = datetime.now()
  id = f'dolphin_sum{sum_n}_{args.seed}_{provenance}_{timestamp.strftime("%Y-%m-%d %H-%M-%S")}'

  wandb.init(
    project="PURIFICATION MNIST", config=config, id=id
  )
  # Create trainer and train

  trainer = Trainer(train_loader, test_loader, model_dir, learning_rate, loss_fn, provenance, device, args.topk, sum_n, args)
  if args.with_purification: # or args.with_preimage:
    vlm = None
    trainer.train_w_purification(vlm, n_epochs, warmup_epochs=1, pretrained_model=args.pruning_model, preimage_only=args.only_preimage)
  else:
    trainer.train(n_epochs)