from dataset import FlowDataset, FlowDatasetTest

import os
import yaml
import sys
import random
import numpy as np
import sklearn.metrics
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from copy import deepcopy
from softgym.envs.corl_baseline import GCFold
from flow_utils import flow2img
from torch.optim.lr_scheduler import ReduceLROnPlateau

import hydra
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning.utilities.seed as seed_utils
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from models import FlowNetSmall

def get_checkpoint(loadpath):
    fs = os.listdir(f'{loadpath}/../default_default/0_0/checkpoints/')
    fs = [x for x in fs if 'last' not in x]
    _, ckpt_name = sorted([(int(x.split('-')[0].replace('epoch=', '').replace('.ckpt', '')), x) for x in fs], reverse=True)[0]
    print(ckpt_name)
    checkpoint = torch.load(f'{loadpath}/../default_default/0_0/checkpoints/{ckpt_name}')
    return checkpoint

@hydra.main(config_name="config")
def main(cfg):
    seed_utils.seed_everything(cfg.seed)

    # Load datasets
    # camera_params = {'default_camera': { 'pos': np.array([-0.0, 0.45, 0.0]),
    camera_params = {'default_camera': { 'pos': np.array([-0.0, 0.65, 0.0]),
                     'angle': np.array([0, -np.pi/2., 0.]),
                     'width': 720,
                     'height': 720}}

    test_data = FlowDatasetTest(cfg, camera_params)
    test_loader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=cfg.n_workers)

    flownet = FlowNetSmall(**cfg.netconf)
    checkpoint = get_checkpoint(cfg.loadpath)
    flownet.load_state_dict(checkpoint['state_dict'])

    csv_logger = pl_loggers.CSVLogger(save_dir=cfg.csvlogs)
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.tblogs)
    trainer = pl.Trainer(gpus=cfg.gpu,
                         logger=[csv_logger, tb_logger],
                         )

    res = trainer.test(flownet, test_dataloaders=test_loader)
    avg_epe = np.mean([x['loss'] for x in res[1:]])
    print(f'Average epe: {avg_epe}')
    np.save('epes.npy', res[1:])

if __name__ == '__main__':
    main()

    