import itertools
import wandb
from datetime import datetime
import os

from torchvision import transforms

from purification import purification
import time
import random
from typing import *

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from argparse import ArgumentParser
from tqdm import tqdm

# from torchql import Query, Table, Database

from dolphin import Distribution, PreImage
from dolphin.provenances import get_provenance

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

class MNISTSumNDataset(torch.utils.data.Dataset):
  def __init__(
    self,
    root: str,
    sum_n: int,
    train: bool = True,
    transform: Optional[Callable] = None,
    target_transform: Optional[Callable] = None,
    download: bool = False,
    num_training_samples: int = -1,
    dataset: str = "msum",
  ):
    # Contains a MNIST dataset
    # self.mnist_dataset = torchvision.datasets.MNIST(
    self.mnist_dataset = torchvision.datasets.CIFAR10(
      root,
      train=train,
      transform=transform_train if train else transform_test,
      target_transform=target_transform,
      download=download,
    )
    self.sum_n = sum_n
    self.index_map = list(range(len(self.mnist_dataset)))
    random.shuffle(self.index_map)
    if num_training_samples > 0:
      self.index_map = self.index_map[:num_training_samples * self.sum_n]
    self.dataset = dataset

  def __len__(self):
    return int(len(self.index_map) / self.sum_n)

  def __getitem__(self, idx):
     # Get n data points
    imgs = ()
    labels = ()
    if self.dataset == "msum":
      sum = 0
      for i in range(self.sum_n):
        img, digit = self.mnist_dataset[self.index_map[idx*self.sum_n + i]]
        imgs = imgs + (img,)
        labels = labels + (digit,)
        sum += digit 
    else:
      assert self.dataset == "mmax"
      sum = 0
      for i in range(self.sum_n):
        img, digit = self.mnist_dataset[self.index_map[idx*self.sum_n + i]]
        imgs = imgs + (img,)
        labels = labels + (digit,)
        sum = max(sum, digit)
    # Each data has two images and the GT is the sum of n digits
    return (*imgs, labels, sum)

  @staticmethod
  def collate_fn(batch):
    imgs = ()
    labels = ()
    for i in range(len(batch[0])-2):
      a = torch.stack([item[i] for item in batch])
      imgs = imgs + (a,)
    digits = torch.stack([torch.tensor(item[len(batch[0])-1]).long() for item in batch])
    labels = [item[len(batch[0])-2] for item in batch]
    return ((imgs), digits, (labels))
  

# main questions: what exactly is classifier $f$? Seems to be the MNIST classifier without the symbolic layer, but not sure.
# Why is D' := D in Algo 1? D' should be empty and get populated with data via the purification process...
# how to set the thresholds? And what is the purpose of the step threshold?
# what classifies as pretraining f? Just one epoch of training?
# when training, we can use the candidate label vectors to instantiate distributions. However, the distributions will still
#     consider all possible combinations of candidate label vectors. Is this correct, or does the preimage also determine the
#     possible combinations?
# shouldn't eps keep increasing? Otherwise eps will always be < eps_max and the loop will never terminate.

# while debugging: trace the amount of labels that are removed from the candidate label vectors, which ones stay, and if
# we prune out a gold label vector

class MNISTSumNDatasetPreImage(torch.utils.data.Dataset):
  def __init__(
    self,
    sum_data: MNISTSumNDataset | List[Tuple],
    pre_image: Dict[int, List[Tuple]],
  ):
    # Contains a MNIST dataset

    length = len(sum_data)
    self.data = []
    for i in range(length):
      data = sum_data[i]
      imgs = data[:-2]
      labels = data[-2]
      sum = data[-1]
      proofs = pre_image[sum]
      self.data.append((*imgs, labels, sum, proofs))

  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
     # Get n data points
    return self.data[idx]
  
  def get_proofs(self):
    return [ item[-1] for item in self.data ]
  
  def update_proofs_samplewise(self, pre_image):
    updated_data = []
    for i in range(len(self.data)):
      data = self.data[i]
      data_without_proofs = data[:-1]
      updated_data.append((*data_without_proofs, pre_image[i]))
    new_dataset = MNISTSumNDatasetPreImage([], [])
    new_dataset.data = updated_data
    return new_dataset

  @staticmethod
  def collate_fn(batch):
    imgs = ()
    labels = ()
    for i in range(len(batch[0])-3):
      a = torch.stack([item[i] for item in batch])
      imgs = imgs + (a,)
    digits = torch.stack([torch.tensor(item[len(batch[0])-2]).long() for item in batch])
    proofs = [item[len(batch[0])-1] for item in batch]
    labels = [item[len(batch[0])-3] for item in batch]
    return ((imgs), digits, labels, proofs)


def mnist_sum_n_loader(data_dir, sum_n, batch_size_train, batch_size_test, num_training_samples=-1, dataset="msum"):
  train_dataset = MNISTSumNDataset(
    data_dir,
    sum_n,
    train=True,
    download=True,
    # transform=mnist_img_transform,
    num_training_samples=num_training_samples,
    dataset=dataset,
  )
  train_loader = torch.utils.data.DataLoader(
    train_dataset,
    collate_fn=MNISTSumNDataset.collate_fn,
    batch_size=batch_size_train,
    shuffle=True
  )
  
  test_dataset = MNISTSumNDataset(
    data_dir,
    sum_n,
    train=False,
    download=True,
    # transform=mnist_img_transform,
    dataset=dataset,
  )
  test_loader = torch.utils.data.DataLoader(
    test_dataset,
    collate_fn=MNISTSumNDataset.collate_fn,
    batch_size=batch_size_test,
    shuffle=True
  )

  return train_loader, test_loader


class MNISTNet(nn.Module):
  def __init__(self):
    super(MNISTNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
    self.fc1 = nn.Linear(1024, 1024)
    self.fc2 = nn.Linear(1024, 10)

  def forward(self, x):
    x = F.max_pool2d(self.conv1(x), 2)
    x = F.max_pool2d(self.conv2(x), 2)
    x = x.view(-1, 1024)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, p = 0.5, training=self.training)
    x = self.fc2(x)
    return F.softmax(x, dim=1)
  
class CIFAR10Net(nn.Module):
  def __init__(self):
    super().__init__()
    self.network = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size=3, padding=1),
        nn.ReLU(),
        nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2), # output: 64 x 16 x 16

        nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2), # output: 128 x 8 x 8

        nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2, 2), # output: 256 x 4 x 4

        nn.Flatten(), 
        nn.Linear(256*4*4, 1024),
        nn.ReLU(),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Linear(512, 10))
        
  def forward(self, xb):
    return F.softmax(self.network(xb), dim=1)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_planes, planes, stride=1):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(self.expansion*planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        out = F.softmax(out, dim=1)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34():
    return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50():
    return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101():
    return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152():
    return ResNet(Bottleneck, [3, 8, 36, 3])

class MNISTSumNNet(nn.Module):
  def __init__(self, dataset="msum"):
    super(MNISTSumNNet, self).__init__()

    # MNIST Digit Recognition Network
    # self.mnist_net = CIFAR10Net()
    self.mnist_net = ResNet18()
    # self.mnist_net = MNISTNet()
    self.dataset = dataset
    self.lam = None
    if dataset == "mmax":
       self.lam = lambda x, y : max(x, y)
    else:
       self.lam = lambda x, y : x + y

  def forward(self, x: Tuple[torch.Tensor, torch.Tensor], labels=None, pre_image = None):
    if labels is not None:
      labels = torch.tensor(labels).to(x[0].device)
      correct = 0
    for i in range(len(x)):
      if PreImage.record_preimage:
        probs = torch.rand((1, 10), device="cpu")
        print("Processing preimage for sums up to", i+1, "digits")
      else:
        probs = self.mnist_net(x[i])
        if labels is not None:
          preds = probs.data.max(1, keepdim=True)[1]
          correct += preds.eq(labels[:,i].data.view_as(preds)).sum()

      if pre_image is None:
        digits = range(10)
        a = Distribution(probs, digits)
        if i == 0:
          res = a
        else:
          res = res.apply(a, self.lam)
          # res = res + a
      else:
        batch_size = len(pre_image)
        # print(pre_image)
        a = []
        
        for idx in range(batch_size):
          label_vector = pre_image[idx]
          # print(labels[idx].tolist())
          # exit()
          # print(label_vector)
          # print(labels[idx])
          digits = list(set(( label[i] for label in label_vector )))
          # exit()
          digits.sort()
          # print(digits)
          a.append(Distribution(probs[idx, :].view(1, -1), range(10)).map_symbols(digits))
        
        a = Distribution.stack(a)
        if i == 0:
            res = a
        else:
          res = res.apply(a, self.lam)
          # res = res + a

    if PreImage.record_preimage:
      return res.preimage
    if labels is None:
      return res.map_symbols(range(len(x) * 9 + 1)).get_probabilities() # Tensor b x 19
    else:
      return res.map_symbols(range(len(x) * 9 + 1)).get_probabilities(), correct


def bce_loss(output, ground_truth):
  gt = torch.nn.functional.one_hot(ground_truth, num_classes=(sum_n*9)+1).float()
  return F.binary_cross_entropy(output, gt)


def nll_loss(output, ground_truth):
  return F.nll_loss(output, ground_truth)


class Trainer():
  def __init__(self, train_loader, test_loader, model_dir, learning_rate, loss, provenance, device, k, sum_n, args, dataset="msum"):
    self.device = device
    self.model_dir = model_dir
    Distribution.provenance = get_provenance(provenance)
    # if k > 0:
    Distribution.provenance.k = k
    self.network = MNISTSumNNet().to(self.device)
    self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
    self.train_loader = train_loader
    self.test_loader = test_loader
    self.best_loss = 10000000000
    self.provenance = provenance
    self.sum_n = sum_n
    self.best_acc = 0
    self.best_acc_class = 0
    self.with_preimage = args.with_preimage
    self.with_purification = args.with_purification
    self.dataset = dataset
    if loss == "nll":
      self.loss = nll_loss
    elif loss == "bce":
      self.loss = bce_loss
    else:
      raise Exception(f"Unknown loss function `{loss}`")

  def train_epoch(self, epoch, train_loader=None):
    if train_loader is None:
      train_loader = self.train_loader
    self.network.train()
    iter = tqdm(train_loader, total=len(train_loader))
    t_begin_epoch = time.time()
    for batch in iter:
      preimage = None
      if len(batch) == 3:
        data, target, labels = batch
      else:
        data, target, labels, preimage = batch
      imgs = ()
      for x in range(self.sum_n):
        imgs = imgs + (data[x].to(self.device),)
      target = target.to(self.device)
      self.optimizer.zero_grad()
      
      output, _ = self.network(imgs, labels=labels, pre_image=preimage)
      # print(labels, target)
      # exit()
      loss = self.loss(output, target)
      loss.backward()
      self.optimizer.step()
      iter.set_description(f"[Train Epoch {epoch}] Loss: {loss.item():.4f}")        
        
    total_epoch_time = time.time() - t_begin_epoch
    wandb.log(
      {
        "epoch": epoch,
        "total_epoch_time": total_epoch_time,
      }
    )
    print(f"Total Epoch Time: {total_epoch_time}")

  def test_epoch(self, epoch):
    self.network.eval()
    num_items = len(self.test_loader.dataset)
    test_loss = 0
    correct = 0
    correct_classified = 0
    with torch.no_grad():
      iter = tqdm(self.test_loader, total=len(self.test_loader))
      for (data, target, labels) in iter:
        imgs = ()
        for x in range(sum_n):
          imgs = imgs + (data[x].to(self.device),)
        target = target.to(self.device)

        output, correct_ind = self.network(imgs, labels=labels)
        # output = F.softmax(output, dim=1)
        # print(output)
        test_loss += self.loss(output, target).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
        correct_classified += correct_ind
        correct_perc = 100. * correct_classified / (num_items * self.sum_n)
        perc = 100. * correct / num_items
        if perc > 97.00:
          # record sum_n + epoch number combination when accuracy is high
          file_path = f'torch_mnist_sum_n_{self.provenance}_epoch_count.log'
          current_timestamp = datetime.now()
          formatted_timestamp = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
          if os.path.exists(file_path):
            with open(file_path, 'a') as file:
              file.write(f'sum n={self.sum_n}, epoch num={epoch}, {formatted_timestamp}\n')
          else:
             with open(file_path, 'w') as file:
              file.write(f'sum n={self.sum_n}, epoch num={epoch}, {formatted_timestamp}\n')
        iter.set_description(f"[Test Epoch {epoch}] Total loss: {test_loss:.4f}, Accuracy: {correct}/{num_items} ({perc:.2f}%), Classifier Accuracy: {correct_classified}/{num_items * self.sum_n} ({correct_perc:.2f}%)")
      wandb.log(
        {
          "epoch": epoch,
          "accuracy": perc,
          "per_class_accuracy": correct_perc,
        }
      )
      if test_loss < self.best_loss:
        self.best_loss = test_loss
        # torch.save(self.network, os.path.join(model_dir, "sum_2_best.pt"))
      if perc > self.best_acc or correct_perc > self.best_acc_class:
        self.best_acc = perc
        self.best_acc_class = correct_perc
      print(f"Best loss: {self.best_loss:.4f}")
      print(f"Best acc: {self.best_acc:.2f}%")

  def train(self, n_epochs):

    self.test_epoch(0)
    for epoch in range(1, n_epochs + 1):
      self.train_epoch(epoch)
      self.test_epoch(epoch)

  def get_preimage(self):
    if self.dataset == "msum":
      with PreImage() as record_preimage:
        print("Record Preimage:", record_preimage.record_preimage)
        preimage = self.network(list(range(self.sum_n)))
        for key, value in preimage.items():
          preimage[key] = [ list(v.values()) for v in value ]
        return preimage
    elif self.dataset == "mmax":
      digits = list(range(10))
      ip = [digits for _ in range(self.sum_n)]

      # all possible combinations of the digits
      # use itertools
      all_combinations = list(itertools.product(*ip))

      # getting the preimage for max
      preimage = {}
      for comb in all_combinations:
        max_val = max(comb)
        if max_val not in preimage:
            preimage[max_val] = []
        preimage[max_val].append(list(comb))
      return preimage
  
  def train_w_purification(self, n_epochs, warmup_epochs=1, synthetic=None):

    # getting the preimage
    preimage = self.get_preimage()

    # instantiate the preimage dataset

    train_w_preimage = MNISTSumNDatasetPreImage(self.train_loader.dataset, preimage)
    train_w_preimage_loader = torch.utils.data.DataLoader(
      train_w_preimage,
      collate_fn=MNISTSumNDatasetPreImage.collate_fn,
      batch_size=self.train_loader.batch_size,
      shuffle=True
    )

    # pretrain the network with D
    self.test_epoch(0)

    # define the purification process
    def purification_old(dataset: MNISTSumNDatasetPreImage, eps_0, eps_max):
      eps = eps_0
      D = dataset.get_proofs()
      D_loader = torch.utils.data.DataLoader(
        dataset,
        collate_fn=MNISTSumNDatasetPreImage.collate_fn,
        batch_size=self.train_loader.batch_size,
        shuffle=False
      )
      trial = 1
      with torch.no_grad():
        self.network.mnist_net.eval()
        while True:
            proof_changes = 0
            gt_removed = 0
            total_processed = 0
            D_new = []
            iter = tqdm(D_loader, f"Purifying Dataset (Trial {trial}, eps={eps}, changes={proof_changes}, gt_removed={gt_removed})")
            for (data, target, labels, proofs) in iter:
                batch_size = len(data[0])
                imgs = ()
                for x in range(self.sum_n):
                    imgs = imgs + (data[x].to(self.device),)
                target = target.to(self.device)
                ops = [ self.network.mnist_net(imgs[x]) for x in range(self.sum_n) ]
                preds = [ op.data.max(1, keepdim=True)[1] for op in ops ]

                for bidx in range(batch_size):
                    original_proof = proofs[bidx]
                    for i in range(len(ops)):
                        pred = preds[i][bidx]
                        for col in range(len(ops[i][bidx])):
                            if ops[i][bidx][pred] - ops[i][bidx][col] > eps:
                                proofs[bidx] = [ p for p in proofs[bidx] if p[i] != col ]
                                # if labels[bidx][i] == col:
                                #   gt_removed += 1
                    proof_changes += len(original_proof) - len(proofs[bidx])
                    if list(labels[bidx]) not in proofs[bidx]:
                      gt_removed += 1
                    total_processed += 1
                    
                    D_new.append(proofs[bidx])
                    
                iter.set_description(f"Purifying Dataset (Trial {trial}, eps={eps}, changes={proof_changes}, gt_removed={gt_removed}/{total_processed})")
            wandb.log({"gt_removed": gt_removed})

            if D == D_new and eps < eps_max and eps > 0:
                eps = eps - get_eps_step(eps) # check
                trial += 1
            else:
                print(f"Trial {trial} completed with {proof_changes} changes")
                
                break
        return dataset.update_proofs_samplewise(D_new)
      
    eps = 0.999
    eps_max = 1
    eps_step = 0.01

    def get_eps_step(eps):
      if eps <= 0.015:
        return 0.001
      elif eps <= 0.15:
        return 0.01
      elif eps <= 0.5:
        return 0.1
      else:
        return 0.25
      

    wandb.log(
      {
        "eps": eps,
        "eps_max": eps_max,
        "eps_step": eps_step,
      }
    )
    
    # for epoch in range(1, n_epochs + 1):
    #     if epoch < warmup_epochs:
    #       self.train_epoch(epoch, train_w_preimage_loader)
    #     else:
    #       self.train_epoch(epoch, train_w_preimage_loader)
    #     self.test_epoch(epoch)
    #     # Purification
    #     if self.with_purification and epoch >= warmup_epochs:
    #         if epoch != n_epochs:
    #             train_w_preimage_purified = purification(train_w_preimage, eps, eps_max, MNISTSumNDatasetPreImage.collate_fn, self.train_loader.batch_size, self.sum_n, self.network.mnist_net, self.device, synthetic=synthetic)
    #             train_w_preimage_loader = torch.utils.data.DataLoader(
    #                 train_w_preimage_purified,
    #                 collate_fn=MNISTSumNDatasetPreImage.collate_fn,
    #                 batch_size=self.train_loader.batch_size,
    #                 shuffle=True
    #             )

     # for warmup_epoch in range(1, warmup_epochs + 1):
    warmup_epoch = 1
    converged_epoch = 0
    best_acc = 0
    best_model = None
    best_acc_class = 0
    # while True:
    #   print(f"Warmup Epoch {warmup_epoch}")
    #   self.train_epoch(warmup_epoch, train_loader=self.train_loader)
    #   self.test_epoch(warmup_epoch)
    #   if self.best_acc > best_acc:
    #     best_acc = self.best_acc
    #     best_acc_class = self.best_acc_class
    #     converged_epoch = warmup_epoch
    #     best_model = self.network.state_dict()
    #   if warmup_epoch - converged_epoch > 3:
    #     break
    #   warmup_epoch += 1
    # print("MODEL CONVERGED AT EPOCH", warmup_epoch, "WITH ACCURACY", self.best_acc, "AND ACCURACY CLASS", self.best_acc_class)
    # self.network.load_state_dict(best_model)

    epoch = 1
    max_acc_epoch = 0
    max_acc = 0
    while True:
        # self.train_epoch(epoch, self.train_loader)
        # if epoch < warmup_epochs:
        #     self.train_epoch(epoch, dataloader=self.train_loader)
        # else:
        # Purification
        if self.with_purification:
            if epoch != n_epochs:
                train_w_preimage_purified = purification(train_w_preimage, eps, eps_max, MNISTSumNDatasetPreImage.collate_fn, self.train_loader.batch_size, self.sum_n, self.network.mnist_net, self.device) #, synthetic=self.synthetic)
                train_w_preimage_loader = torch.utils.data.DataLoader(
                    train_w_preimage_purified,
                    collate_fn=MNISTSumNDatasetPreImage.collate_fn,
                    batch_size=self.train_loader.batch_size,
                    shuffle=True
                )
        print("Training with Preimage")

        self.train_epoch(epoch, train_loader=train_w_preimage_loader)
        # self.test_epoch(epoch)
        self.test_epoch(epoch)
        if self.best_acc > max_acc:
            max_acc = self.best_acc
            max_acc_epoch = epoch
        if epoch - max_acc_epoch > 3:
            break
        epoch += 1

    print("Converged at epoch", converged_epoch, "with accuracy", f"{float(best_acc):.2f} ({float(best_acc_class):.2f})")
    print("After purification, the model achieved", f"{float(self.best_acc):.2f} ({float(self.best_acc_class):.2f})")


    # self.test_epoch(0)
    # for epoch in range(1, n_epochs + 1):
    #   self.train_epoch(epoch)
    #   self.test_epoch(epoch)

if __name__ == "__main__":
  # Argument parser
  parser = ArgumentParser()
  parser.add_argument("--sum-n", type=int, default=30)
  parser.add_argument("--n-epochs", type=int, default=5)
  parser.add_argument("--batch-size-train", type=int, default=64)
  parser.add_argument("--batch-size-test", type=int, default=64)
  parser.add_argument("--learning-rate", type=float, default=0.001)
  parser.add_argument("--loss-fn", type=str, default="bce")
  parser.add_argument("--seed", type=int, default=1234)
  parser.add_argument("--provenance", type=str, default="damp", choices=['damp', 'dmmp', 'dtkp-am'])
  parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
  parser.add_argument("--topk", type=int, default=1)
  parser.add_argument("--only_preimage", action="store_true")
  parser.add_argument("--with_preimage", action="store_true")
  parser.add_argument("--with_purification", action="store_true")
  parser.add_argument("--num-training-samples", type=int, default=-1)
  parser.add_argument("--dataset", type=str, default="msum", choices=["msum", "mmax"])
  parser.add_argument("--synthetic", type=int, default=None)
  args = parser.parse_args()

  print(args)

  # Parameters
  sum_n = args.sum_n
  n_epochs = args.n_epochs
  batch_size_train = args.batch_size_train
  batch_size_test = args.batch_size_test
  learning_rate = args.learning_rate
  loss_fn = args.loss_fn
  provenance = args.provenance
  torch.manual_seed(args.seed)
  random.seed(args.seed)

  config = {
    "sum_n": sum_n,
    "n_epochs": n_epochs,
    "batch_size_train": batch_size_train, 
    "batch_size_test": batch_size_test,
    "provenance": provenance,
    "seed": args.seed,
    "experiment_type": "dolphin", 
    "with_preimage": args.with_preimage,
    "with_purification": args.with_purification,
  }

  # Data
  data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data"))
  model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../model/mnist_sum_{sum_n}'))
  os.makedirs(model_dir, exist_ok=True)

  # Dataloaders
  train_loader, test_loader = mnist_sum_n_loader(data_dir, sum_n, batch_size_train, batch_size_test, num_training_samples=args.num_training_samples, dataset=args.dataset)

  if args.device == "cuda" and torch.cuda.is_available():
    device_name = "cuda"
  elif args.device == "mps" and torch.backends.mps.is_available():
    device_name = "mps"
  else:
    device_name = "cpu"

  device = torch.device(device_name)
  timestamp = datetime.now()
  id = f'dolphin_sum{sum_n}_{args.seed}_{provenance}_{timestamp.strftime("%Y-%m-%d %H-%M-%S")}'

  wandb.init(
    project="PURIFICATION CIFAR-10", config=config, id=id, mode="disabled"
  )
  # Create trainer and train

  trainer = Trainer(train_loader, test_loader, model_dir, learning_rate, loss_fn, provenance, device, args.topk, sum_n, args, dataset=args.dataset)
  if not args.only_preimage:
    if args.with_purification or args.with_preimage:
      trainer.train_w_purification(n_epochs, warmup_epochs=1, synthetic=args.synthetic)
    else:
      trainer.train(n_epochs)
  else:
    trainer.get_preimage()