import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF

import os
import cv2
import yaml
import random
import socket
import imutils
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
from PIL import Image
from collections import namedtuple
from softgym.envs.corl_baseline import GCFold

from torch.utils.data import Dataset, DataLoader

Experience = namedtuple('Experience', ('obs', 'goal', 'act', 'rew', 'nobs', 'done'))

class TestGoalsDataset(Dataset):
    def __init__(self, config):
        self.cfg = config

        # hostname-specific vars
        hostname = socket.gethostname()
        if 'compute' in hostname:
            hostname = 'seuss'
        self.out_path = self.cfg[hostname]['out_path']
    
        # test goals
        self.eval_ims = dict()
        for start_fn, goal_fn in self.cfg['eval_combos']: 
            if start_fn not in self.eval_ims:
                self.eval_ims[start_fn] = self.get_eval(start_fn) 
            if goal_fn not in self.eval_ims:
                self.eval_ims[goal_fn] = self.get_eval(goal_fn)

    def get_eval(self, fn):
        """Get eval image and 3D coordinates
        """
        fpath = f'{self.out_path}/goals/{fn}_depth.png'
        im = cv2.imread(fpath)
        pos = np.load(f'{self.out_path}/goals/particles/{fn}.npy')[:, :3]
        return im, pos

    def __len__(self):
        return len(self.cfg['eval_combos'])

    def __getitem__(self, index):
        """Sample the goal at index
        """
        start_fn, goal_fn = self.cfg['eval_combos'][index]
        start_im, start_pos = self.eval_ims[start_fn]
        goal_im, goal_pos = self.eval_ims[goal_fn]
        return start_pos, goal_im, goal_pos

class TrainGoalsDataset(Dataset):
    def __init__(self, config):
        self.cfg = config

        # hostname-specific vars
        hostname = socket.gethostname()
        if 'compute' in hostname:
            hostname = 'seuss'
        self.out_path = self.cfg[hostname]['out_path']
    
        # load flat and unfolding goals
        self.flat_goals = dict()
        for start_fn, goal_fn in self.cfg['flat_combos']: 
            if start_fn not in self.flat_goals:
                self.flat_goals[start_fn] = self.get_goal(start_fn)
            if goal_fn not in self.flat_goals:
                self.flat_goals[goal_fn] = self.get_goal(goal_fn)

        self.unfold_goals = dict()
        for start_fn, goal_fn in self.cfg['unfold_combos']: 
            if start_fn not in self.unfold_goals:
                self.unfold_goals[start_fn] = self.get_goal(start_fn)
            if goal_fn not in self.unfold_goals:
                self.unfold_goals[goal_fn] = self.get_goal(goal_fn)

    def get_goal(self, fn):
        """Get goal image and particle 2D + 3D coordinates
        """
        fpath = f'{self.out_path}/goals/{fn}_depth.png'
        im = cv2.imread(fpath)
        pos = np.load(f'{self.out_path}/goals/particles/{fn}.npy')[:, :3]
        uv = np.load(f'{self.out_path}/goals/particles/{fn}_uv.npy')
        return im, pos, uv

    def sample(self):
        """Sample a random training goal
        """
        goal_prob = np.random.uniform()
        if goal_prob < self.cfg['flat_prob']: # flat goal
            # print(f"{goal_prob} === Loading flat goal")
            combo_ind = np.random.randint(len(self.cfg['flat_combos']))
            start_fn, goal_fn = self.cfg['flat_combos'][combo_ind]
            start_im, start_pos, _ = self.flat_goals[start_fn]
            goal_im, goal_pos, goal_uv = self.flat_goals[goal_fn]
        else:
            # print(f"{goal_prob} === Loading unfold goal")
            combo_ind = np.random.randint(len(self.cfg['unfold_combos']))
            start_fn, goal_fn = self.cfg['unfold_combos'][combo_ind]
            start_im, start_pos, _ = self.unfold_goals[start_fn]
            goal_im, goal_pos, goal_uv = self.unfold_goals[goal_fn]
        return start_pos, goal_im, goal_pos

class QNetDataset(Dataset):
    def __init__(self, config, model_name, camera_params, aug=True):
        self.cfg = config
        self.model_name = model_name
        # self.camera_params = camera_params
        self.aug = aug

        self.transform = T.Compose([T.ToTensor()])

    def __len__(self):
        return len(self.buf)

    def __getitem__(self, index):
        # def sample_batch(buf):
        # batch = random.sample(self.buf, self.cfg['batch'])
        # img_batch = []
        # for b in batch:
        b = self.buf[index]
        idx = b.obs
        obs = cv2.imread(f'{self.out_path}/output/{self.model_name}/online_data/obs_depth_{str(idx).zfill(6)}.png')
        nobs = cv2.imread(f'{self.out_path}/output/{self.model_name}/online_data/nobs_depth_{str(idx).zfill(6)}.png')
        goal = b.goal
        act = b.act
        rew = b.rew
        done = b.done

        # 50 percent chance of sampling a random nobs, otherwise use actual nobs as goal
        if self.cfg['her']:
            if np.random.uniform() < 0.5:
                n = random.sample(self.buf,1)[0].nobs
                goal = cv2.imread(f'{self.out_path}/output/{self.model_name}/online_data/nobs_depth_{str(n).zfill(6)}.png')
                
                # Check if the random nobs happens to be this nobs
                if ((nobs==goal).all()):
                    rew = 1
                    done = True
                else:
                    rew = 0
                    done = False
            else:
                goal = nobs
                rew = 1
                done = True

        if np.random.uniform() < 0.9:
            angle = np.random.uniform(-5,5)
            shift = (np.random.uniform(-5,5),np.random.uniform(-5,5))
            obs = self.augment(obs, angle, shift)
            nobs = self.augment(nobs, angle, shift)
            goal = self.augment(goal, angle, shift)

        # TODO convert to Tensors

        sample = {'obs': obs, 'nobs': nobs, 'goal': goal, 'act': act, 'rew': rew, 'done': done, 'idx': idx}
        return sample

    def augment(self, img, angle, shift):
        """Translate and rotate image for spatial aug
        """
        img = imutils.rotate(img, angle)
        img = imutils.translate(img,*shift)
        return img

    def init_buffer(self):
        """Load buffer and create buffer folder.
        """
        hostname = socket.gethostname()
        if 'compute' in hostname:
            hostname = 'seuss'
        self.buffer_folder = f"{self.cfg[hostname]['buffer_folder']}/{self.cfg['run_name']}"

        self.buf = torch.load(f'{self.buffer_folder}/{self.cfg["run_name"]}_idx.buf')
        self.buf = self.buf[:self.cfg['max_buf']]

        self.out_path = self.cfg[hostname]["out_path"]
        if not os.path.exists(f'{self.out_path}/output/{self.model_name}/online_data'):
            os.mkdir(f'{self.out_path}/output/{self.model_name}/online_data')

        self.buf_ind = self.copy_buffer()

    def copy_buffer(self):
        """Copy buffer from original folder to online data folder.
        """
        W = self.cfg['W']
        max_ind = 0
        inds = 0
        for b in self.buf:
            obs_ind = b.obs
            nobs_ind = b.nobs

            depth = np.load(os.path.join(self.buffer_folder,"rendered_images/%06d_depth_before.npy"%(obs_ind)))
            depth = depth*255
            depth = depth.astype(np.uint8)
            obs = np.dstack([depth, depth, depth])
            obs = cv2.resize(obs, (W, W))

            depth = np.load(os.path.join(self.buffer_folder,"rendered_images/%06d_depth_after.npy"%(nobs_ind)))
            depth = depth*255
            depth = depth.astype(np.uint8)
            nobs = np.dstack([depth, depth, depth])
            nobs = cv2.resize(nobs, (W, W))
            
            cv2.imwrite(f'{self.out_path}/output/{self.model_name}/online_data/obs_depth_{str(obs_ind).zfill(6)}.png', obs)
            cv2.imwrite(f'{self.out_path}/output/{self.model_name}/online_data/nobs_depth_{str(obs_ind).zfill(6)}.png', nobs)
            if obs_ind > max_ind:
                max_ind = obs_ind
            inds += 1

        assert max_ind+1 == inds
        return max_ind+1

if __name__ == '__main__':
    with open('config.yaml') as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)

    env = GCFold(use_depth=True,
                use_cached_states=False,
                horizon=5,
                use_desc=False,
                action_repeat=1,
                headless=True)
    
    model_name = 'debug'
    data = QNetDataset(cfg, model_name, env.camera_params)
    batch = random.sample(range(len(data)), 2)
    # print(data[batch[0]])

    train_goals = TrainGoalsDataset(cfg)
    print(train_goals.sample())

    # loader = DataLoader(data, batch_size=2, shuffle=False, num_workers=0)
    # for i, batch in enumerate(loader):
    #     print(batch['obs'].shape)
    #     break



