import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import sys
sys.path.append('/home/exx/projects/softagent/descriptors_softgym_baseline')
from flowim import Flow
from utils import remove_dups, generate_perlin_noise_2d
from softgym.envs.corl_baseline import GCFold
import torch.utils.data as data

import os
import cv2
import yaml
import random
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from PIL import Image
import hydra

from torch.utils.data import Dataset, DataLoader

import pyflex
from softgym.envs.corl_baseline import GCFold
import softgym.envs.tshirt_descriptor as td

class JointDataset(Dataset):
    def __init__(self, cfg, ids, buf, camera_params, stage='train'):
        self.cfg = cfg
        self.camera_params = camera_params
        self.buf = buf

        self.flw = Flow()

        self.transform = T.Compose([T.ToTensor()])
        
        self.data_path = f'{cfg.basepath}/{cfg.trainname}' if stage == 'train' else f'{cfg.basepath}/{cfg.valname}'
        self.aug = cfg.aug if stage == 'train' else False
        self.ids = ids
        self.stage = stage

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        # Compute flow between obs and nobs

        # switch obs and nobs
        switch = self.aug and torch.rand(1) < self.cfg.switch_obs
        obs_suffix = 'after' if switch else 'before'
        nobs_suffix = 'before' if switch else 'after'

        # Load obs and knots
        depth_o = np.load(f'{self.data_path}/rendered_images/{str(index).zfill(6)}_depth_{obs_suffix}.npy')
        cloth_mask = (depth_o != 0).astype(float) # 200 x 200

        if not os.path.exists(f'{self.data_path}/knots/{str(index).zfill(6)}_knotsnodups_{obs_suffix}.npy'):
            coords_o = np.load(f'{self.data_path}/coords/{str(index).zfill(6)}_coords_{obs_suffix}.npy')
            uv_o_f = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knots_{obs_suffix}.npy')
            uv_o_f[:,[1,0]] = uv_o_f[:,[0,1]] # knots axes are flipped in collect_data
            
            # Remove occlusions
            depth_o_rs = cv2.resize(depth_o, (720, 720))
            uv_o = remove_dups(self.camera_params, uv_o_f, coords_o, depth_o_rs, zthresh=0.001)
            np.save(f'{self.data_path}/knots/{str(index).zfill(6)}_knotsnodups_{obs_suffix}.npy', uv_o)
        else:
            uv_o = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knotsnodups_{obs_suffix}.npy', allow_pickle=True)

        # Load nobs and knots
        # With probablity p, sample image pair as obs and nobs, otherwise choose random nobs
        if self.aug and torch.rand(1) < self.cfg.random_nobs_prob:
            rand_i = self.ids[torch.randint(len(self.ids), (1,))]
            depth_n = np.load(f'{self.data_path}/rendered_images/{str(rand_i).zfill(6)}_depth_{nobs_suffix}.npy')
            uv_n_f = np.load(f'{self.data_path}/knots/{str(rand_i).zfill(6)}_knots_{nobs_suffix}.npy')
            uv_n_f[:,[1,0]] = uv_n_f[:,[0,1]] # knots axes are flipped in collect_data
        else:
            depth_n = np.load(f'{self.data_path}/rendered_images/{str(index).zfill(6)}_depth_{nobs_suffix}.npy')
            uv_n_f = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knots_{nobs_suffix}.npy')
            uv_n_f[:,[1,0]] = uv_n_f[:,[0,1]] # knots axes are flipped in collect_data

        # Get pick labels
        b = self.buf[index]
        index = b.obs
        pick1 = b.act[0]
        place1 = b.act[1]
        pick2 = b.act[2]
        place2 = b.act[3]

        # Spatial aug
        if self.aug and torch.rand(1) < self.cfg.spatial_aug:
            depth_o = Image.fromarray(depth_o)
            depth_n = Image.fromarray(depth_n)
            cloth_mask = Image.fromarray(cloth_mask)
            depth_o, depth_n, cloth_mask, uv_o, uv_n_f, \
            pick1, pick2, place1, place2 = self.spatial_aug(depth_o, depth_n, cloth_mask, uv_o, uv_n_f, 
                                                            pick1, pick2, place1, place2)
            depth_o = np.array(depth_o)
            depth_n = np.array(depth_n)
        cloth_mask = np.array(cloth_mask, dtype=bool)

        # Remove out of bounds
        uv_o[uv_o < 0] = float('NaN')
        uv_o[uv_o >= 720] = float('NaN')

        # Get flow image
        flow_lbl = self.flw.get_image(uv_o, uv_n_f, mask=cloth_mask, depth_o=depth_o, depth_n=depth_n)

        # Get loss mask
        loss_mask = np.zeros((flow_lbl.shape[0], flow_lbl.shape[1]), dtype=np.float32)
        non_nan_idxs = np.rint(uv_o[~np.isnan(uv_o).any(axis=1)]/719*199).astype(int)
        loss_mask[non_nan_idxs[:, 0], non_nan_idxs[:, 1]] = 1

        if self.cfg.debug.plot:
            im1 = depth_o
            im2 = depth_n
            fig, ax = plt.subplots(1, 4, figsize=(32, 16))
            ax[0].imshow(im1)
            ax[0].scatter([pick1[1]], [pick1[0]], s=100, c='blue')
            ax[0].scatter([pick2[1]], [pick2[0]], s=100, c='blue')
            ax[0].scatter([place1[1]], [place1[0]], s=100, c='red')
            ax[0].scatter([place2[1]], [place2[0]], s=100, c='red')

            ax[1].imshow(im2)

            skip = 1
            h, w, _ = flow_lbl.shape
            ax[2].imshow(np.zeros((h, w)), alpha=0.5)
            ys, xs, _ = np.where(flow_lbl != 0)
            ax[2].quiver(xs[::skip], ys[::skip],
                        flow_lbl[ys[::skip], xs[::skip], 1], flow_lbl[ys[::skip], xs[::skip], 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)

            ax[3].imshow(loss_mask)

            plt.tight_layout()
            plt.savefig(f'data_{index}.png')
            plt.close()

        depths = np.stack([depth_o, depth_n], axis=2)

        depths = self.transform(depths)
        flow_lbl = self.transform(flow_lbl)
        loss_mask = self.transform(loss_mask)
        cloth_mask = self.transform(cloth_mask)

        sample = {'depths': depths, 'flow_lbl': flow_lbl, 'loss_mask': loss_mask, 'cloth_mask': cloth_mask, 'pick_lbl': [pick1, pick2], 'place_lbl': [place1, place2]}
        return sample
    
    def aug_uv(self, uv, angle, dx, dy, size=719):
        uvt = deepcopy(uv)
        rad = np.deg2rad(angle)
        R = np.array([
            [np.cos(rad), -np.sin(rad)],
            [np.sin(rad), np.cos(rad)]])
        uvt -= size / 2
        uvt = np.dot(R, uvt.T).T
        uvt += size / 2
        uvt[:, 1] += dx
        uvt[:, 0] += dy
        uvt = np.clip(uvt, 0, size)
        return uvt

    def spatial_aug(self, depth_o, depth_n, cloth_mask, uv_o, uv_n_f, pick1, pick2, place1, place2):
        spatial_rot = self.cfg.spatial_rot
        spatial_trans = self.cfg.spatial_trans
        angle = np.random.randint(-spatial_rot, spatial_rot+1)
        dx = np.random.randint(-spatial_trans, spatial_trans+1)
        dy = np.random.randint(-spatial_trans, spatial_trans+1)
        depth_o = TF.affine(depth_o, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        depth_n = TF.affine(depth_n, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        cloth_mask = TF.affine(cloth_mask, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        uv_o = self.aug_uv(uv_o, -angle, dx/199*719, dy/199*719)
        uv_n_f = self.aug_uv(uv_n_f, -angle, dx/199*719, dy/199*719)
        pick1 = self.aug_uv(pick1.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        pick1 = pick1.squeeze().astype(int)
        pick2 = self.aug_uv(pick2.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        pick2 = pick2.squeeze().astype(int)
        place1 = self.aug_uv(place1.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        place1 = place1.squeeze().astype(int)
        place2 = self.aug_uv(place2.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        place2 = place2.squeeze().astype(int)
        return depth_o, depth_n, cloth_mask, uv_o, uv_n_f, pick1, pick2, place1, place2
    

class Goals(data.Dataset):
    def __init__(self, config, mode='towel'):
        """ Evaluation dataset/ loads the test goals
        params:
        config: current rollout/test config dict (contains paths)
        mode: towel (all towel goals), os (onestep towel), ms (multistep towel),
                rect (all rectangle goals), tsh (all tshirt goals),
                large (large towel goals (incomplete)), debug
        """
        self.cfg = config

        self.eval_combos = []
        self.mode = mode
        goals = []

        if self.mode not in ['os','ms','rect','tsh','towel','large','debug']:
            raise Exception("invalid mode for goal loader")

        if self.mode == 'os' or self.mode == 'towel':
            for i in range(40):
                goals.append([f'test_goal_{i}'])
        if self.mode == 'ms' or self.mode == 'towel':
            names = ['ms_opp_corn_in',
                'ms_all_corn_in',
                'ms_two_side_horz',
                'ms_two_side_vert',
                'ms_double_tri',
                'ms_double_rect']
            lens = [2,4,2,2,2,2]
            for name,steps in zip(names,lens):
                goal = []
                for i in range(steps):
                    goal.append(f'{name}_{i}')
                goals.append(goal)
        if self.mode == 'rect':
            names = ['rect_horz_fold',
                     'rect_vert_fold',
                     'rect_one_corn_in',
                     'rect_two_side_horz',
                     'rect_two_side_vert']
            lens = [1,1,1,2,2]
            for name,steps in zip(names,lens):
                goal = []
                for i in range(steps):
                    goal.append(f'{name}_{i}')
                goals.append(goal)
        if self.mode == 'tsh':
            names = ['tsh_three_step',
                     'tsh_across_horz',
                     'tsh_across_vert']
            lens = [3,1,1]
            for name,steps in zip(names,lens):
                goal = []
                for i in range(steps):
                    goal.append(f'{name}_{i}')
                goals.append(goal)
        if self.mode == 'large':
            names = ['large_two_side_horz']
            lens = [2]
            for name,steps in zip(names,lens):
                goal = []
                for i in range(steps):
                    goal.append(f'{name}_{i}')
                goals.append(goal) 
        if self.mode == 'debug':
            goals.append([f'test_goal_3'])

        self.goals = goals

    def __len__(self):
        return len(self.goals)

    def __getitem__(self, index):
        # load images
        if self.mode == 'rect':
            start_fn = 'rect_open_0'
            folder = 'rect'
        elif self.mode == 'tsh':
            start_fn = 'tsh_open_0'
            folder = 'tshirt'
        elif self.mode == 'large':
            start_fn = 'large_open_0'
            folder = ''
        else:
            start_fn = 'open_2side_high'
            folder = 'square_final'

        goal_fns = self.goals[index]

        depth_o = cv2.imread(f'../goals/{folder}/{start_fn}_depth.png')[:, :, 0] / 255
        coords_o = np.load(f'../goals/{folder}/particles/{start_fn}.npy')

        depth_ns = []
        coords_ns = []
        knots_ns = []
        for goal_fn in goal_fns:
            depth_n = cv2.imread(f'../goals/{folder}/{goal_fn}_depth.png')[:, :, 0] / 255
            coords_n = np.load(f'../goals/{folder}/particles/{goal_fn}.npy')
        
            depth_ns.append(depth_n)
            coords_ns.append(coords_n)

        # convert to tensor
        depth_o = torch.FloatTensor(depth_o)
        depth_o = depth_o.unsqueeze(0)
        depth_ns = torch.FloatTensor(depth_ns)
        depth_ns = depth_ns.unsqueeze(1)

        return depth_o, depth_ns, coords_o, coords_ns