from . import VecEnvWrapper
import numpy as np
from .running_mean_std import RunningMeanStd
import torch
import matplotlib.pyplot as plt
import os
from utils import get_kMedoids, drawArrows
from dataset import loadEnvData
import copy
import pickle


class VecPretextNormalize(VecEnvWrapper):
    """
    A vectorized wrapper that normalizes the observations
    and returns from an environment.
    """

    def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8, config=None):
        VecEnvWrapper.__init__(self, venv)

        self.config = config
        self.pretextModel=None # will be set by RL.py
        self.device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.medoids=None
        self.new_sound_feat=None
        self.fig=None
        self.ax=None
        self.fileNum=0
        self.quiver_img=None
        self.quiver_sound =None

        # a buffer containing triplet candidate for updating representation
        self.pair_candidates=[]
        self.pair_candidates_batch={
                              'sound_positive': np.zeros((config.pretextTrainBatchSize,)+config.sound_dim),
                              'image_feat': np.zeros((config.pretextTrainBatchSize, config.representationDim)),
                              'goal_sound_feat':np.zeros((config.pretextTrainBatchSize,config.representationDim)),
                              }
        self.pair_counter=0
        self.save_obs=[] # a buffer containing triplets to be saved to disk for updating representation
        # After pretextTrainBatchSize number pairs are in self.pair_candidates_batch we pass them through pretextModelReuse
        # some pairs according to self.score() will be kept in self.save_obs and saved to disk when its length > 10000

        self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
        self.ret_rms = RunningMeanStd(shape=()) if ret else None
        self.clipob = clipob
        self.cliprew = cliprew
        self.ret = np.zeros(self.num_envs)
        self.gamma = gamma
        self.epsilon = epsilon

        self.origStepReward = np.zeros(self.num_envs)
        self.rl_obs_space=None

        self.processing_func={
            'KukaConfig': self.processKuka, 'KinovaGen3Config': self.processKuka,
            'TurtleBotConfig': self.processTurtleBot, 'AI2ThorConfig':self.processAI2Thor
        }

    def step_wait(self):
        obs, env_rews, news, infos = self.venv.step_wait()
        # process the observations and reward
        obs,rews=self.processing_func[self.config.name](obs, env_rews, news, infos)

        self.origStepReward=rews.copy()
        # normalize the reward
        self.ret = self.ret * self.gamma + rews
        if self.ret_rms:
            self.ret_rms.update(self.ret)
            rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
        self.ret[news] = 0.

        return obs, rews, news, infos

    def _obfilt(self, obs):
        if self.ob_rms and self.config.RLTrain:
            self.ob_rms.update(obs)
            obs = np.clip((obs - self.ob_rms.mean) / np.sqrt(self.ob_rms.var + self.epsilon), -self.clipob, self.clipob)
            return obs
        else:
            return obs

    def reset(self):
        if self.config.RSI_ver>1 and (self.config.calcMedoids or self.config.RLRealTimePlot):
            dataset=self.config.pretextDataset
            data_generator, ds = loadEnvData(data_dir=self.config.pretextDataDir,
                                                 config=self.config,
                                                 batch_size=self.config.pretextTrainBatchSize,
                                                 shuffle=True,
                                                 num_workers=self.config.pretextDataNumWorkers, # change it to 0 if multiprocessing error
                                                 drop_last=True,
                                                 loadNum=self.config.pretextDataFileLoadNum,
                                             dtype=dataset)
        if self.config.RSI_ver>1 and self.config.calcMedoids:
            self.medoids = get_kMedoids(data_generator=data_generator,
                                       torch_device=self.device,
                                       model=self.pretextModel,
                                       config=self.config,
                                       filter_empty=True)
            self.new_sound_feat=np.zeros((self.num_envs, self.config.representationDim))

        if self.config.RSI_ver>1 and self.config.RLRealTimePlot:
            plt.ion()
            kwargs = {'medoids':self.medoids} if self.config.calcMedoids else {}
            self.fig, self.ax = self.config.plotFunc(data_generator, self.device, self.pretextModel, self.config, **kwargs)
            self.fig.canvas.draw_idle()
            self.fig.canvas.start_event_loop(0.001) # without it, the plot will not be drawn

        self.ret = np.zeros(self.num_envs)
        obs = self.venv.reset()
        obs, _ = self.processing_func[self.config.name](obs, np.zeros((self.num_envs,)), np.array([True]*self.num_envs), ({},)*self.num_envs)
        
        return obs

    def processTurtleBot(self, O, envReward, done, infos):
        if self.config.RSI_ver==1:
            soundLabel = np.zeros((self.num_envs, self.config.taskNum))

            for i in range(self.num_envs):
                if O['ground_truth'][i] != self.config.taskNum:
                    soundLabel[i][O['ground_truth'][i]] = 1

            s={ 'robot_pose': O['robot_pose'], 'goal_sound': O['goal_sound'],
                'image': O['image'] / 255.,
                'soundLabel': soundLabel, 'inSight': O['inSight'], 'exi': O['exi']
            }
            reward=envReward

        elif self.config.RSI_ver==2:
            with torch.no_grad():
                d = \
                self.pretextModel(torch.from_numpy(O['image'] / 255.).float().to(self.device),
                                  torch.from_numpy(O['goal_sound']).float().to(self.device),
                                  torch.from_numpy(O['current_sound']).float().to(self.device))

            image_feat = d['image_feat'].to('cpu').numpy()
            goal_sound_feat = d['sound_feat_positive'].to('cpu').numpy()
            current_sound_feat = d['sound_feat_negative'].to('cpu').numpy()

            if self.config.RLTrain:
                self.collectData(infos, image_feat, goal_sound_feat)

            if self.config.RLTask == 'avoid':
                raise NotImplementedError
            elif self.config.RLTask=='approach':
                img_sound_dot = np.sum(image_feat[:, :self.config.representationDim] * goal_sound_feat, axis=1)
                sound_sound_dot = np.sum(current_sound_feat * goal_sound_feat, axis=1)
                exi=np.zeros_like(img_sound_dot)
                exi[img_sound_dot>0.5]=1.
                exi[img_sound_dot<=0.5]=0.
                exi=np.expand_dims(exi, 1)

                s = {'robot_pose': O['robot_pose'], 'goal_sound_feat': goal_sound_feat,
                     'image': O['image'] / 255.,
                     'image_feat': image_feat, 'exi': exi
                     }

            else:
                raise NotImplementedError

            if self.config.RLRealTimePlot:
                self.quiver_img, self.quiver_sound = drawArrows(self.ax, self.fig, v_img=image_feat,
                                                                v_sound=goal_sound_feat, quiver_img=self.quiver_img,
                                                                quiver_sound=self.quiver_sound)

            embReward = img_sound_dot + sound_sound_dot * self.config.RLRewardSoundSound
            reward = embReward + envReward

        elif self.config.RSI_ver == 3:
            with torch.no_grad():
                d = \
                    self.pretextModel(torch.from_numpy(O['image'] / 255.).float().to(self.device),
                                      torch.from_numpy(O['goal_sound']).float().to(self.device),
                                      torch.from_numpy(O['current_sound']).float().to(self.device))

            image_feat = d['image_feat'].to('cpu').numpy()
            goal_sound_feat = d['sound_feat_positive'].to('cpu').numpy()
            current_sound_feat = d['sound_feat_negative'].to('cpu').numpy()

            if self.config.RLTask == 'approach':
                img_sound_dot = np.sum(image_feat[:, :self.config.representationDim] * goal_sound_feat, axis=1)
                sound_sound_dot = np.sum(current_sound_feat * goal_sound_feat, axis=1)
                exi = np.zeros_like(img_sound_dot)
                exi[img_sound_dot > 0.5] = 1.
                exi[img_sound_dot <= 0.5] = 0.
                exi = np.expand_dims(exi, 1)

                s = {'robot_pose': O['robot_pose'], 'goal_sound_feat': goal_sound_feat,
                     'image': O['image'] / 255.,
                     'image_feat': image_feat, 'exi': exi
                     }
            else:
                raise NotImplementedError

            if self.config.RLRealTimePlot:
                self.quiver_img, self.quiver_sound = drawArrows(self.ax, self.fig, v_img=image_feat,
                                                                v_sound=goal_sound_feat, quiver_img=self.quiver_img,
                                                                quiver_sound=self.quiver_sound)

            embReward = img_sound_dot + sound_sound_dot * self.config.RLRewardSoundSound
            reward = embReward + envReward

        else:
            raise NotImplementedError

        obs = self._obfilt(s)

        return obs, reward

    def processKuka(self, O, envReward, done, infos):
        if self.config.RSI_ver == 1:

            soundLabel = np.zeros((self.num_envs, 4))
            for i in range(self.num_envs):
                if O['ground_truth'][i] != 4:
                    soundLabel[i][O['ground_truth'][i]] = 1

            s = {'robot_pose': O['robot_pose'], 'goal_sound': O['goal_sound'],
                 'image': O['image'] / 255.,
                 'soundLabel':soundLabel
                 }

            reward = envReward
        elif self.config.RSI_ver == 2:
            with torch.no_grad():
                d = \
                    self.pretextModel(torch.from_numpy(O['image'] / 255.).float().to(self.device),
                                      torch.from_numpy(O['goal_sound']).float().to(self.device),
                                      torch.from_numpy(O['current_sound']).float().to(self.device))
            image_feat = d['image_feat'].to('cpu').numpy()
            goal_sound_feat = d['sound_feat_positive'].to('cpu').numpy()
            current_sound_feat = d['sound_feat_negative'].to('cpu').numpy()

            if self.config.RLTrain:
                self.collectData(infos, image_feat, goal_sound_feat)

            if self.config.RLTask=='avoid':
                for i in range(self.num_envs):
                    if done[i]:
                        repeated_sound_feat=goal_sound_feat[None, i]
                        repeated_sound_feat=np.repeat(repeated_sound_feat, repeats=[self.config.taskNum,], axis=0)
                        cos_sim=np.sum(repeated_sound_feat*self.medoids, axis=1)
                        avoid_idx=np.argmax(cos_sim)
                        prob=np.ones((self.config.taskNum,))/(self.config.taskNum-1)
                        prob[avoid_idx]=0.
                        self.new_sound_feat[i]=self.medoids[np.random.choice(self.config.taskNum, p=prob)]
                img_sound_dot = np.sum(image_feat[:, :self.config.representationDim] * self.new_sound_feat, axis=1)
                sound_sound_dot = np.sum(current_sound_feat * self.new_sound_feat, axis=1)

                s = {'robot_pose': O['robot_pose'], 'goal_sound_feat': self.new_sound_feat,
                     'image': O['image'] / 255.,
                     'image_feat': image_feat,
                     }

            elif self.config.RLTask=='approach':
                img_sound_dot = np.sum(image_feat[:, :self.config.representationDim] * goal_sound_feat, axis=1)
                sound_sound_dot=np.sum(current_sound_feat * goal_sound_feat, axis=1)
                s = {'robot_pose': O['robot_pose'], 'goal_sound_feat': goal_sound_feat,
                     'image': O['image'] / 255.,
                     'image_feat': image_feat,
                     }
            else:
                raise NotImplementedError

            if self.config.RLRealTimePlot:
                self.quiver_img, self.quiver_sound = drawArrows(self.ax, self.fig, v_img=image_feat,
                                                                v_sound=goal_sound_feat, quiver_img=self.quiver_img,
                                                                quiver_sound=self.quiver_sound)

            embReward = img_sound_dot + sound_sound_dot * self.config.RLRewardSoundSound
            reward = embReward + envReward


        elif self.config.RSI_ver == 3:
            with torch.no_grad():
                d = \
                    self.pretextModel(torch.from_numpy(O['image'] / 255.).float().to(self.device),
                                      torch.from_numpy(O['goal_sound']).float().to(self.device),
                                      torch.from_numpy(O['current_sound']).float().to(self.device))
            image_feat = d['image_feat'].to('cpu').numpy()
            goal_sound_feat = d['sound_feat_positive'].to('cpu').numpy()
            current_sound_feat = d['sound_feat_negative'].to('cpu').numpy()

            if self.config.RLTrain:
                self.collectData(infos, image_feat, goal_sound_feat)

            if self.config.RLTask == 'approach':
                img_sound_dot = np.sum(image_feat[:, :self.config.representationDim] * goal_sound_feat, axis=1)
                sound_sound_dot = np.sum(current_sound_feat * goal_sound_feat, axis=1)
                s = {'robot_pose': O['robot_pose'], 'goal_sound_feat': goal_sound_feat,
                     'image': O['image'] / 255.,
                     'image_feat': image_feat,
                     }

            else:
                raise NotImplementedError

            if self.config.RLRealTimePlot:
                self.quiver_img, self.quiver_sound = drawArrows(self.ax, self.fig, v_img=image_feat,
                                                                v_sound=goal_sound_feat, quiver_img=self.quiver_img,
                                                                quiver_sound=self.quiver_sound)
            embReward = img_sound_dot + sound_sound_dot * self.config.RLRewardSoundSound
            reward = embReward + envReward
        else:
            raise NotImplementedError

        obs = self._obfilt(s)

        return obs, reward

    def processAI2Thor(self, O, envReward, done, infos):
        if self.config.RSI_ver == 1:

            s = {'occupancy': O['occupancy'] / 255.,
                 'image': O['image'] / 255.,
                 'soundLabel': O['soundLabel'], 'inSight': O['inSight'], 'exi': O['exi']
                 }

            if not self.config.RLUseSoundLabel:
                s['goal_sound']=O['goal_sound']

            reward = envReward

        elif self.config.RSI_ver == 2:
            with torch.no_grad():
                d = \
                self.pretextModel(torch.from_numpy(O['image'] / 255.).float().to(self.device),
                                  torch.from_numpy(O['goal_sound']).float().to(self.device),
                                  torch.from_numpy(O['current_sound']).float().to(self.device))

            image_feat = d['image_feat'].to('cpu').numpy()
            goal_sound_feat = d['sound_feat_positive'].to('cpu').numpy()
            current_sound_feat = d['sound_feat_negative'].to('cpu').numpy()

            if self.config.RLTrain:
                self.collectData(infos, image_feat, goal_sound_feat)

            img_sound_dot = np.sum(image_feat[:, :self.config.representationDim] * goal_sound_feat, axis=1)
            sound_sound_dot = np.sum(current_sound_feat * goal_sound_feat, axis=1)

            s = {'occupancy': O['occupancy'] / 255.,
                 'goal_sound_feat': goal_sound_feat,
                 'image': O['image'] / 255.,
                 'image_feat': image_feat,
                 }

            if self.config.RLRealTimePlot:
                self.quiver_img, self.quiver_sound = drawArrows(self.ax, self.fig, v_img=image_feat,
                                                                v_sound=goal_sound_feat, quiver_img=self.quiver_img,
                                                                quiver_sound=self.quiver_sound)


            embReward = img_sound_dot + sound_sound_dot * self.config.RLRewardSoundSound
            reward = embReward + envReward

        elif self.config.RSI_ver == 3:
            with torch.no_grad():
                d = \
                self.pretextModel(torch.from_numpy(O['image'] / 255.).float().to(self.device),
                                  torch.from_numpy(O['goal_sound']).float().to(self.device),
                                  torch.from_numpy(O['current_sound']).float().to(self.device))

            image_feat = d['image_feat'].to('cpu').numpy()
            goal_sound_feat = d['sound_feat_positive'].to('cpu').numpy()
            current_sound_feat = d['sound_feat_negative'].to('cpu').numpy()

            if self.config.RLTrain:
                self.collectData(infos, image_feat, goal_sound_feat)

            img_sound_dot = np.sum(image_feat[:, :self.config.representationDim] * goal_sound_feat, axis=1)
            sound_sound_dot = np.sum(current_sound_feat * goal_sound_feat, axis=1)
            s = {'occupancy': O['occupancy'] / 255.,
                 'goal_sound_feat': goal_sound_feat,
                 'image': O['image'] / 255.,
                 'image_feat': image_feat,
                 }

            if self.config.RLRealTimePlot:
                self.quiver_img, self.quiver_sound = drawArrows(self.ax, self.fig, v_img=image_feat,
                                                                v_sound=goal_sound_feat, quiver_img=self.quiver_img,
                                                                quiver_sound=self.quiver_sound)

            embReward = img_sound_dot + sound_sound_dot * self.config.RLRewardSoundSound
            reward = embReward + envReward

        else:
            raise NotImplementedError

        obs = self._obfilt(s)

        return obs, reward

    def collectData(self, infos, image_feat, goal_sound_feat):
        def score_fn(v_i, v_g, v_sp):
            vivg_relu=np.maximum(0,np.sum(v_i*v_g, axis=1))
            vivsp_relu=np.maximum(0,np.sum(v_i*v_sp, axis=1))
            return np.maximum(0,vivg_relu-vivsp_relu)

        for i, dictionary in enumerate(infos): # collect triplets from the Envs
            if dictionary:
                self.pair_candidates.append(dictionary)
                self.pair_candidates_batch['sound_positive'][self.pair_counter]=dictionary['sound_positive']
                self.pair_candidates_batch['image_feat'][self.pair_counter] = image_feat[i]
                self.pair_candidates_batch['goal_sound_feat'][self.pair_counter] = goal_sound_feat[i]
                self.pair_counter= self.pair_counter+1
                if len(self.pair_candidates)==self.config.pretextTrainBatchSize: # pass all the pair through pretextModelReuse
                    with torch.no_grad():
                        _, sound_feat_positive, _ = \
                            self.pretextModel(None,
                                                   torch.from_numpy(self.pair_candidates_batch['sound_positive']).float().to(self.device),
                                                   None
                                                   )
                    sound_feat_positive=sound_feat_positive.to('cpu').numpy()
                    # score those pairs and decides which of them to keep
                    scores=score_fn(self.pair_candidates_batch['image_feat'],
                                    self.pair_candidates_batch['goal_sound_feat'],
                                    sound_feat_positive)
                    rand_num=np.random.rand(self.config.pretextTrainBatchSize)
                    selected_pair_idx=np.arange(self.config.pretextTrainBatchSize)[rand_num<scores]

                    # append paris to self.save_obs
                    self.save_obs=self.save_obs+list(np.array(self.pair_candidates)[selected_pair_idx])

                    # reset counter and buffer
                    self.pair_candidates = []
                    self.pair_candidates_batch = {
                        'sound_positive': np.zeros((self.config.pretextTrainBatchSize,) + self.config.sound_dim),
                        'image_feat': np.zeros((self.config.pretextTrainBatchSize, self.config.representationDim)),
                        'goal_sound_feat': np.zeros((self.config.pretextTrainBatchSize, self.config.representationDim)),
                    }
                    self.pair_counter=0

        if len(self.save_obs)>=2000:
            filePath = os.path.join(self.config.pretextModelUpdateDataDir, 'train')
            if not os.path.isdir(filePath):
                os.makedirs(filePath)
            self.fileNum=self.fileNum+1
            filePath = os.path.join(filePath, 'data_' + str(self.fileNum) + '.pickle')
            with open(filePath, 'wb') as f:
                pickle.dump(self.save_obs, f, protocol=pickle.HIGHEST_PROTOCOL)
            self.save_obs.clear()
            print('Save new data to', filePath)
