"""
These functionalities should be refactored,
part of this goes into each model,
part of this goes to a unified script.
"""
import os
import time
import numpy as np
import torch
import hydra

from pathlib import Path
from torch.utils.data import DataLoader
from pytorch_lightning import seed_everything

from graphwm.data.datamodule import worker_init_fn
from graphwm.data.utils import dict_collate_fn
from graphwm.common import PROJECT_ROOT
from graphwm.model import GNS, PnR

# evaluation
# from sklearn.metrics import r2_score
# from scipy.stats import wasserstein_distance

MODELS = {
    'gns': GNS,
    'pnr': PnR
}

@hydra.main(config_path=str(PROJECT_ROOT / "conf"), config_name="eval")
def main(cfg):
  seed_everything(cfg.random_seed)
  
  model_dir = Path(cfg.model_dir)
  dataclass, modelclass = model_dir.parts[-1].split('_')[:2]
  last_model_dir = Path(cfg.model_dir) / 'last.ckpt'
  if last_model_dir.exists():
    ckpt_path = str(last_model_dir)
  else:   
    ckpts = list(Path(cfg.model_dir).glob('*.ckpt'))
    ckpt_epochs = np.array([int(ckpt.parts[-1].split('-')[0].split('=')[1]) for ckpt in ckpts])
    ckpt_path = str(ckpts[ckpt_epochs.argsort()[-1]])
  print(f'load checkpoint: {ckpt_path}')
  
  model = MODELS[modelclass].load_from_checkpoint(ckpt_path)
  # prepare data
  dataset = hydra.utils.instantiate(cfg.data, 
                                    seq_len=model.hparams.seq_len, 
                                    dilation=model.hparams.dilation, 
                                    grouping=model.hparams.cg_level)
  data_loader = DataLoader(dataset, shuffle=False, batch_size=cfg.batch_size, num_workers=8, 
                          worker_init_fn=worker_init_fn, collate_fn=dict_collate_fn)
  
  model = model.to('cuda')
  model.eval()
  save_dir = Path(cfg.save_dir)
  os.makedirs(save_dir, exist_ok=True)
  outputs = []
  
  # adjust ld_kwargs
  now = time.time()
  if modelclass == 'pnr':
    # model.hparams.dyn_level = cfg.ld_kwargs.dyn_level
    model.hparams.step_per_sigma = cfg.ld_kwargs.step_per_sigma
    model.hparams.step_size = cfg.ld_kwargs.step_size
  
  last_idx = 0
  with torch.no_grad():
    for idx, batch in enumerate(data_loader):
      if idx == cfg.num_batches:
        break
      batch = {k: v.cuda() for k, v in batch.items()}
      output = model.simulate(batch, cfg.rollout_length // model.hparams.dilation,
                              save_positions=cfg.save_pos, save_frequency=cfg.save_frequency)
      output.update({k: v.detach().cpu() for k, v in batch.items()})
      
      # rg specific
      batch_size = batch['n_keypoint'].shape[0]
      outputs.append(output)
      last_idx += batch_size
  elapsed = time.time() - now
  
  outputs = {k: torch.cat([d[k] for d in outputs]) for k in outputs[-1].keys()}
  outputs['time_elapsed'] = elapsed
  outputs['model_params'] = model.hparams
  outputs['eval_cfg'] = cfg
  
  if modelclass == 'pnr':
    folder_name = f'nsteps{cfg.ld_kwargs.step_per_sigma}_stepsize_{cfg.ld_kwargs.step_size}'
  else:
    folder_name = 'rollouts'
  
  os.makedirs(save_dir / folder_name, exist_ok=True)
  torch.save(outputs, save_dir / folder_name / f'seed_{cfg.random_seed}.pt')
  
  # evalute here
  
  print(f'Finished {cfg.batch_size*cfg.num_batches} rollouts of {cfg.rollout_length} steps.')

if __name__ == "__main__":
  main()
