from models import FlowPickSplit, FlowPickNet
from dataset import Dataset
#from dataset_hdf5 import Dataset

import os
import sys
import yaml
import cv2
import matplotlib.pyplot as plt
import numpy as np
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from collections import namedtuple, deque, OrderedDict
import time

from utils import softargmax_coords, plot_flow

#from torchviz import make_dot

Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))

class Trainer(object):
    def __init__(self, config, train_data, test_data=None):
        self.cfg = config
        self.train_loader = DataLoader(train_data, batch_size=cfg['batch'], shuffle=True, num_workers=0, persistent_workers=False)
        
        if test_data is not None:
            self.test_loader = DataLoader(test_data, batch_size=cfg['batch'], shuffle=True, num_workers=0, persistent_workers=False)
        else:
            self.test_loader = None

        self.create_dirs()
        self.init_model()

        self.di = 0
        self.ups = nn.Upsample(size=(200,200), mode='bilinear', align_corners=True)


    def init_model(self):
        assert self.cfg['model_type'] == 'split'
        self.first = FlowPickSplit(2,self.cfg["im_width"]).cuda() # inputsize=4
        self.second = FlowPickSplit(3, self.cfg["im_width"], second=True).cuda() # inputsize=5
        self.opt1 = torch.optim.Adam(self.first.parameters(), lr=self.cfg['lr'])
        self.opt2 = torch.optim.Adam(self.second.parameters(), lr=self.cfg['lr'])


    def create_dirs(self):
        self.model_name = f'dual_{self.cfg["model_type"]}_{self.cfg["run_name"]}'\
                          f'_l{self.cfg["loss_type"]}_ep{self.cfg["epochs"]}'\
                          f'_lr{self.cfg["lr"]}_s{self.cfg["seed"]}'\
                          f'_b{self.cfg["batch"]}_fla{int(self.cfg["flow_align"])}'\
                          f'aug_{int(self.cfg["augment"])}_fl{self.cfg["flow"]}'\
                          f'_{self.cfg["model_suffix"]}'
        if self.cfg["debug"] and os.path.exists(f'{self.cfg["output_dir"]}/{self.model_name}'):
            import shutil
            shutil.rmtree(f'{self.cfg["output_dir"]}/{self.model_name}')
        os.mkdir(f'{self.cfg["output_dir"]}/{self.model_name}')
        os.mkdir(f'{self.cfg["output_dir"]}/{self.model_name}/images')
        os.mkdir(f'{self.cfg["output_dir"]}/{self.model_name}/evals')
        os.mkdir(f'{self.cfg["output_dir"]}/{self.model_name}/weights')
        with open(f'{self.cfg["output_dir"]}/{self.model_name}/config.yaml', 'w') as f:
            yaml.dump(self.cfg, f)


    def get_pt(self, logits):
        N = logits.size(0)
        W = logits.size(2)

        probs = torch.sigmoid(logits)
        probs = probs.view(N,1,W*W)
        val,idx = torch.max(probs[:,0], 1)

        u = (idx // 20) * 10
        v = (idx % 20) * 10

        #u[val < 0.5] = 0
        #v[val < 0.5] = 0

        return u,v


    def get_gaussian(self, u, v, sigma=5, size=None):
        if size is None:
            size = self.cfg["im_width"]

        x0, y0 = u, v
        x0 = x0[:, None]
        y0 = y0[:, None]

        N = u.size(0)
        num = torch.arange(size).float()
        x, y = torch.vstack([num]*N).cuda(), torch.vstack([num]*N).cuda()
        gx = torch.exp(-(x-x0)**2/(2*sigma**2))
        gy = torch.exp(-(y-y0)**2/(2*sigma**2))
        g = torch.einsum('ni,no->nio', gx, gy)

        gmin = g.amin(dim=(1,2))
        gmax = g.amax(dim=(1,2))
        g = (g - gmin[:,None,None])/(gmax[:,None,None] - gmin[:,None,None])
        g = g.unsqueeze(1)

        if False:
            import matplotlib.pyplot as plt
            for i in range(g.shape[0]):
                plt.imshow(g[i].squeeze().detach().cpu().numpy())
                plt.show()

        return g


    def calc_loss(self, logits1, logits2, pick1, pick2):
        N = logits1.size(0)
        W = logits1.size(2)

        assert self.cfg['loss_type'] == 'sigmoid'
        
        pick1 = pick1.cuda()
        pick2 = pick2.cuda()
        label_a = self.get_gaussian(pick1[:,0] // 10, pick1[:,1] // 10, sigma=2, size=20)
        label_b = self.get_gaussian(pick2[:,0] // 10, pick2[:,1] // 10, sigma=2, size=20)

        # set label for single arm actions to be all zero
        #single_idx = (pick2[:,0] == 0) + (pick2[:,1] == 0)
        #label_b[single_idx,:,:,:] = 0

        if self.cfg['min_loss']:
            loss_1a = torch.mean(F.binary_cross_entropy_with_logits(logits1, label_a, reduction='none'), dim=(1,2,3))
            loss_1b = torch.mean(F.binary_cross_entropy_with_logits(logits1, label_b, reduction='none'), dim=(1,2,3))
            loss_2a = torch.mean(F.binary_cross_entropy_with_logits(logits2, label_a, reduction='none'), dim=(1,2,3))
            loss_2b = torch.mean(F.binary_cross_entropy_with_logits(logits2, label_b, reduction='none'), dim=(1,2,3))

            loss1 = torch.where((loss_1a + loss_2b) < (loss_1b + loss_2a), loss_1a, loss_1b).mean()
            loss2 = torch.where((loss_1a + loss_2b) < (loss_1b + loss_2a), loss_2b, loss_2a).mean()
        else:
            loss1 = F.binary_cross_entropy_with_logits(logits1, label_a)
            loss2 = F.binary_cross_entropy_with_logits(logits2, label_b)

        #return loss
        return loss1, loss2


    def train(self):
        tr_losses_list = []
        tr_metrics_list = []
        ev_losses_list = []
        ev_metrics_list = []

        start = time.time()

        running_loss1 = []
        running_loss2 = []
        running_m = []

        step = 0

        self.first.train()
        self.second.train()

        for epoch in range(self.cfg['epochs']+1):
            for i,b in enumerate(self.train_loader):
                obs, nobs, flow, pick1, pick2 = b
                #obs, nobs, flow, pick1 = b

                obs = obs.cuda()
                nobs = nobs.cuda()
                flow = flow.cuda()

                if self.cfg['input_mode'] == 'noflow':
                    x1 = torch.cat([obs.detach().clone(), nobs.detach().clone()], dim=1)
                elif self.cfg['input_mode'] == 'obsflow':
                    x1 = torch.cat([obs.detach().clone(), flow.detach().clone()], dim=1)
                else:
                    x1 = flow
                
                logits1 = self.first(x1)
                u1,v1 = self.get_pt(logits1)
                pick1_gau = self.get_gaussian(u1,v1)

                if self.cfg['input_mode'] == 'noflow':
                    x2 = torch.cat([obs.detach().clone(), nobs.detach().clone(), pick1_gau.detach().clone()], dim=1)
                elif self.cfg['input_mode'] == 'obsflow':
                    x2 = torch.cat([obs.detach().clone(), flow.detach().clone(), pick1_gau.detach().clone()], dim=1)
                else:
                    x2 = torch.cat([flow.detach().clone(), pick1_gau.detach().clone()], dim=1)

                logits2 = self.second(x2)

                loss1, loss2 = self.calc_loss(logits1, logits2, pick1, pick2)

                self.opt1.zero_grad()
                loss1.backward()
                self.opt1.step()

                # dot = make_dot(loss2)
                # dot.format = 'png'
                # dot.render("graph")

                self.opt2.zero_grad()
                loss2.backward()
                self.opt2.step()

                # running train metrics
                u2,v2 = self.get_pt(logits2)
                m = self.calc_metrics(u1, v1, u2, v2, pick1, pick2)

                running_loss1.append(loss1.item())
                running_loss2.append(loss2.item())
                running_m.append(m.item())

                print(f"Step: {step} Ep: {epoch} batch: {i} loss_1: {loss1.item()} loss_2: {loss2.item()}")

                # eval
                if step % self.cfg['eval_freq'] == 0 and self.test_loader is not None:
                    eval_loss_1 ,eval_loss_2, eval_m = self.eval(epoch,step)

                    ev_losses_list.append([step, eval_loss_1, eval_loss_2])
                    ev_metrics_list.append([step, eval_m])

                # log
                if step % self.cfg['log_freq'] == 0:
                    # train metrics
                    train_loss_1 = np.mean(running_loss1)
                    train_loss_2 = np.mean(running_loss2)
                    train_m = np.mean(running_m)

                    # reset running metrics
                    running_loss1 = []
                    running_loss2 = []
                    running_m = []

                    tr_losses_list.append([step, train_loss_1, train_loss_2])
                    tr_metrics_list.append([step, train_m])

                # save
                if step % self.cfg['save_freq'] == 0:
                    torch.save(self.first.state_dict(),  f'{self.cfg["output_dir"]}/{self.model_name}/weights/first_{step}.pt')
                    torch.save(self.second.state_dict(),  f'{self.cfg["output_dir"]}/{self.model_name}/weights/second_{step}.pt')
                    torch.save(self.opt1.state_dict(), f'{self.cfg["output_dir"]}/{self.model_name}/weights/opt1_{step}.pt')
                    torch.save(self.opt2.state_dict(), f'{self.cfg["output_dir"]}/{self.model_name}/weights/opt2_{step}.pt')

                    np.save(f'{self.cfg["output_dir"]}/{self.model_name}/evals/tr_losses.npy', tr_losses_list)
                    np.save(f'{self.cfg["output_dir"]}/{self.model_name}/evals/tr_metrics.npy', tr_metrics_list)
                    np.save(f'{self.cfg["output_dir"]}/{self.model_name}/evals/ev_losses.npy', ev_losses_list)
                    np.save(f'{self.cfg["output_dir"]}/{self.model_name}/evals/ev_metrics.npy', ev_metrics_list)


                # debug viz
                if False:
                    i = 0
                    # obs and nobs images
                    img = obs[i,0,:,:].detach().cpu().numpy()
                    img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
                    nimg = nobs[i,0,:,:].detach().cpu().numpy()
                    nimg = cv2.cvtColor(nimg,cv2.COLOR_GRAY2RGB)
                    fig, ax = plt.subplots(1,3, figsize=(6,3))
                    fig.suptitle('inputs')
                    ax[0].imshow(img)
                    ax[0].set_title('obs')
                    ax[1].imshow(nimg)
                    ax[1].set_title('nobs')
                    flow = flow[0].permute(1, 2, 0).detach().cpu().numpy()
                    plot_flow(ax[2], flow, skip=10)

                    plt.show()
                    plt.close()

                step += 1

        # final eval
        eval_loss_1 ,eval_loss_2, eval_m = self.eval(epoch,step,viz=True)

        # log
        ev_losses_list.append([step,  eval_loss_1, eval_loss_2])
        ev_metrics_list.append([step, eval_m])

        self.plot(np.array(tr_losses_list), np.array(tr_metrics_list), np.array(ev_losses_list), np.array(ev_metrics_list))
        end = time.time()
        print("total time:",end - start)


    def eval(self, epoch, step, viz=False , viz_freq=500):
        print("Running eval...")

        self.first.eval()
        self.second.eval()

        eval_losses_1 = []
        eval_losses_2 = []
        eval_ms = []

        eval_batches = int(len(self.test_loader.dataset) / self.cfg['batch'])

        print("eval size:",len(self.test_loader.dataset))
        with torch.no_grad():
            for i,b in enumerate(self.test_loader):
                obs, nobs, flow, pick1, pick2 = b

                obs = obs.cuda()
                nobs = nobs.cuda()
                flow = flow.cuda()

                if self.cfg['input_mode'] == 'noflow':
                    x1 = torch.cat([obs, nobs], dim=1)
                elif self.cfg['input_mode'] == 'obsflow':
                    x1 = torch.cat([obs, flow], dim=1)
                else:
                    x1 = flow
                
                logits1 = self.first(x1)
                u1,v1 = self.get_pt(logits1)
                pick1_gau = self.get_gaussian(u1,v1)

                if self.cfg['input_mode'] == 'noflow':
                    x2 = torch.cat([obs, nobs, pick1_gau], dim=1)
                elif self.cfg['input_mode'] == 'obsflow':
                    x2 = torch.cat([obs, flow, pick1_gau], dim=1)
                else:
                    x2 = torch.cat([flow, pick1_gau], dim=1)

                logits2 = self.second(x2)

                loss1, loss2 = self.calc_loss(logits1, logits2, pick1, pick2)

                u2,v2 = self.get_pt(logits2)

                #m1 = self.calc_metrics(u1, v1, pick1)
                #m2 = self.calc_metrics(u2, v2, pick2)
                #metric = (m1 + m2)/2
                m = self.calc_metrics(u1, v1, u2, v2, pick1, pick2)
                eval_ms.append(m.item())
                #eval_m1s.append(m1.item())
                #eval_m2s.append(m2.item())
                eval_losses_1.append(loss1.item())
                eval_losses_2.append(loss2.item())

                print(f"eval: {i}/{eval_batches} loss: {loss1.item()} {loss2.item()} m: {m.item()}")

                if viz and i % viz_freq == 0:
                    self.debug_viz(logits1, pick1, logits2, pick2, obs, nobs, flow, step, i)

        avg_eval_loss_1 = np.mean(eval_losses_1)
        avg_eval_loss_2 = np.mean(eval_losses_2)
        avg_eval_m = np.mean(eval_ms)
        print(f"Eval at step: {step} Epoch: {epoch} avg loss: {avg_eval_loss_1}, {avg_eval_loss_2} avg metric: {avg_eval_m}")
        
        self.first.train()
        self.second.train()

        return avg_eval_loss_1, avg_eval_loss_2, avg_eval_m


    def calc_metrics(self, u1, v1, u2, v2, pick1, pick2):
        pick1 = pick1.float().cuda()
        pick2 = pick2.float().cuda()
        
        u1 = u1.float().cuda()
        v1 = v1.float().cuda()
        u2 = u2.float().cuda()
        v2 = v2.float().cuda()
        
        u1_gt = pick1[:,0]
        v1_gt = pick1[:,1]
        u2_gt = pick2[:,0]
        v2_gt = pick2[:,1]

        m1a = torch.sqrt((u1 - u1_gt)**2 + (v1 - v1_gt)**2)
        m1b = torch.sqrt((u1 - u2_gt)**2 + (v1 - v2_gt)**2)
        m2a = torch.sqrt((u2 - u1_gt)**2 + (v2 - v1_gt)**2)
        m2b = torch.sqrt((u2 - u2_gt)**2 + (v2 - v2_gt)**2)

        m = torch.minimum((m1a + m2b), (m1b + m2a)).mean()/2

        return m


    def plot(self, tr_losses, tr_metrics, ev_losses, ev_metrics):
        fig, ax = plt.subplots(1,2, figsize=(10,3))
        ax[0].set_title("loss")
        ax[0].set_xlabel("step")
        ax[0].set_ylabel("loss (log scale)")
        ax[0].plot(tr_losses[:,0], tr_losses[:,1], label='train_1')
        ax[0].plot(tr_losses[:,0], tr_losses[:,2], label='train_2')
        ax[0].plot(ev_losses[:,0], ev_losses[:,1], label='val_1')
        ax[0].plot(ev_losses[:,0], ev_losses[:,2], label='val_2')
        ax[0].set_yscale('log')
        ax[0].legend()

        ax[1].set_title("metric")
        ax[1].set_xlabel("step")
        ax[1].set_ylabel("avg dist")
        ax[1].plot(tr_metrics[:,0], tr_metrics[:,1], label='tr m')
        ax[1].plot(ev_metrics[:,0], ev_metrics[:,1], label='val m')
        ax[1].legend()

        plt.tight_layout()
        plt.savefig(f'{self.cfg["output_dir"]}/{self.model_name}/evals/train.png')
        plt.close()


    def debug_viz(self, logits1, pick1, logits2, pick2, obs, nobs, flow, step=None, ev_step=None):

        N = logits1.size(0)
        W = logits1.size(2)

        u1s,v1s = self.get_pt(logits1)
        u2s,v2s = self.get_pt(logits2)

        probs1 = torch.sigmoid(logits1)
        probs2 = torch.sigmoid(logits2)

        for i in range(N):
            # obs and nobs images
            img = obs[i,0,:,:].detach().cpu().numpy()
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
            nimg = nobs[i,0,:,:].detach().cpu().numpy()
            nimg = cv2.cvtColor(nimg,cv2.COLOR_GRAY2RGB)

            # label
            label1 = np.zeros((W,W))
            #label[(pick1[0,0] // 6)-5 : (pick1[0,0] // 6)+5, (pick1[0,1] // 10)-5 : (pick1[0,1] // 10)+5] = 1
            label1[pick1[i,0] // 10, pick1[i,1] // 10] = 1

            # predicted pick pts vs gt
            u1 = u1s[i].item()
            v1 = v1s[i].item()
            u1_gt = pick1[i,0].item()
            v1_gt = pick1[i,1].item()

            cv2.drawMarker(img, (v1_gt, u1_gt), (0,200,0), markerType=cv2.MARKER_STAR, 
                        markerSize=10, thickness=2, line_type=cv2.LINE_AA)
            cv2.circle(img, (v1,u1), 6, (0,0,200), 2)

            label2 = np.zeros((W,W))
            label2[pick2[i,0] // 10, pick2[i,1] // 10] = 1
            u2 = u2s[i].item()
            v2 = v2s[i].item()
            u2_gt = pick2[i,0].item()
            v2_gt = pick2[i,1].item()

            cv2.drawMarker(img, (v2_gt, u2_gt), (0,200,0), markerType=cv2.MARKER_STAR, 
                        markerSize=10, thickness=2, line_type=cv2.LINE_AA)
            cv2.circle(img, (v2,u2), 6, (0,0,200), 2)


            # inputs viz
            fig, ax = plt.subplots(1,4, figsize=(6,2))
            fig.suptitle('inputs')
            ax[0].imshow(img)
            ax[0].set_title('obs')
            ax[1].imshow(nimg)
            ax[1].set_title('nobs')
            ax[2].imshow(flow[i,0,:,:].detach().cpu().numpy(), interpolation='none')
            ax[2].set_title('flow')
            ax[3].imshow(flow[i,1,:,:].detach().cpu().numpy(), interpolation='none')
            ax[3].set_title('flow')
            plt.tight_layout()
            plt.savefig(f'{self.cfg["output_dir"]}/{self.model_name}/images/st_{step}_id_{ev_step*N + i}_inputs.png')
            plt.close()
            
            fig, ax = plt.subplots(2,3)
            fig.suptitle('outputs')
            ax[0,0].imshow(logits1[i,0,:,:].clone().detach().cpu().numpy())
            ax[0,0].set_title('logits 1')
            ax[0,1].imshow(probs1[i,0,:,:].clone().detach().cpu().numpy())
            ax[0,1].set_title('heatmap 1')
            ax[0,2].imshow(label1)
            ax[0,2].set_title('label 1')

            ax[1,0].imshow(logits2[i,0,:,:].clone().detach().cpu().numpy())
            ax[1,0].set_title('logits 2')
            ax[1,1].imshow(probs2[i,0,:,:].clone().detach().cpu().numpy())
            ax[1,1].set_title('heatmap 2')
            ax[1,2].imshow(label2)
            ax[1,2].set_title('label 2')

            plt.tight_layout()
            plt.savefig(f'{self.cfg["output_dir"]}/{self.model_name}/images/st_{step}_id_{ev_step*N + i}_debug_outputs.png')
            #plt.show()
            plt.close()


if __name__ == '__main__':
    torch.cuda.empty_cache()
    with open('config.yaml') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    seed = cfg['seed']
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

    buf_train = torch.load(f"{cfg['dataset_path']}/{cfg['run_name']}/{cfg['run_name']}_idx.buf")
    buf_train = buf_train[:cfg['max_buf']]

    buf_test = torch.load(f"{cfg['dataset_path']}/{cfg['test_name']}/{cfg['test_name']}_idx.buf")
    buf_test = buf_test[:cfg['max_buf']]

    #buf_train = None
    #buf_test = None

    if 'pickplace' in cfg['run_name']:
        pick_place = True
    else:
        pick_place = False

    train_data = Dataset(cfg, buf_train, mode='train', pick_pt=cfg['pick'], pick_place=pick_place)
    test_data = Dataset(cfg, buf_test, mode='test', pick_pt=cfg['pick'], pick_place=pick_place)

    tr = Trainer(cfg, train_data, test_data)
    tr.train()