from dataset import FlowDataset, OnlineFlowDataset, FlowDatasetTest, FlowDatasetReal

import os
import yaml
import sys
import argparse
import random
import numpy as np
import sklearn.metrics
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
from copy import deepcopy
from softgym.envs.corl_baseline import GCFold
from flow_utils import flow2img

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from models import FlowNetSmall

class Evaluator(object):
    def __init__(self, model_path, cfg, val_data, test_data=None):
        self.model_path = model_path
        self.cfg = cfg
        self.val_loader = DataLoader(val_data, batch_size=cfg['batch'], shuffle=False, num_workers=self.cfg['num_workers'])
        if test_data is not None:
            self.test_loader = DataLoader(test_data, batch_size=1, shuffle=False, num_workers=self.cfg['num_workers'])
        else:
            self.test_loader = None

        self.init_model()
        
    def init_model(self):
        if self.cfg['model'] == 'flownets':
            self.model = FlowNetSmall(input_channels=2).cuda()
        else: 
            raise NotImplementedError

    def loss(self, input_flow, target_flow, mask):
        masked = (target_flow*mask - input_flow*mask)
        b, c, h, w = input_flow.size()
        return torch.linalg.norm(torch.reshape(masked, (c, b*h*w)), ord=self.cfg['loss'], dim=1).mean()

    def plot_batch(self, epoch, batch_idx, batch, flow_pred, batchtype='train'):
        depth_input = batch['depths'][0].detach().permute(1, 2, 0).cpu().numpy()
        depth_o = depth_input[:, :, 0].squeeze()
        depth_n = depth_input[:, :, 1].squeeze()
        flow_lbl = batch['flow_lbl'][0].detach().permute(1, 2, 0).cpu().numpy()
        flow_pred = flow_pred[0].detach().permute(1, 2, 0).cpu().numpy()
        cloth_mask = batch['cloth_mask'][0].detach().squeeze().cpu().numpy()
        loss_mask = batch['loss_mask'][0].detach().permute(1, 2, 0).cpu().numpy()

        fig, ax = plt.subplots(3, 2, figsize=(8, 12))
        ax[0][0].set_title('depth obs')
        ax[0][0].imshow(depth_o)
        ax[0][0].imshow(loss_mask[:, :, 0], alpha=0.5)
        ax[0][1].set_title('depth goal')
        ax[0][1].imshow(depth_n)

        flow_lbl_viz = flow2img(flow_lbl)
        ax[1][0].set_title('label')
        ax[1][0].imshow(flow_lbl_viz)

        ax[1][1].set_title(f"dsampled label quiver")
        ys, xs, _ = np.where(flow_lbl != 0)
        s = int(np.ceil(len(ys) / 500)) if len(ys) != 0 else 1
        ax[1][1].imshow(np.zeros((200, 200)))
        ax[1][1].quiver(xs[::s], ys[::s], 
                        flow_lbl[ys[::s], xs[::s], 1], flow_lbl[ys[::s], xs[::s], 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)
        
        flow_pred_viz = flow2img(flow_pred)
        ax[2][0].set_title('pred')
        ax[2][0].imshow(flow_pred_viz)

        ax[2][1].set_title(f"dsampled pred quiver")
        cloth_ids = np.where(cloth_mask != 0)
        cid_set = set([x for x in zip(cloth_ids[0], cloth_ids[1])])
        all_ys, all_xs, _ = np.where(flow_pred != 0)
        ys = []
        xs = []
        for yx in zip(all_ys, all_xs):
            if yx in cid_set:
                ys.append(yx[0])
                xs.append(yx[1])
        s = int(np.ceil(len(ys) / 500)) if len(ys) != 0 else 1
        ax[2][1].imshow(np.zeros((200, 200)))
        ax[2][1].quiver(xs[::s], ys[::s],
                        flow_pred[ys[::s], xs[::s], 1], flow_pred[ys[::s], xs[::s], 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)
        
        print(f"batch type: {batchtype}")
        plt.tight_layout()
        plt.savefig(f'{self.model_path}/eval_images/{batchtype}_epoch{epoch}_batch{batch_idx}.png', bbox_inches='tight')
        plt.close(fig)

    def eval(self, epoch, loader):
        self.model.eval()
        batchtype = 'val' if loader == self.val_loader else 'test'
        with torch.no_grad():
            losses = []
            imgs = []
            aps = []
            ious = []
            aucs = []
            for i, batch in enumerate(loader):
                depth_input = batch['depths'].cuda()
                flow_label = batch['flow_lbl'].cuda()

                flow_pred = self.model(depth_input)
                loss_mask = batch['loss_mask'].cuda()
                loss = self.loss(flow_pred.detach().clone(), flow_label, loss_mask)
                losses.append(loss.item())
            
                if i < 30:
                    self.plot_batch(epoch, i, batch, flow_pred, batchtype=batchtype)

            return np.mean(losses)

    def plot_metrics(self, epoch, avg_train_losses, avg_val_losses, avg_test_losses=[]):
        interval = self.cfg['eval_interval']
        val_xs = [x*interval for x in range(len(avg_val_losses))]

        fig, ax = plt.subplots(1, 1, figsize=(16, 16))
        ax.set_title(f'Min val loss at {np.argmin(avg_val_losses)*interval}: {np.min(avg_val_losses)}')
        ax.plot(range(len(avg_train_losses)), avg_train_losses, label='train loss')
        ax.plot(val_xs, avg_val_losses, label='val loss')
        if len(avg_test_losses) != 0:
            test_xs = [x*interval for x in range(len(avg_test_losses))]
            ax.plot(test_xs, avg_test_losses, label='test loss')
        ax.legend()
        
        plt.tight_layout()
        plt.savefig(f'{self.cfg["output_dir"]}/{self.model_name}/metrics.png')
        plt.close(fig)

    def run(self):
        if os.path.exists(f'{self.model_path}/eval_images'):
            import shutil
            shutil.rmtree(f'{self.model_path}/eval_images')
        os.mkdir(f'{self.model_path}/eval_images')

        avg_val_losses = []
        avg_test_losses = []
        epochs = sorted([int(x.split('_')[-1].replace('.pt', '')) for x in os.listdir(f'{self.model_path}/weights') if 'model_' in x])
        for epoch in epochs:
            self.model.load_state_dict(torch.load(f'{self.model_path}/weights/model_{epoch}.pt'))
            avg_val_loss = self.eval(epoch, self.val_loader)
            avg_val_losses.append(avg_val_loss)
            avg_test_loss = self.eval(epoch, self.test_loader)
            avg_test_losses.append(avg_test_loss)

        np.save(f'{self.model_path}/avg_val_losses1.npy', avg_val_losses)
        np.save(f'{self.model_path}/avg_test_losses1.npy', avg_test_losses)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', dest="model_path")
    args = parser.parse_args()
    model_path = args.model_path

    with open(f'{model_path}/config.yaml') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    seed = cfg['seed']
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    env = GCFold(use_depth=True,
                use_cached_states=False,
                horizon=5,
                use_desc=False,
                action_repeat=1,
                headless=True)

    real_data = FlowDatasetReal(cfg, env.camera_params)
    
    model = FlowNetSmall(input_channels=2).cuda()
    model.eval()

    for i, (depth_o, depth_n) in enumerate(real_data):
        depth_input = torch.tensor(np.stack([depth_o, depth_n]), dtype=torch.float32, device='cuda')
        flow_pred = model(depth_input.unsqueeze(0)).squeeze()

        

        fig, ax = plt.subplots(3, 2, figsize=(8, 12))
        ax[0][0].set_title('depth obs')
        ax[0][0].imshow(depth_o)
        ax[0][0].imshow(loss_mask[:, :, 0], alpha=0.5)
        ax[0][1].set_title('depth goal')
        ax[0][1].imshow(depth_n)

        flow_lbl_viz = flow2img(flow_lbl)
        ax[1][0].set_title('label')
        ax[1][0].imshow(flow_lbl_viz)

        ax[1][1].set_title(f"dsampled label quiver")
        ys, xs, _ = np.where(flow_lbl != 0)
        s = int(np.ceil(len(ys) / 500)) if len(ys) != 0 else 1
        ax[1][1].imshow(np.zeros((200, 200)))
        ax[1][1].quiver(xs[::s], ys[::s], 
                        flow_lbl[ys[::s], xs[::s], 1], flow_lbl[ys[::s], xs[::s], 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)
        
        flow_pred_viz = flow2img(flow_pred)
        ax[2][0].set_title('pred')
        ax[2][0].imshow(flow_pred_viz)

        ax[2][1].set_title(f"dsampled pred quiver")
        cloth_ids = np.where(cloth_mask != 0)
        cid_set = set([x for x in zip(cloth_ids[0], cloth_ids[1])])
        all_ys, all_xs, _ = np.where(flow_pred != 0)
        ys = []
        xs = []
        for yx in zip(all_ys, all_xs):
            if yx in cid_set:
                ys.append(yx[0])
                xs.append(yx[1])
        s = int(np.ceil(len(ys) / 500)) if len(ys) != 0 else 1
        ax[2][1].imshow(np.zeros((200, 200)))
        ax[2][1].quiver(xs[::s], ys[::s],
                        flow_pred[ys[::s], xs[::s], 1], flow_pred[ys[::s], xs[::s], 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)
        
        print(f"batch type: {batchtype}")
        plt.tight_layout()
        plt.savefig(f'{self.model_path}/eval_images/{batchtype}_epoch{epoch}_batch{batch_idx}.png', bbox_inches='tight')
        plt.close(fig)

    
