import numpy as np
import matplotlib.pyplot as plt
from models import FlowPickSplit, FlowPickNet, FlowNetSmall
import yaml
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from collections import namedtuple
from dataset import Dataset, Goals
from torch.utils.data import DataLoader
import cv2
from softgym.envs.bimanual_env import BimanualEnv
from softgym.envs.bimanual_tshirt import BimanualTshirtEnv
from utils import remove_dups, flow_afwarp, avg_flow_angle, get_harris, plot_flow
from flow import Flow
import pyflex
import os
import time 
#from edge_masker import EdgeMasker
from copy import deepcopy
import softgym.envs.tshirt_descriptor as td
#from raft import RAFT
#import torch
#import argparse


Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))

def ccw(A,B,C):
    return (C[1] - A[1]) * (B[0] - A[0]) > (B[1] - A[1]) * (C[0] - A[0])

def intersect(A,B,C,D):
    return ccw(A,C,D) != ccw(B,C,D) and ccw(A,B,C) != ccw(A,B,D)


class EnvRollout(object):
    def __init__(self, cfgs):
        self.cfgs = cfgs

        seed = cfgs['seed']
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)

        if cfgs['cloth_type'] == 'towel':
            self.env = BimanualEnv(use_depth=cfgs['img_type'] == 'depth',
                    use_cached_states=False,
                    horizon=cfgs['horizon'],
                    use_desc=False,
                    action_repeat=1,
                    headless=cfgs['headless'],
                    rect=(cfgs['goals']=='rect'))
        elif cfgs['cloth_type'] == 'tshirt':
            self.env = BimanualTshirtEnv(use_depth=cfgs['img_type'] == 'depth',
                    use_cached_states=False,
                    use_desc=False,
                    horizon=cfgs['horizon'],
                    action_repeat=1,
                    headless=cfgs['headless'])

        self.load_model()

        goal_data = Goals(self.cfgs, mode=cfgs['goals'])
        self.goal_loader = DataLoader(goal_data, batch_size=1, shuffle=False, num_workers=0)
        self.save_dir = f'{self.cfgs["output_dir"]}/{self.cfgs["run_name"]}/rollout'
        if not os.path.exists(self.save_dir):
            os.mkdir(self.save_dir)
            os.mkdir(f'{self.save_dir}/unscaled')

        self.ups = nn.Upsample(size=(200,200), mode='bilinear', align_corners=True)

        #self.em = EdgeMasker(self.env, self.cfgs['cloth_type'], tshirtmap_path=None, edgethresh=cfgs['edgethresh'])

        #self.corners = self.get_corner_particles()


    def get_corner_particles(self):
        state = self.env.reset(given_goal=None, given_goal_pos=None)
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]
        uv = (uv/719) * 199

        # corner 1
        dists = np.linalg.norm((uv - np.array([25,25])),axis=1)
        c1 = dists.argmin()

        # corner 2
        dists = np.linalg.norm((uv - np.array([25,175])),axis=1)
        c2 = dists.argmin()

        # corner 3
        dists = np.linalg.norm((uv - np.array([175,175])),axis=1)
        c3 = dists.argmin()

        # corner 4
        dists = np.linalg.norm((uv - np.array([175,25])),axis=1)
        c4 = dists.argmin()

        # u,v = uv[c1]
        # print(u,v)
        # action = [[u,v],[175,175],[u,v],[175,175]]
        # self.env.step(action, pickplace=True)

        return c1,c2,c3,c4

    def get_particle_uv(self,idx):
        uv = td.particle_uv_pos(self.env.camera_params,None)
        uv[:,[1,0]]=uv[:,[0,1]]
        uv = (uv/719) * 199
        u,v = uv[idx]
        return u,v


    def get_true_corner_mask(self, clothmask, r=4):
        true_corners = np.zeros((200,200))
        for c in self.corners:
            b,a = self.get_particle_uv(c)
            h,w = true_corners.shape
            y,x = np.ogrid[-a:h-a, -b:w-b]
            cmask = x*x + y*y <= r*r
            true_corners[cmask] = 1

        true_corners = true_corners * clothmask
        return true_corners


    def load_model(self):
        run_name = self.cfgs['run_name']
        load_iter = self.cfgs['load_iter']
        first_path = f"{self.cfgs['output_dir']}/{run_name}/weights/first_{load_iter}.pt"
        second_path = f"{self.cfgs['output_dir']}/{run_name}/weights/second_{load_iter}.pt"

        if self.cfgs['input_mode'] == 'obsflow':
            chan = 3
        else:
            chan = 2

        self.first = FlowPickSplit(chan,self.cfgs["im_width"]).cuda() # inputsize=4
        self.first.load_state_dict(torch.load(first_path))
        self.first.eval()

        self.second = FlowPickSplit(chan+1, self.cfgs["im_width"], second=True).cuda()
        self.second.load_state_dict(torch.load(second_path))
        self.second.eval()

        if self.cfgs['placenet'] != '':
            self.use_placenet = True
            placenet = self.cfgs['placenet']
            first_path = f"{self.cfgs['output_dir']}/{placenet}/weights/first_{load_iter}.pt"
            second_path = f"{self.cfgs['output_dir']}/{placenet}/weights/second_{load_iter}.pt"

            self.place1 = FlowPickSplit(chan,self.cfgs["im_width"]).cuda() # inputsize=4
            self.place1.load_state_dict(torch.load(first_path))
            self.place1.eval()

            self.place2 = FlowPickSplit(chan+1, self.cfgs["im_width"], second=True).cuda()
            self.place2.load_state_dict(torch.load(second_path))
            self.place2.eval()
        else:
            self.use_placenet = False

        # flow model
        self.flw = FlowNetSmall(input_channels=2).cuda()
        checkpt = torch.load(f'./flow_model/{self.cfgs["flow"]}')
        self.flw.load_state_dict(checkpt['state_dict'])
        #self.flw.load_state_dict(torch.load(f'./flow_model/model_{self.cfgs["flow"]}.pt'))
        self.flw.eval()


    def get_obs_tensor(self):
        _, depth = self.env.get_rgbd()
        depth = cv2.resize(depth, (200, 200))
        depth = torch.FloatTensor(depth)
        depth = depth.unsqueeze(0).unsqueeze(0)
        return depth


    def get_rgbd(self):
        rgbd = pyflex.render_sensor()
        rgbd = np.array(rgbd).reshape(self.env.camera_height, self.env.camera_width, 4)
        rgbd = rgbd[::-1, :, :]
        rgb = rgbd[:, :, :3]
        img = self.env.get_image(self.env.camera_height, self.env.camera_width)
        depth = rgbd[:, :, 3]
        mask = depth > 0
        return img, depth, mask


    def run(self):
        actions = []
        total_metrics = []
        normed_metrics = []
        for i, b in enumerate(self.goal_loader):
            start, goals, coords_o, coords_n = b
            #start, goal, flow, pt1, pt2, coords_o, coords_n = b

            state = self.env.reset(given_goal=None, given_goal_pos=None)
            pyflex.set_positions(coords_o[0])
            pyflex.step()
            done = False
            
            for j in range(goals.shape[1]):
                goal = goals[0,j].unsqueeze(0)
                self.env.render(mode='rgb_array')
                obs = self.get_obs_tensor()
                
                _, unsdepth = self.env.get_rgbd()
                np.save(f'{self.save_dir}/unscaled/{i}_{j}_obs.npy', unsdepth)

                inp = torch.cat([obs,goal],dim=1)
                flow = self.flw(inp.cuda())
                #_, flow = self.flw.module(obs.cuda(), goal.cuda(), test_mode=True)

                # mask flow
                flow[0,0,:,:][obs[0,0,:,:] == 0] = 0
                flow[0,1,:,:][obs[0,0,:,:] == 0] = 0

                action, unmasked_pred = self.get_action(obs, goal, flow)
                actions.append(action)
                next_state, reward, done, _ = self.env.step(action, pickplace=True, on_table=self.cfgs['on_table'])
                self.env.render(mode='rgb_array')
                
                img = cv2.cvtColor(next_state["color"], cv2.COLOR_RGB2BGR)
                img = self.action_viz(img, action, unmasked_pred)

                if self.cfgs['unfold']:
                    cv2.imwrite(f'{self.save_dir}/{i}_{j}_unfold.png',img)
                else:
                    cv2.imwrite(f'{self.save_dir}/{i}_{j}_fold.png',img)

                _, unsdepth = self.env.get_rgbd()
                np.save(f'{self.save_dir}/unscaled/{i}_{j}_nobs.npy', unsdepth)

            metrics,normed = self.calc_metric(start=coords_o[0,:,:3], goal=coords_n[-1][0,:,:3])
            total_metrics.append(metrics)
            normed_metrics.append(normed)
            print(f"goal {i}: {metrics},  {normed} (norm)")

        print("configs \n")
        print(self.cfgs)

        print("\nmean metrics: ",np.mean(total_metrics))
        print("normalized metrics: ",np.mean(normed_metrics))

        np.save(f"{self.save_dir}/actions.npy",actions)
        return total_metrics, normed_metrics


    def get_gaussian(self, u, v, sigma=5, size=None):
        if size is None:
            size = self.cfgs["im_width"]

        x0, y0 = torch.Tensor([u]).cuda(), torch.Tensor([v]).cuda()
        x0 = x0[:, None]
        y0 = y0[:, None]

        N = 1
        num = torch.arange(size).float()
        x, y = torch.vstack([num]*N).cuda(), torch.vstack([num]*N).cuda()
        gx = torch.exp(-(x-x0)**2/(2*sigma**2))
        gy = torch.exp(-(y-y0)**2/(2*sigma**2))
        g = torch.einsum('ni,no->nio', gx, gy)

        gmin = g.amin(dim=(1,2))
        gmax = g.amax(dim=(1,2))
        g = (g - gmin[:,None,None])/(gmax[:,None,None] - gmin[:,None,None])
        g = g.unsqueeze(1)

        if False:
            import matplotlib.pyplot as plt
            for i in range(g.shape[0]):
                plt.imshow(g[i].squeeze().detach().cpu().numpy())
                plt.show()

        return g


    def get_pt(self, logits, min_r=3):
        # select 2 pts with NMS
        N = logits.size(0)
        W = logits.size(2)

        if self.cfgs['prob_type'] == 'sigmoid':
            probs = torch.sigmoid(logits)
            probs = probs.view(N,1,W*W)
        else:
            probs = F.softmax(logits.flatten(-2), -1)

        val,idx = torch.max(probs[:,0], 1)
        u = (idx // 20) * 10
        v = (idx % 20) * 10

        return u.item(),v.item()


    def nearest_to_mask(self, u, v, depth):
        mask_idx = np.argwhere(depth)
        nearest_idx = mask_idx[((mask_idx - [u,v])**2).sum(1).argmin()]

        return nearest_idx


    def get_action(self, obs, nobs, flow):
        # obs, nobs, flow = get_flow(data)

        obs = obs.cuda()
        nobs = nobs.cuda()
        flow = flow.cuda()

        if self.cfgs['input_mode'] == 'noflow':
            x1 = torch.cat([obs, nobs], dim=1)
        elif self.cfgs['input_mode'] == 'obsflow':
            x1 = torch.cat([obs, flow], dim=1)
        else:
            x1 = flow

        # pick point
        logits1 = self.first(x1)
        pick_u1,pick_v1 = self.get_pt(logits1)
        pick1_gau = self.get_gaussian(pick_u1,pick_v1)

        if self.cfgs['input_mode'] == 'noflow':
            x2 = torch.cat([obs, nobs, pick1_gau], dim=1)
        elif self.cfgs['input_mode'] == 'obsflow':
            x2 = torch.cat([obs, flow, pick1_gau], dim=1)
        else:
            x2 = torch.cat([flow, pick1_gau], dim=1)

        logits2 = self.second(x2)
        pick_u2,pick_v2 = self.get_pt(logits2)

        depth_arr = obs.cpu().detach().numpy()[0,0]
        pick1 = self.nearest_to_mask(pick_u1, pick_v1, depth_arr)
        pick2 = self.nearest_to_mask(pick_u2, pick_v2, depth_arr)

        # place point
        pickmask_u1,pickmask_v1 = pick1
        pickmask_u2,pickmask_v2 = pick2

        if self.use_placenet:
            logits1pl = self.place1(x1)
            place_u1,place_v1 = self.get_pt(logits1pl)
            place1_gau = self.get_gaussian(place_u1,place_v1)

            if self.cfgs['input_mode'] == 'noflow':
                x2pl = torch.cat([obs, nobs, place1_gau], dim=1)
            elif self.cfgs['input_mode'] == 'obsflow':
                x2pl = torch.cat([obs, flow, place1_gau], dim=1)
            else:
                x2pl = torch.cat([flow, place1_gau], dim=1)

            logits2pl = self.place2(x2pl)
            place_u2,place_v2 = self.get_pt(logits2pl)

            # swap if intersecting
            if intersect((pickmask_u1,pickmask_v1),(place_u1,place_v1),(pickmask_u2,pickmask_v2),(place_u2,place_v2)):
                place1 = np.array([place_u2, place_v2])
                place2 = np.array([place_u1, place_v1])
            else:
                place1 = np.array([place_u1, place_v1])
                place2 = np.array([place_u2, place_v2])
        else:
            flow_arr = flow.cpu().detach().numpy()[0]
            place_u1,place_v1 = self.get_flow_place_pt(pickmask_u1,pickmask_v1,flow_arr)
            place_u2,place_v2 = self.get_flow_place_pt(pickmask_u2,pickmask_v2,flow_arr)
            place1 = np.array([place_u1, place_v1])
            place2 = np.array([place_u2, place_v2])

        pred_1 = np.array([pick_u1, pick_v1])
        pred_2 = np.array([pick_u2, pick_v2])

        # single action threshold
        if self.cfgs['s_pick_thres'] > 0 and np.linalg.norm(pick1-pick2) < self.cfgs['s_pick_thres']:
            pick2 = [0,0]
            place2 = [0,0]

        # debug viz
        if False:
            i = 0
            # obs and nobs images
            img = obs[i,0,:,:].detach().cpu().numpy()
            img = cv2.cvtColor(img,cv2.COLOR_GRAY2RGB)
            nimg = nobs[i,0,:,:].detach().cpu().numpy()
            nimg = cv2.cvtColor(nimg,cv2.COLOR_GRAY2RGB)

            # predicted pick pts vs gt
            u1 = pick_u1
            v1 = pick_v1
            cv2.circle(img, (v1,u1), 6, (0,0,1), 2)
            u2 = pick_u2
            v2 = pick_v2
            cv2.circle(img, (v2,u2), 6, (0,0,1), 2)
            u,v = pick1
            cv2.circle(img, (v,u), 6, (1,1,0), 2)
            u,v = pick2
            cv2.circle(img, (v,u), 6, (1,1,0), 2)

            u,v = place1
            cv2.circle(img, (int(v),int(u)), 6, (1,0,1), 2)
            u,v = place2
            cv2.circle(img, (int(v),int(u)), 6, (1,0,1), 2)

            # inputs viz
            fig, ax = plt.subplots(1,3, figsize=(6,3))
            fig.suptitle('inputs')
            ax[0].imshow(img)
            ax[0].set_title('obs')
            ax[1].imshow(nimg)
            ax[1].set_title('nobs')
            # ax[2].imshow(flow[i,0,:,:].detach().cpu().numpy(), interpolation='none')
            # ax[2].set_title('flow')
            # ax[3].imshow(flow[i,1,:,:].detach().cpu().numpy(), interpolation='none')
            # ax[3].set_title('flow')
            flow = flow[0].permute(1, 2, 0).detach().cpu().numpy()
            plot_flow(ax[2], flow, skip=10)

            plt.show()
            plt.close()

            probs1 = torch.sigmoid(logits1)
            probs2 = torch.sigmoid(logits2)

            fig, ax = plt.subplots(1,4)
            fig.suptitle('outputs')
            ax[0].imshow(logits1[i,0,:,:].clone().detach().cpu().numpy())
            ax[0].set_title('logits 1')
            ax[1].imshow(probs1[i,0,:,:].clone().detach().cpu().numpy())
            ax[1].set_title('heatmap 1')
            ax[2].imshow(logits2[i,0,:,:].clone().detach().cpu().numpy())
            ax[2].set_title('logits 2')
            ax[3].imshow(probs2[i,0,:,:].clone().detach().cpu().numpy())
            ax[3].set_title('heatmap 2')
            plt.show()
            plt.close()

        return np.array([pick1, place1, pick2, place2]), np.array([pred_1, pred_2])

    def calc_metric(self, start, goal):
        current = pyflex.get_positions().reshape(-1, 4)[:,:3]
        start_dist = np.linalg.norm(goal - start, axis=1).mean()
        curr_dist = np.linalg.norm(goal - current, axis=1).mean()
        normed = curr_dist/start_dist

        return curr_dist, normed


    def get_flow_place_pt(self, u,v, flow):
        flow_u_idxs = np.argwhere(flow[0,:,:])
        flow_v_idxs = np.argwhere(flow[1,:,:])
        nearest_u_idx = flow_u_idxs[((flow_u_idxs - [u,v])**2).sum(1).argmin()]
        nearest_v_idx = flow_v_idxs[((flow_v_idxs - [u,v])**2).sum(1).argmin()]

        flow_u = flow[0,nearest_u_idx[0],nearest_u_idx[1]]
        flow_v = flow[1,nearest_v_idx[0],nearest_v_idx[1]]

        new_u = np.clip(u + flow_u, 0, 199)
        new_v = np.clip(v + flow_v, 0, 199)

        return new_u,new_v


    def rotate(self, obs, goal, angle):
        obs = TF.affine(obs, angle=angle, translate=(0, 0), scale=1.0, shear=0)
        goal = TF.affine(goal, angle=angle, translate=(0, 0), scale=1.0, shear=0)
        return obs, goal


    def aug_uv(self, uv, angle, dx, dy, size=719):
        uvt = deepcopy(uv)
        rad = np.deg2rad(angle)
        R = np.array([
            [np.cos(rad), -np.sin(rad)],
            [np.sin(rad), np.cos(rad)]])
        uvt -= size / 2
        uvt = np.dot(R, uvt.T).T
        uvt += size / 2
        uvt[:, 1] += dx
        uvt[:, 0] += dy

        uvt = np.clip(uvt, 0, size)
        return uvt


    def pad(self, im, pad_amount):
        trans = transforms.Compose([transforms.Pad(pad_amount),
                                    transforms.Resize((200,200))])
        padded_im = trans(im)
        return padded_im

    def action_viz(self, img, action, unmasked_pred):
        ''' img: cv2 image
            action: pick1, place1, pick2, place2
            unmasked_pred: pick1_pred, pick2_pred'''
        pick1, place1, pick2, place2 = action
        pick1_pred, pick2_pred = unmasked_pred

        # draw the original predictions
        #u,v = pick1_pred
        #cv2.drawMarker(img, (int(v), int(u)), (0,0,200), markerType=cv2.MARKER_STAR, 
        #                markerSize=10, thickness=2, line_type=cv2.LINE_AA)
        #u,v = pick2_pred
        #cv2.drawMarker(img, (int(v), int(u)), (0,0,200), markerType=cv2.MARKER_STAR, 
        #                markerSize=10, thickness=2, line_type=cv2.LINE_AA)

        # draw the masked action
        u1,v1 = pick1
        u2,v2 = place1
        cv2.circle(img, (int(v1),int(u1)), 6, (0,200,0), 2)
        cv2.arrowedLine(img, (int(v1),int(u1)), (int(v2),int(u2)), (0,200,0), 3)
        u1,v1 = pick2
        u2,v2 = place2
        cv2.circle(img, (int(v1),int(u1)), 6, (0,200,0), 2)
        cv2.arrowedLine(img, (int(v1),int(u1)), (int(v2),int(u2)), (0,200,0), 3)

        return img



if __name__ == '__main__':
    goals = 'towel' #os #ms #towel #tsh #rect
    if goals == 'tsh':
        cloth_type = 'tshirt'
    else:
        cloth_type = 'towel'
    edgethresh = 10 if cloth_type == 'tshirt' else 5
    cfgs = {
        'seed': 0,
        'im_width': 200,
        'img_type': 'depth',
        'on_table': True,
        'horizon': 1,
        #'run_name': 'dual_split_biman_towel_actpickplace_n10000_h2_co1_am0.9_tc0.5_cam0.65_lsigmoid_ep150_lr0.0001_s0_b10_fla0aug_0_fl600000_raft-towel-spatialaug-0.65-flip.pth__nojit_flowonly_raft',
        'run_name': 'dual_split_biman_towel_actpickplace_n10000_h2_co1_am0.9_tc0.5_cam0.65_lsigmoid_ep150_lr0.0001_s0_b10_fla0aug_1_flepoch=613.ckpt__nojit_noflow_aug',
        'load_iter': 0,
        'goal_folder': '',
        'dataset_path': '/data',
        'output_dir': '/home/[username]/bimanual_flow/output',
        'flow': 'epoch=613.ckpt', #epoch=613.ckpt #lact_epoch=513.ckpt #rect_epoch=872.ckpt #tsh_epoch=683.ckpt
        'actmask': 'none', # none, edge, corner
        'res_name': '',
        'res_iter': 0,
        'unfold': False,
        'prob_type': 'sigmoid',
        'load_finetuned': False,
        'flow_align': False,
        'headless': True,
        'cloth_type': cloth_type,
        'edgethresh': edgethresh,
        'placenet': 'dual_split_biman_towel_actpickplace_n10000_h2_co1_am0.9_tc0.5_cam0.65_lsigmoid_ep150_lr0.0001_s0_b10_fla0aug_1_flepoch=613.ckpt__nojit_noflow_aug_place', # ''
        #'placenet': '',
        'input_mode': 'noflow', #noflow #obsflow #flowonly
        'goals': goals,
        's_pick_thres': 30 #30
    }

    avg_metrics = []
    full_mean = []

    cfgs['unfold'] = False
    env = EnvRollout(cfgs)

    for i in range(0, 300001, 5000):
        print(f"loading {i}")
        print("folding:")
        try:
            env.cfgs['load_iter'] = i
            env.load_model()
            fold_mean, fold_norm = env.run()

            full_mean.append(fold_mean)
            print(f"mean: {np.mean(fold_mean)} norm: {np.mean(fold_norm)}")
            
            avg_metrics.append([i, np.mean(fold_mean), np.mean(fold_norm)])
        except EOFError:
            print("EOFError. skipping...")

    avg_metrics = np.array(avg_metrics)
    idx = avg_metrics[:,1].argmin()
    print(f"\nmin: {avg_metrics[idx,0]} fold mean: {avg_metrics[idx,1]} norm {avg_metrics[idx,2]}")

    print("\nall goals:")
    for m in full_mean[idx]:
        print(m)

    if cfgs['goals'] == 'towel':
        full_mean = np.array(full_mean)
        print(f"\none-step: {np.mean(full_mean[idx,:40])}")
        print(f"mul-step: {np.mean(full_mean[idx,40:])}")
        print(f"one-arm: {np.mean(full_mean[idx,np.r_[:24,32:42,44]])}")
        print(f"two-arm: {np.mean(full_mean[idx,np.r_[24:32,42,43,45]])}")

    np.save(f'{env.save_dir}/fold_metrics.npy',avg_metrics)
    np.save(f'{env.save_dir}/all_goals.npy',full_mean)

    fig, ax = plt.subplots(1,2, figsize=(10,3))
    ax[0].set_title("mean")
    ax[0].set_xlabel("step")
    ax[0].set_ylabel("mean metric")
    ax[0].plot(avg_metrics[:,0], avg_metrics[:,1], label='fold')
    ax[0].legend()
    ax[1].set_title("norm")
    ax[1].set_xlabel("step")
    ax[1].set_ylabel("norm metric")
    ax[1].plot(avg_metrics[:,0], avg_metrics[:,2], label='fold')
    plt.savefig(f'{env.save_dir}/fold_eval.png')
    plt.show()
