from dataset.nas_data import NASData
from model.evaluators.oracle import Oracle
from model.generators.gran import GRAN

import torch
import torch.optim as optim
from easydict import EasyDict as edict

DEBUG_PRINT = 0

def test():
  # test the dataset and gran simultaenously to simulate one nas iteration
  config = {
    'seed': 3,
    'generator': {
      'search_space': 'darts',
      'name': 'GRAN',
      'device': 'cpu',
      'num_mix_component': 20,
      'is_sym': False,
      'block_size': 1,
      'sample_stride': 1,
      'max_num_nodes': 6,
      'num_node_labels': 0,
      'num_edge_labels': 8,
      'hidden_dim': 24,
      'embedding_dim': 24,
      'num_GNN_layers': 7,
      'num_GNN_prop': 1,
      'num_canonical_order': 1,
      'dimension_reduce': True,
      'has_attention': True,
      'edge_weight': 1.0e+0,
    },
    'nas': {
      'baseline': 0,
      'ewma_alpha': 0.5,
      'reward': 'cdf',
      'keep_top': 10000,
    }
  }
  config = edict(config)

  dataset = NASData(config)
  dataset.append(data())

  generator = GRAN(config)

  train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=2,
    shuffle=False,
    num_workers=0,
    collate_fn=dataset.collate_fn,
    drop_last=False
  )
  params = filter(lambda p: p.requires_grad, generator.parameters())
  optimizer = optim.Adam(params, lr=0.002, weight_decay=0.001)
  for epoch in range(70):
    generator.train()

    train_loss = 0.
    for batch in train_loader:
      optimizer.zero_grad()

      loss = generator(batch)
      loss.backward()

      optimizer.step()

      train_loss += loss.item()

    if epoch % 5 == 4:
      print(f"Epoch={epoch}, Train loss={train_loss/len(train_loader)}")

  samples = generator.sample(3, debug=DEBUG_PRINT)
  for adj, _ in samples:
    print(adj)

  dataset.append([(g, 0.9) for g in samples])
  print(len(dataset))
  # adj1 = dataset.collate_fn([dataset[0], dataset[8]])
  adj1 = dataset.collate_fn([dataset[0], dataset[8]])
  adj1['debug'] = DEBUG_PRINT
  with torch.no_grad():
    print(generator(adj1))




def data():
  return [
    ((torch.tensor([[0., 1., 3., 0., 2., 3.],
                    [0., 0., 3., 3., 0., 0.],
                    [0., 0., 0., 3., 0., 3.],
                    [0., 0., 0., 0., 4., 0.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
    ((torch.tensor([[0., 1., 2., 3., 2., 5.],
                    [0., 0., 4., 0., 0., 0.],
                    [0., 0., 0., 3., 0., 0.],
                    [0., 0., 0., 0., 3., 0.],
                    [0., 0., 0., 0., 0., 5.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
    ((torch.tensor([[0., 1., 3., 2., 4., 3.],
                    [0., 0., 3., 3., 0., 0.],
                    [0., 0., 0., 0., 3., 3.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
    ((torch.tensor([[0., 1., 3., 4., 2., 5.],
                    [0., 0., 3., 3., 0., 0.],
                    [0., 0., 0., 0., 3., 0.],
                    [0., 0., 0., 0., 0., 5.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
    ((torch.tensor([[0., 1., 3., 2., 3., 5.],
                    [0., 0., 4., 0., 0., 0.],
                    [0., 0., 0., 4., 0., 0.],
                    [0., 0., 0., 0., 3., 0.],
                    [0., 0., 0., 0., 0., 5.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
    ((torch.tensor([[0., 1., 3., 3., 3., 5.],
                    [0., 0., 3., 3., 0., 0.],
                    [0., 0., 0., 0., 3., 0.],
                    [0., 0., 0., 0., 0., 5.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
    ((torch.tensor([[0., 1., 3., 3., 2., 5.],
                    [0., 0., 3., 0., 0., 0.],
                    [0., 0., 0., 3., 0., 0.],
                    [0., 0., 0., 0., 4., 5.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
    ((torch.tensor([[0., 1., 3., 3., 4., 5.],
                    [0., 0., 3., 0., 0., 0.],
                    [0., 0., 0., 3., 0., 0.],
                    [0., 0., 0., 0., 4., 5.],
                    [0., 0., 0., 0., 0., 0.],
                    [0., 0., 0., 0., 0., 0.]]), torch.tensor([0., 0., 0., 0., 0., 0.])), 0.9),
  ]

