from dataset import FlowDataset, FlowDatasetTest

import os
import yaml
import sys
import random
import numpy as np
import sklearn.metrics
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from copy import deepcopy
from softgym.envs.corl_baseline import GCFold
from flow_utils import flow2img
from torch.optim.lr_scheduler import ReduceLROnPlateau

import hydra
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning.utilities.seed as seed_utils
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from models import FlowNetSmall

def train(epoch, cfg, model):
    self.model.train()
    train_losses = []
    for j, batch in enumerate(self.train_loader):
        depth_input = batch['depths'].cuda()
        flow_label = batch['flow_lbl'].cuda()
        
        flow_pred = self.model(depth_input)
        loss_mask = batch['loss_mask'].cuda()
        train_loss = self.loss(flow_pred, flow_label, loss_mask)

        self.opt.zero_grad()
        train_loss.backward()
        self.opt.step()

        loss = train_loss.item()
        train_losses.append(loss)
        print(f"Train loss epoch {epoch} batch {j}: {loss}")

        if 'max_iters' in self.cfg and j > self.cfg['max_iters']:
            break

    if epoch % self.cfg['eval_interval'] == 0:
        self.plot_batch(epoch, j, batch, flow_pred) # plot train

    avg_train_loss = np.mean(train_losses)
    return avg_train_loss

@hydra.main(config_name="config")
def main(cfg):
    seed_utils.seed_everything(cfg.seed)

    # Load datasets
    camera_params = {'default_camera': { 'pos': np.array([-0.0, 0.65, 0.0]),
                     'angle': np.array([0, -np.pi/2., 0.]),
                     'width': 720,
                     'height': 720}}
    # Train/val split needs to happen outside dataset class
    # to avoid random nobs switch from using val images
    trainpath = f'{cfg.basepath}/{cfg.trainname}'
    valpath = f'{cfg.basepath}/{cfg.valname}'
    trainfs = sorted([int(fn.split('_')[0])
                for fn in os.listdir(f'{trainpath}/rendered_images') 
                if 'before' in fn])
    if cfg['max_train'] != None:
        trainfs = trainfs[:cfg['max_train']]
        print(f"Max training set: {len(trainfs)}")
    valfs = sorted([int(fn.split('_')[0])
                for fn in os.listdir(f'{valpath}/rendered_images') 
                if 'before' in fn])
    if cfg['max_val'] != None:
        valfs = valfs[:cfg['max_val']]
        print(f"Max val set: {len(valfs)}")
    
    train_data = FlowDataset(cfg, trainfs, camera_params, stage='train')
    val_data = FlowDataset(cfg, valfs, camera_params, stage='val')
    train_loader = DataLoader(train_data, batch_size=cfg.batch, shuffle=True, num_workers=cfg.n_workers)
    val_loader = DataLoader(val_data, batch_size=cfg.batch, shuffle=False, num_workers=cfg.n_workers)

    flownet = FlowNetSmall(**cfg.netconf)

    csv_logger = pl_loggers.CSVLogger(save_dir=cfg.csvlogs)
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.tblogs)
    chkpt_cb = ModelCheckpoint(monitor='loss/val', save_top_k=3)
    trainer = pl.Trainer(gpus=cfg.gpu, 
                         logger=[csv_logger, tb_logger],
                         max_epochs=cfg.epochs,
                         check_val_every_n_epoch=cfg.epochsperval,
                         callbacks=[chkpt_cb])

    trainer.fit(flownet, train_loader, val_loader)

if __name__ == '__main__':
    main()