import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import sys
from flowim import Flow
from utils import remove_dups, generate_perlin_noise_2d
from softgym.envs.corl_baseline import GCFold

import os
import cv2
import yaml
import random
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from PIL import Image
import hydra

from torch.utils.data import Dataset, DataLoader

import pyflex
from softgym.envs.corl_baseline import GCFold
import softgym.envs.tshirt_descriptor as td

class FlowDatasetTest(Dataset):
    """Test single-step and multi-step sim image goals.
    """
    def __init__(self, config, camera_params, base_path=None):
        self.cfg = config
        self.camera_params = camera_params
        self.flw = Flow()
        self.transform = T.Compose([T.ToTensor()])

        self.eval_combos = [
            ['open_2side_high', f'test_goal_{i}'] for i in range(40)
        ] + [
            ['open_2side_high', 'ms_opp_corn_in_0'],
            ['ms_opp_corn_in_0', 'ms_opp_corn_in_1'],
            ['open_2side_high', 'ms_all_corn_in_0'],
            ['ms_all_corn_in_0', 'ms_all_corn_in_1'],
            ['ms_all_corn_in_1', 'ms_all_corn_in_2'],
            ['ms_all_corn_in_2', 'ms_all_corn_in_3'],
            ['open_2side_high', 'ms_double_rect_0'],
            ['ms_double_rect_0', 'ms_double_rect_1'],
            ['open_2side_high', 'ms_double_tri_0'],
            ['ms_double_tri_0', 'ms_double_tri_1'],
            ['open_2side_high', 'ms_two_side_horz_0'],
            ['ms_two_side_horz_0', 'ms_two_side_horz_1'],
            ['open_2side_high', 'ms_two_side_vert_0'],
            ['ms_two_side_vert_0', 'ms_two_side_vert_1'],
        ]
        if base_path == None:
            self.basepath = hydra.utils.get_original_cwd()
        else:
            self.basepath = base_path

        self.env = GCFold(use_depth=True,
            use_cached_states=False,
            horizon=1,
            use_desc=False,
            cam_height=0.65,
            action_repeat=1,
            headless=True)
        state = self.env.reset()
    
    def __len__(self):
        return len(self.eval_combos)

    def __getitem__(self, index):
        start_fn, goal_fn = self.eval_combos[index]
        depth_o = cv2.imread(f'{self.basepath}/../goals/{start_fn}_depth.png')[:, :, 0] / 255 # 200 x 200
        cloth_mask = (depth_o != 0).astype(float)
        if not os.path.exists(f'{self.basepath}/../goals/particles/{start_fn}_uvnodups.npy'):
            # Load obs uv
            coords_o = np.load(f'{self.basepath}/../goals/particles/{start_fn}.npy')
            if not os.path.exists(f'{self.basepath}/../goals/particles/{start_fn}_uv.npy'):
                state = self.env.reset()
                pyflex.set_positions(coords_o)
                pyflex.step()
                uv_o_f = td.particle_uv_pos(self.camera_params,None)
                np.save(f'{self.basepath}/../goals/particles/{start_fn}_uv.npy', uv_o_f)
            else:
                uv_o_f = np.load(f'{self.basepath}/../goals/particles/{start_fn}_uv.npy')
            depth_o_rs = cv2.resize(depth_o, (720, 720))
            uv_o = remove_dups(self.camera_params, uv_o_f, coords_o, depth_o_rs, zthresh=0.005)
            np.save(f'{self.basepath}/../goals/particles/{start_fn}_uvnodups.npy', uv_o)
        else:
            uv_o = np.load(f'{self.basepath}/../goals/particles/{start_fn}_uvnodups.npy')
        
        # Load depth nobs and uv
        depth_n = cv2.imread(f'{self.basepath}/../goals/{goal_fn}_depth.png')[:, :, 0] / 255
        if not os.path.exists(f'{self.basepath}/../goals/particles/{goal_fn}_uv.npy'):
            coords_n = np.load(f'{self.basepath}/../goals/particles/{goal_fn}.npy')
            state = self.env.reset()
            pyflex.set_positions(coords_n)
            pyflex.step()
            uv_n_f = td.particle_uv_pos(self.camera_params,None)
            np.save(f'{self.basepath}/../goals/particles/{goal_fn}_uv.npy', uv_n_f)
        else:
            uv_n_f = np.load(f'{self.basepath}/../goals/particles/{goal_fn}_uv.npy')

        # Get ground truth and learned flow
        flow_lbl = self.flw.get_image(uv_o, uv_n_f, mask=cloth_mask)

        # Get loss mask
        loss_mask = np.zeros((flow_lbl.shape[0], flow_lbl.shape[1]), dtype=np.float32)
        non_nan_idxs = np.rint(uv_o[~np.isnan(uv_o).any(axis=1)]/719*199).astype(int)
        loss_mask[non_nan_idxs[:, 0], non_nan_idxs[:, 1]] = 1

        if self.cfg['debug']['plot']:
            im1 = depth_o
            im2 = depth_n
            fig, ax = plt.subplots(1, 4, figsize=(32, 16))
            ax[0].imshow(im1)
            ax[1].imshow(im2)
            
            skip = 1
            h, w, _ = flow_lbl.shape
            ax[2].imshow(np.zeros((h, w)), alpha=0.5)
            ys, xs, _ = np.where(flow_lbl != 0)
            ax[2].quiver(xs[::skip], ys[::skip],
                        flow_lbl[ys[::skip], xs[::skip], 1], flow_lbl[ys[::skip], xs[::skip], 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)

            ax[3].imshow(loss_mask)

            plt.tight_layout()
            plt.show()

        depths = np.stack([depth_o, depth_n], axis=2)
        depths = self.transform(depths).float()
        flow_lbl = self.transform(flow_lbl)
        loss_mask = self.transform(loss_mask)
        cloth_mask = self.transform(cloth_mask)

        sample = {'depths': depths, 'flow_lbl': flow_lbl, 'loss_mask': loss_mask, 'cloth_mask': cloth_mask}
        return sample

class FlowDataset(Dataset):
    def __init__(self, cfg, ids, camera_params, stage='train'):
        self.cfg = cfg
        self.camera_params = camera_params

        self.flw = Flow()

        self.transform = T.Compose([T.ToTensor()])
        
        self.data_path = f'{cfg.basepath}/{cfg.trainname}' if stage == 'train' else f'{cfg.basepath}/{cfg.valname}'
        self.aug = cfg.aug if stage == 'train' else False
        self.ids = ids
        self.stage = stage

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, index):
        # Compute flow between obs and nobs

        # switch obs and nobs
        switch = self.aug and torch.rand(1) < self.cfg.switch_obs
        obs_suffix = 'after' if switch else 'before'
        nobs_suffix = 'before' if switch else 'after'

        # Load obs and knots
        depth_o = np.load(f'{self.data_path}/rendered_images/{str(index).zfill(6)}_depth_{obs_suffix}.npy')
        cloth_mask = (depth_o != 0).astype(float) # 200 x 200

        if not os.path.exists(f'{self.data_path}/knots/{str(index).zfill(6)}_knotsnodups_{obs_suffix}.npy'):
            coords_o = np.load(f'{self.data_path}/coords/{str(index).zfill(6)}_coords_{obs_suffix}.npy')
            uv_o_f = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knots_{obs_suffix}.npy')
            uv_o_f[:,[1,0]] = uv_o_f[:,[0,1]] # knots axes are flipped in collect_data
            
            # Remove occlusions
            depth_o_rs = cv2.resize(depth_o, (720, 720))
            uv_o = remove_dups(self.camera_params, uv_o_f, coords_o, depth_o_rs, zthresh=0.001)
            np.save(f'{self.data_path}/knots/{str(index).zfill(6)}_knotsnodups_{obs_suffix}.npy', uv_o)
        else:
            uv_o = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knotsnodups_{obs_suffix}.npy', allow_pickle=True)

        # Load nobs and knots
        # With probablity p, sample image pair as obs and nobs, otherwise choose random nobs
        if self.aug and torch.rand(1) < self.cfg.random_nobs_prob:
            rand_i = self.ids[torch.randint(len(self.ids), (1,))]
            depth_n = np.load(f'{self.data_path}/rendered_images/{str(rand_i).zfill(6)}_depth_{nobs_suffix}.npy')
            uv_n_f = np.load(f'{self.data_path}/knots/{str(rand_i).zfill(6)}_knots_{nobs_suffix}.npy')
            uv_n_f[:,[1,0]] = uv_n_f[:,[0,1]] # knots axes are flipped in collect_data
        else:
            depth_n = np.load(f'{self.data_path}/rendered_images/{str(index).zfill(6)}_depth_{nobs_suffix}.npy')
            uv_n_f = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knots_{nobs_suffix}.npy')
            uv_n_f[:,[1,0]] = uv_n_f[:,[0,1]] # knots axes are flipped in collect_data

        # Spatial aug
        if self.aug and torch.rand(1) < self.cfg.spatial_aug:
            depth_o = Image.fromarray(depth_o)
            depth_n = Image.fromarray(depth_n)
            cloth_mask = Image.fromarray(cloth_mask)
            depth_o, depth_n, cloth_mask, uv_o, uv_n_f = self.spatial_aug(depth_o, depth_n, cloth_mask, uv_o, uv_n_f)
            depth_o = np.array(depth_o)
            depth_n = np.array(depth_n)
        cloth_mask = np.array(cloth_mask, dtype=bool)

        # Remove out of bounds
        uv_o[uv_o < 0] = float('NaN')
        uv_o[uv_o >= 720] = float('NaN')

        # Get flow image
        flow_lbl = self.flw.get_image(uv_o, uv_n_f, mask=cloth_mask, depth_o=depth_o, depth_n=depth_n)

        # Get loss mask
        loss_mask = np.zeros((flow_lbl.shape[0], flow_lbl.shape[1]), dtype=np.float32)
        non_nan_idxs = np.rint(uv_o[~np.isnan(uv_o).any(axis=1)]/719*199).astype(int)
        loss_mask[non_nan_idxs[:, 0], non_nan_idxs[:, 1]] = 1

        if self.cfg.debug.plot and self.stage == 'train':
            im1 = depth_o
            im2 = depth_n
            fig, ax = plt.subplots(1, 4, figsize=(32, 16))
            ax[0].imshow(im1)
            ax[1].imshow(im2)

            skip = 1
            h, w, _ = flow_lbl.shape
            ax[2].imshow(np.zeros((h, w)), alpha=0.5)
            ys, xs, _ = np.where(flow_lbl != 0)
            ax[2].quiver(xs[::skip], ys[::skip],
                        flow_lbl[ys[::skip], xs[::skip], 1], flow_lbl[ys[::skip], xs[::skip], 0], 
                        alpha=0.8, color='white', angles='xy', scale_units='xy', scale=1)

            ax[3].imshow(loss_mask)

            plt.tight_layout()
            plt.show()

        depths = np.stack([depth_o, depth_n], axis=2)

        depths = self.transform(depths)
        flow_lbl = self.transform(flow_lbl)
        loss_mask = self.transform(loss_mask)
        cloth_mask = self.transform(cloth_mask)

        sample = {'depths': depths, 'flow_lbl': flow_lbl, 'loss_mask': loss_mask, 'cloth_mask': cloth_mask}
        return sample
    
    def aug_uv(self, uv, angle, dx, dy):
        uvt = deepcopy(uv)
        rad = np.deg2rad(angle)
        R = np.array([
            [np.cos(rad), -np.sin(rad)],
            [np.sin(rad), np.cos(rad)]])
        uvt -= 719 / 2
        uvt = np.dot(R, uvt.T).T
        uvt += 719 / 2
        uvt[:, 1] += dx
        uvt[:, 0] += dy
        return uvt

    def spatial_aug(self, depth_o, depth_n, cloth_mask, uv_o, uv_n_f):
        spatial_rot = self.cfg.spatial_rot
        spatial_trans = self.cfg.spatial_trans
        angle = np.random.randint(-spatial_rot, spatial_rot+1)
        dx = np.random.randint(-spatial_trans, spatial_trans+1)
        dy = np.random.randint(-spatial_trans, spatial_trans+1)
        depth_o = TF.affine(depth_o, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        depth_n = TF.affine(depth_n, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        cloth_mask = TF.affine(cloth_mask, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        uv_o = self.aug_uv(uv_o, -angle, dx/199*719, dy/199*719)
        uv_n_f = self.aug_uv(uv_n_f, -angle, dx/199*719, dy/199*719)
        return depth_o, depth_n, cloth_mask, uv_o, uv_n_f