from dataset import JointDataset

import os
import yaml
import sys
import random
import numpy as np
import sklearn.metrics
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from copy import deepcopy
from softgym.envs.corl_baseline import GCFold
from flow_utils import flow2img
from torch.optim.lr_scheduler import ReduceLROnPlateau

import hydra
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning.utilities.seed as seed_utils
import pytorch_lightning as pl
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint

import torch
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from collections import namedtuple
from models import FlowNetPickSplit

Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))

def train(epoch, cfg, model):
    self.model.train()
    train_losses = []
    for j, batch in enumerate(self.train_loader):
        depth_input = batch['depths'].cuda()
        flow_label = batch['flow_lbl'].cuda()
        
        flow_pred = self.model(depth_input)
        loss_mask = batch['loss_mask'].cuda()
        train_loss = self.loss(flow_pred, flow_label, loss_mask)

        self.opt.zero_grad()
        train_loss.backward()
        self.opt.step()

        loss = train_loss.item()
        train_losses.append(loss)
        print(f"Train loss epoch {epoch} batch {j}: {loss}")

        if 'max_iters' in self.cfg and j > self.cfg['max_iters']:
            break

    if epoch % self.cfg['eval_interval'] == 0:
        self.plot_batch(epoch, j, batch, flow_pred) # plot train

    avg_train_loss = np.mean(train_losses)
    return avg_train_loss

@hydra.main(config_name="config")
def main(cfg):
    seed_utils.seed_everything(cfg.seed)

    # Load datasets
    camera_params = {'default_camera': { 'pos': np.array([-0.0, 0.65, 0.0]),
                     'angle': np.array([0, -np.pi/2., 0.]),
                     'width': 720,
                     'height': 720}}
    # Train/val split needs to happen outside dataset class
    # to avoid random nobs switch from using val images
    trainpath = f'{cfg.basepath}/{cfg.trainname}'
    valpath = f'{cfg.basepath}/{cfg.valname}'
    trainfs = sorted([int(fn.split('_')[0])
                for fn in os.listdir(f'{trainpath}/rendered_images') 
                if 'before' in fn])
    buf_train = torch.load(f"{cfg.basepath}/{cfg.trainname}/{cfg.trainname}_idx.buf")
    if cfg['max_train'] != None:
        trainfs = trainfs[:cfg['max_train']]
        buf_train = buf_train[:cfg['max_train']]
        print(f"Max training set: {len(trainfs)}")
    valfs = sorted([int(fn.split('_')[0])
                for fn in os.listdir(f'{valpath}/rendered_images') 
                if 'before' in fn])
    buf_val = torch.load(f"{cfg.basepath}/{cfg.valname}/{cfg.valname}_idx.buf")
    if cfg['max_val'] != None:
        valfs = valfs[:cfg['max_val']]
        buf_val = buf_val[:cfg['max_val']]
        print(f"Max val set: {len(valfs)}")
    
    train_data = JointDataset(cfg, trainfs, buf_train, camera_params, stage='train')
    val_data = JointDataset(cfg, valfs, buf_val, camera_params, stage='val')
    train_loader = DataLoader(train_data, batch_size=cfg.batch, shuffle=True, num_workers=cfg.n_workers)
    val_loader = DataLoader(val_data, batch_size=cfg.batch, shuffle=False, num_workers=cfg.n_workers)

    if cfg.netconf.nettype == 'flownetpicksplit':
        model = FlowNetPickSplit(cfg=cfg.netconf).cuda()
        if cfg.loadpath != None:
            # get last ckpt
            ckpt_paths = os.listdir(f'{cfg.loadpath}/default_default/0_0/checkpoints/')
            last_ckpt = [x for x in ckpt_paths if 'last' in x]
            if last_ckpt != []:
                ckpt_path = f'{cfg.loadpath}/default_default/0_0/checkpoints/{last_ckpt[0]}'
            else:
                ckpts = sorted(ckpt_paths, key=lambda x: int(x.split('=')[-1].replace('.ckpt', '')))
                ckpt_path = f'{cfg.loadpath}/default_default/0_0/checkpoints/{ckpts[-1]}'
            ckpt = torch.load(ckpt_path)
            model.load_state_dict(ckpt['state_dict'])
    else:
        raise NotImplementedError

    csv_logger = pl_loggers.CSVLogger(save_dir=cfg.csvlogs)
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=cfg.tblogs)
    chkpt_cb = ModelCheckpoint(monitor='loss/val', save_top_k=-1, save_last=True) # save checkpoint every val
    trainer = pl.Trainer(gpus=cfg.gpu, 
                         logger=[csv_logger, tb_logger],
                         max_epochs=cfg.epochs,
                         check_val_every_n_epoch=cfg.epochsperval,
                         callbacks=[chkpt_cb],
                         profiler='simple',
                         benchmark=cfg.benchmark,
                         num_sanity_val_steps=cfg.sanity_val_steps,
                         resume_from_checkpoint=None if cfg.loadpath == None else ckpt_path
                         )

    trainer.fit(model, train_loader, val_loader)

if __name__ == '__main__':
    main()

    