from model.generators.gran import GRAN

import torch
import torch.optim as optim
import torch.nn as nn
from easydict import EasyDict as edict

def test():
  config = {
    'generator': {
      'search_space': 'custom',
      'name': 'GRAN',
      'device': 'cpu',
      'num_mix_component': 20,
      'is_sym': False,
      'block_size': 1,
      'sample_stride': 1,
      'max_num_nodes': 5,
      'num_node_labels': 0,
      'num_edge_labels': 8,
      'hidden_dim': 128,
      'embedding_dim': 128,
      'num_GNN_layers': 7,
      'num_GNN_prop': 1,
      'num_canonical_order': 1,
      'dimension_reduce': True,
      'has_attention': True,
      'edge_weight': 1.0e+0,
    },
    'nas': {
      'baseline': 0
    }
  }
  config = edict(config)
  generator = GRAN(config)
  print("ok!")
  print(generator({'is_sampling': True, 'batch_size': 2}))

  data = [{'adj': torch.tensor([[[[0., 0., 0., 0., 0.],
  [1., 0., 0., 0., 0.],
  [0., 2., 0., 0., 0.],
  [4., 0., 3., 0., 0.],
  [5., 0., 0., 6., 0.]]]]),
  'att_idx': torch.tensor([1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1]),
  'edges': torch.tensor([[ 1,  2 ],
    [ 2,  1 ],
    [ 3,  4 ],
    [ 3,  5 ],
    [ 4,  3 ],
    [ 4,  5 ],
    [ 5,  3 ],
    [ 5,  4 ],
    [ 6,  7 ],
    [ 6,  9 ],
    [ 7,  6 ],
    [ 7,  8 ],
    [ 7,  9 ],
    [ 8,  7 ],
    [ 8,  9 ],
    [ 9,  6 ],
    [ 9,  7 ],
    [ 9,  8 ],
    [10, 11],
    [10, 13],
    [10, 14],
    [11, 10],
    [11, 12],
    [11, 14],
    [12, 11],
    [12, 13],
    [12, 14],
    [13, 10],
    [13, 12],
    [13, 14],
    [14, 10],
    [14, 11],
    [14, 12],
    [14, 13]]),
  'label': torch.tensor([0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 0., 0., 1., 0.]),
  'edge_label': torch.tensor([0., 1., 0., 0., 2., 0., 4., 0., 3., 0., 5., 0., 0., 6., 0.]),
  'edge_feat': torch.tensor([
    [1., 0., 0., 0., 0., 0., 0., 0.,], # 1 2
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 1., 0., 0., 0., 0., 0., 0.,], # 3 4 5
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 1., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 1., 0., 0., 0., 0., 0., 0.,], # 6 7 8 9
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 1., 0., 0., 0., 0., 0., 0.,],
    [0., 0., 1., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 0., 1., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 1., 0., 0., 0., 0., 0., 0.,], # 10 11 12 13 14
    [0., 0., 0., 0., 1., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 1., 0., 0., 0., 0., 0., 0.,],
    [0., 0., 1., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 0., 1., 0., 0., 0., 0., 0.,],
    [0., 0., 0., 1., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [0., 0., 0., 0., 1., 0., 0., 0.,],
    [0., 0., 0., 1., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
    [1., 0., 0., 0., 0., 0., 0., 0.,],
  ]),
  'node_idx_feat': torch.tensor([0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4, 0]),
  'node_ids': torch.tensor([0, 0, 1, 0, 1, 2, 0, 1, 2, 3, 0, 1, 2, 3, 4]),
  'node_idx_gnn': torch.tensor([[ 0,  0 ],
    [ 2,  1 ],
    [ 2,  2 ],
    [ 5,  3 ],
    [ 5,  4 ],
    [ 5,  5 ],
    [ 9,  6 ],
    [ 9,  7 ],
    [ 9,  8 ],
    [ 9,  9 ],
    [14, 10],
    [14, 11],
    [14, 12],
    [14, 13],
    [14, 14]]),
  'num_nodes_gt': torch.tensor([5]),
  'subgraph_idx': torch.tensor([0, 1, 1, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4]),
  'subgraph_idx_base': torch.tensor([0, 5])} ]
  print(f"target: {data[0]['adj']}")

  epochs = 500
  optimizer = optim.SGD(generator.parameters(), lr=0.2, momentum=0.9, weight_decay=0.001)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs+300)
  generator.train()
  for i in range(epochs):
    optimizer.zero_grad()
    loss = generator(data[0])
    loss.backward()
    if i % 100 == 0:
      print(f"tloss {loss.item()}")
    optimizer.step()
    scheduler.step()
  data[0]['baseline'] = 1
  generator(data[0])

  print(generator({'is_sampling': True, 'batch_size': 2}))


