"""
So the issue to test/showcase here is a quirk with how Dataloader interacts
with multiprocessing. Basically, for some reason, dataloader starts n workers
at the beginning of each epoch (not during initialization). Normally, using
fork, this is no issue, but when using spawn (which is much much slower on
linux) this is a very noticable performance hit. Here we compare

(1) using fork, new workers each epoch
(2) using spawn, new workers each epoch
(3) using spawn, reusing workers

We see that (1) is much faster than (2) (10+ seconds PER EPOCH), but only a bit
faster than (3) (~10 seconds overall).

The fixed FastDataLoader was pulled from some online source
"""


import sys
# shouldn't be hardcoded...
sys.path.append("/Users/jeffr/src/Probabilistic-NAS/")
# the __main__ function must be in this file for multiprocessing init reasons
import time

from utils.torch_helper import FastDataLoader

import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim

class Net(nn.Module):

  def __init__(self):
    super().__init__()
    self.l1 = nn.Linear(1, 5)
    self.l2 = nn.Linear(5, 1)

  def forward(self, x):
    x = self.l1(x)
    x = self.l2(x)
    return x


class Dataset:

  def __init__(self, data):
    self.data = data

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    time.sleep(0.02)
    return self.data[index]

  def collate_fn(self, batch):
    time.sleep(0.05)
    return torch.tensor(batch).float().view(-1, 1)

def main(fast=False):
  net = Net()

  dataset = Dataset(list(range(20)))

  DL = FastDataLoader if fast else torch.utils.data.DataLoader
  train_loader = DL(
    dataset,
    batch_size=5,
    shuffle=False,
    num_workers=2,
    collate_fn=dataset.collate_fn,
    drop_last=False
  )

  optimizer = optim.Adam(net.parameters(), lr=0.002, weight_decay=0.001)
  prev = time.time()
  for epoch in range(300):
    for batch in train_loader:
      # print("batch load time", time.time() - prev)
      prev = time.time()
      optimizer.zero_grad()

      out = net(batch)
      loss = ((out - batch ** 2) ** 2).mean()
      loss.backward()

      optimizer.step()
      # print("forward time", time.time() - prev)
      prev = time.time()


if __name__ == '__main__':
  mp.set_start_method('spawn') # very slow at beginnign of each epoch
  #mp.set_start_method('fork') # quite a bit faster

  prev = time.time()
  main(fast=False)
  print(f"TOTAL DEFAULT {time.time() - prev}")
  prev = time.time()
  main(fast=True)
  print(f"TOTAL FAST {time.time() - prev}")

