import numpy as np
import torch
import torch.utils.data as data
import torchvision.transforms.functional as TF
import torchvision.transforms as transforms
from PIL import Image
from utils import remove_dups, flow_afwarp, avg_flow_angle, get_gaussian
from flow import Flow
import cv2
from copy import deepcopy
import os.path as osp
import os
import matplotlib.pyplot as plt
from models import FlowNetSmall

#from raft import RAFT
#import torch
#import argparse


class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1.):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__


class Dataset(data.Dataset):
    def __init__(self, config, buf, mode='train', single=False, pick_place=True, pick_pt=True, coords=False, ret_all=False):
        self.cfg = config
        self.mode = mode
        self.single = single
        self.pick_place = pick_place
        self.resize_act = True
        self.buf = buf

        if mode == 'train':
            self.data_path = f"{self.cfg['dataset_path']}/{self.cfg['run_name']}"
        else:
            self.data_path = f"{self.cfg['dataset_path']}/{self.cfg['test_name']}"

        if self.cfg['flow'] == 'gt':
            self.gt_flow = True
            self.flw = Flow()
            if not osp.exists(osp.join(self.data_path, "flow_gt")):
                os.mkdir(osp.join(self.data_path, "flow_gt"))
        else:
            self.gt_flow = False
            self.flw = FlowNetSmall(input_channels=2).cuda()
            checkpt = torch.load(f'./flow_model/{self.cfg["flow"]}')
            self.flw.load_state_dict(checkpt['state_dict'])
            self.flw.eval()
            if not osp.exists(osp.join(self.data_path, "flow_lr")):
                os.mkdir(osp.join(self.data_path, "flow_lr"))

        self.pick_pt = pick_pt
        self.ret_all = ret_all
        self.coords = coords
        self.camera_params = {'default_camera':
                              {'pos': np.array([-0.0, 0.45, 0.0]),
                               'angle': np.array([0, -np.pi/2., 0.]),
                               'width': 720,
                               'height': 720}}
        

    def __len__(self):
        return len(self.buf)


    def __getitem__(self, buf_idx):
        b = self.buf[buf_idx]

        if self.single:
            if self.pick_place:
                index, pick = b.obs, b.act[0]
            else:
                index, pick = b.obs, b.act[2:4]
            #pick[0] = pick[0] // 10
            #pick[1] = pick[1] // 10
        else:
            if self.pick_place:
                index = b.obs
                pick1 = b.act[0]
                place1 = b.act[1]
                pick2 = b.act[2]
                place2 = b.act[3]
            else:
                index, pick1, pick2 = b.obs, b.act[3:5], b.act[5:7]
            #pick1[0] = pick1[0] // 10
            #pick1[1] = pick1[1] // 10
            #pick2[0] = pick2[0] // 10
            #pick2[1] = pick2[1] // 10
        
        depth_n = np.load(f'{self.data_path}/rendered_images/{str(index).zfill(6)}_depth_after.npy')
        depth_o = np.load(f'{self.data_path}/rendered_images/{str(index).zfill(6)}_depth_before.npy')

        if self.gt_flow:
            uv_n_f = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knots_after.npy')
            uv_n_f[:,[1,0]] = uv_n_f[:,[0,1]] # knots axes are flipped in collect_data

            # ---- load flow

            if not osp.exists(osp.join(self.data_path, "flow_gt", f"{str(index).zfill(6)}_flow.npy")):
                coords_o = np.load(f'{self.data_path}/coords/{str(index).zfill(6)}_coords_before.npy')
                #cloth_mask = (depth_o != 0).astype(float) # 200 x 200
                uv_o_f = np.load(f'{self.data_path}/knots/{str(index).zfill(6)}_knots_before.npy')
                uv_o_f[:,[1,0]] = uv_o_f[:,[0,1]] # knots axes are flipped in collect_data

                # Remove occlusions
                depth_o_rs = cv2.resize(depth_o, (720, 720))
                uv_o, _ = remove_dups(self.camera_params, uv_o_f, coords_o, depth_o_rs, zthresh=0.001)

                # Remove out of bounds
                uv_o[uv_o < 0] = float('NaN')
                uv_o[uv_o >= 720] = float('NaN')

                # Get flow image
                flow_lbl = self.flw.get_image(uv_o, uv_n_f)

                # save the flow
                np.save(osp.join(self.data_path, "flow_gt", f"{str(index).zfill(6)}_flow.npy"), flow_lbl)
            else:
                flow_lbl = np.load(osp.join(self.data_path, "flow_gt", f"{str(index).zfill(6)}_flow.npy"))
            # ----
        else:
            if self.cfg['misalign']:
                angle = np.random.randint(-2, 3)
                dx = np.random.randint(-2, 3)
                dy = np.random.randint(-2, 3)
                depth_o = TF.affine(depth_o, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
                inp = torch.stack([torch.FloatTensor(depth_o), torch.FloatTensor(depth_n)]).unsqueeze(0)
                flow_lbl = self.flw(inp.cuda())
                flow_lbl = flow_lbl.squeeze().cpu()
            elif not osp.exists(osp.join(self.data_path, "flow_lr", f"{str(index).zfill(6)}_flow.npy")):
                inp = torch.stack([torch.FloatTensor(depth_o), torch.FloatTensor(depth_n)]).unsqueeze(0)
                flow_lbl = self.flw(inp.cuda())
                flow_lbl = flow_lbl.squeeze().cpu()
                np.save(osp.join(self.data_path, "flow_lr", f"{str(index).zfill(6)}_flow.npy"), flow_lbl.detach().numpy())
            else:
                flow_lbl = np.load(osp.join(self.data_path, "flow_lr", f"{str(index).zfill(6)}_flow.npy"))

        depth_o = torch.FloatTensor(depth_o)
        depth_o = depth_o.unsqueeze(0)
        depth_n = torch.FloatTensor(depth_n)
        depth_n = depth_n.unsqueeze(0)

        if self.gt_flow:
            flow_lbl = flow_lbl.transpose([2,0,1])

        if not isinstance(flow_lbl, torch.Tensor):
            flow_lbl = torch.FloatTensor(flow_lbl)

        # mask flow
        flow_lbl[0,:,:][depth_o[0] == 0] = 0
        flow_lbl[1,:,:][depth_o[0] == 0] = 0

        # align to flow dir
        if self.cfg['flow_align']:
            # pad
            diag_len = 200*np.sqrt(2)
            pad_amount = int(np.ceil((diag_len - 200)/2))
            padded_scale = 200/(200+2*pad_amount)

            depth_o = self.pad(depth_o,pad_amount)
            depth_n = self.pad(depth_n,pad_amount)
            flow_lbl = self.pad(flow_lbl,pad_amount)

            pick1 = (pick1 + pad_amount)*padded_scale
            pick2 = (pick2 + pad_amount)*padded_scale
            place1 = (place1 + pad_amount)*padded_scale
            place2 = (place2 + pad_amount)*padded_scale

            if self.cfg['augment']:
                angle = np.random.randint(-5, 6)
                dx = np.random.randint(-5, 6)
                dy = np.random.randint(-5, 6)
                depth_o, depth_n, pick1, pick2, place1, place2 = self.spatial_aug(depth_o, depth_n, pick1, pick2, place1, place2, angle, dx, dy)
                flow_im = flow_lbl.permute(1, 2, 0).numpy()
                flow_rot = flow_afwarp(flow_im, -angle, 0, 0)
                flow_lbl = torch.FloatTensor(flow_rot).permute(2, 0, 1)

            # Get average flow angle
            flow_im = flow_lbl.detach().permute(1, 2, 0).numpy()
            avg_angle = avg_flow_angle(flow_im)

            # Rotate flow image and depth image so average flow angle points right to left
            theta = np.deg2rad(180) - avg_angle
            flow_aligned = flow_afwarp(flow_im, np.rad2deg(-theta), 0, 0)
            depth_o, depth_n, pick1, pick2, place1, place2 = self.spatial_aug(depth_o, depth_n, pick1, pick2, place1, place2, np.rad2deg(theta), 0, 0)
            flow_lbl = torch.FloatTensor(flow_aligned).permute(2, 0, 1)

            # resize
            depth_o = self.resize(depth_o)
            depth_n = self.resize(depth_n)
            flow_lbl = self.resize(flow_lbl, factor=padded_scale)

        if self.cfg['augment']:
            angle = np.random.randint(-5, 6)
            dx = np.random.randint(-5, 6)
            dy = np.random.randint(-5, 6)
            depth_o, depth_n, pick1, pick2, place1, place2 = self.spatial_aug(depth_o, depth_n, pick1, pick2, place1, place2, angle, dx, dy)
            flow_im = flow_lbl.permute(1, 2, 0).detach().numpy()
            flow_rot = flow_afwarp(flow_im, -angle, 0, 0)
            flow_lbl = torch.FloatTensor(flow_rot).permute(2, 0, 1)

        if self.cfg['rand_depth_offset']:
            offset = np.random.uniform(-0.1,0.1)
            depth_o += offset
            depth_n += offset

        if self.cfg['gaussian_noise']:
            depth_noise = AddGaussianNoise(0., 0.007)
            flow_noise = AddGaussianNoise(0., 0.6)
            depth_o = depth_noise(depth_o)
            depth_n = depth_noise(depth_n)
            flow_lbl = flow_noise(flow_lbl)


        # return
        if self.pick_pt:
            pt1, pt2 = pick1, pick2
        else:
            pt1, pt2 = place1, place2

        if self.single:
            return depth_o, depth_n, flow_lbl, pick
        else:
            if self.ret_all:
                coords_o = np.load(f'{self.data_path}/coords/{str(index).zfill(6)}_coords_before.npy')
                coords_n = np.load(f'{self.data_path}/coords/{str(index).zfill(6)}_coords_after.npy')

                rgb_o = cv2.imread(f'{self.data_path}/images/{str(index).zfill(6)}_rgb_before.png')
                rgb_o = cv2.cvtColor(rgb_o,cv2.COLOR_BGR2RGB)
                rgb_n = cv2.imread(f'{self.data_path}/images/{str(index).zfill(6)}_rgb_after.png')
                rgb_n = cv2.cvtColor(rgb_n,cv2.COLOR_BGR2RGB)

                #return depth_o, depth_n, flow_lbl, pick1, pick2, place1, place2, coords_o, coords_n
                return depth_o, depth_n, flow_lbl, pick1, pick2, place1, place2, coords_o, coords_n, rgb_o, rgb_n
            elif self.coords:
                coords_o = np.load(f'{self.data_path}/coords/{str(index).zfill(6)}_coords_before.npy')
                coords_n = np.load(f'{self.data_path}/coords/{str(index).zfill(6)}_coords_after.npy')
                return depth_o, depth_n, flow_lbl, pt1, pt2, coords_o, coords_n
            else:
                return depth_o, depth_n, flow_lbl, pt1, pt2

    def spatial_aug(self, depth_o, depth_n, pick1, pick2, place1, place2, angle, dx, dy):
        depth_o = TF.affine(depth_o, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        depth_n = TF.affine(depth_n, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        #cloth_mask = TF.affine(cloth_mask, angle=angle, translate=(dx, dy), scale=1.0, shear=0)
        #uv_o = self.aug_uv(uv_o, -angle, dx/199*719, dy/199*719, size=719)
        #uv_n_f = self.aug_uv(uv_n_f, -angle, dx/199*719, dy/199*719, size=719)
        pick1 = self.aug_uv(pick1.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        pick1 = pick1.squeeze().astype(int)
        pick2 = self.aug_uv(pick2.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        pick2 = pick2.squeeze().astype(int)
        place1 = self.aug_uv(place1.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        place1 = place1.squeeze().astype(int)
        place2 = self.aug_uv(place2.astype(np.float64)[None,:], -angle, dx, dy, size=199)
        place2 = place2.squeeze().astype(int)
        #return depth_o, depth_n, uv_o, uv_n_f, pick1, pick2
        return depth_o, depth_n, pick1, pick2, place1, place2

    def aug_uv(self, uv, angle, dx, dy, size=719):
        uvt = deepcopy(uv)
        rad = np.deg2rad(angle)
        R = np.array([
            [np.cos(rad), -np.sin(rad)],
            [np.sin(rad), np.cos(rad)]])
        uvt -= size / 2
        uvt = np.dot(R, uvt.T).T
        uvt += size / 2
        uvt[:, 1] += dx
        uvt[:, 0] += dy

        uvt = np.clip(uvt, 0, size)

        return uvt

    def pad(self, im, pad_amount):
        trans = transforms.Pad(pad_amount)
        padded_im = trans(im)
        return padded_im

    def resize(self, im, size=(200,200), factor=1.0):
        trans = transforms.Resize(size, interpolation=0)
        resized_im = factor * trans(im)
        return resized_im


class Goals(data.Dataset):
    def __init__(self, config, mode='all'):
        self.cfg = config

        # if self.cfg['flow'] == 'gt':
        #     self.gt_flow = True
        #     self.flw = Flow()
        # else:

        # learned flow
        #assert self.cfg['flow'] != 'gt'
        #self.flw = FlowNetSmall(input_channels=2).cuda()
        #checkpt = torch.load(f'./flow_model/{self.cfg["flow"]}')
        #self.flw.load_state_dict(checkpt['state_dict'])
        #self.flw.eval()

        # get goal list
        #self.goal_folder = self.cfg['goal_folder']
        #self.eval_combos = self.cfg['eval_combos']
        self.eval_combos = []
        self.mode = mode
        goals = []

        if self.mode not in ['os','ms','rect','tsh','towel']:
            raise Exception("invalid mode for goal loader")

        if self.mode == 'os' or self.mode == 'towel':
            for i in range(40):
                #goals.append([f'towel_train_{i}_high'])
                goals.append([f'test_goal_{i}'])
            # goals.append(['ms_one_corn_in_0'])
            # goals.append(['ms_triangle_0'])
        if self.mode == 'ms' or self.mode == 'towel':
            names = ['ms_opp_corn_in',
                'ms_all_corn_in',
                'ms_two_side_horz',
                'ms_two_side_vert',
                'ms_double_tri',
                'ms_double_rect']
            lens = [2,4,2,2,2,2]
            for name,steps in zip(names,lens):
                goal = []
                for i in range(steps):
                    goal.append(f'{name}_{i}')
                goals.append(goal)
        if self.mode == 'rect':
            names = ['rect_horz_fold',
                     'rect_vert_fold',
                     'rect_one_corn_in',
                     'rect_two_side_horz',
                     'rect_two_side_vert']
            lens = [1,1,1,2,2]
            for name,steps in zip(names,lens):
                goal = []
                for i in range(steps):
                    goal.append(f'{name}_{i}')
                goals.append(goal)
        if self.mode == 'tsh':
            names = ['tsh_three_step',
                     'tsh_across_horz',
                     'tsh_across_vert']
            lens = [3,1,1]
            for name,steps in zip(names,lens):
                goal = []
                for i in range(steps):
                    goal.append(f'{name}_{i}')
                goals.append(goal)
        

        self.goals = goals

        # for goal in goals:
        #     self.eval_combos.append([[],goal])

        # self.eval_combos = [['open_2side','debug_bimanual_1'],
        #                     ['open_2side','debug_bimanual_2'],
        #                     ['open_2side','debug_bimanual_3'],
        #                     ['open_2side','debug_bimanual_4']]

        #self.eval_combos.append(['open_2side','towel_train_16'])

        
    def __len__(self):
        return len(self.goals)

    def __getitem__(self, index):
        # load images
        if self.mode == 'rect':
            start_fn = 'rect_open_0'
        elif self.mode == 'tsh':
            start_fn = 'tsh_open_0'
        else:
            start_fn = 'open_2side_high'

        goal_fns = self.goals[index]

        depth_o = cv2.imread(f'../goals/{start_fn}_depth.png')[:, :, 0] / 255
        coords_o = np.load(f'../goals/particles/{start_fn}.npy')

        depth_ns = []
        coords_ns = []
        #flows = []
        for goal_fn in goal_fns:
            depth_n = cv2.imread(f'../goals/{goal_fn}_depth.png')[:, :, 0] / 255
            coords_n = np.load(f'../goals/particles/{goal_fn}.npy')
        
            # calculate flow
            #inp = torch.stack([torch.FloatTensor(depth_o), torch.FloatTensor(depth_n)]).unsqueeze(0)
            #flow_lbl = self.flw(inp.cuda())
            #flow_lbl = flow_lbl.squeeze().cpu()

            depth_ns.append(depth_n)
            coords_ns.append(coords_n)
            #flows.append(flow_lbl)

        # convert to tensor
        depth_o = torch.FloatTensor(depth_o)
        depth_o = depth_o.unsqueeze(0)
        depth_ns = torch.FloatTensor(depth_ns)
        depth_ns = depth_ns.unsqueeze(1)
        #flows = torch.stack(flows)

        #return depth_o, depth_ns, flows, coords_o, coords_ns
        return depth_o, depth_ns, coords_o, coords_ns