import warnings

warnings.filterwarnings('ignore', category=DeprecationWarning)

import os

os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
#os.environ['MUJOCO_GL'] = 'egl'
#os.environ['CUDA_VISIBLE_DEVICES'] = '1'

from PIL import Image

from pathlib import Path

import hydra
import numpy as np
import utils
import torch
from dm_env import specs

import metaworld_env as mw

from logger import Logger
from replay_buffer import ReplayBufferStorage, make_replay_loader
from video import TrainVideoRecorder, VideoRecorder
import wandb
import math
import re
import sys

#sys.path.append("../")

torch.backends.cudnn.benchmark = True


def make_agent(obs_spec, action_spec, cfg):
    cfg.obs_shape = obs_spec.shape
    cfg.action_shape = action_spec.shape
    return hydra.utils.instantiate(cfg)


class Workspace:
    def __init__(self, cfg):
        self.work_dir = Path.cwd()
        print(f'workspace: {self.work_dir}')
        self.cfg = cfg
        if self.cfg.use_wandb:
            exp_name = '_'.join([cfg.task_name, str(cfg.seed)])
            group_name = re.search(r'\.(.+)\.', cfg.agent._target_).group(1)
            wandb.init(project="DrM",
                       group=group_name,
                       name=exp_name,
                       config=cfg)
        utils.set_seed_everywhere(cfg.seed)
        self.device = torch.device(cfg.device)
        self._discount = cfg.discount
        self._discount_alpha = cfg.discount_alpha
        self._discount_alpha_temp = cfg.discount_alpha_temp
        self._discount_beta = cfg.discount_beta
        self._discount_beta_temp = cfg.discount_beta_temp
        self._nstep = cfg.nstep
        self._nstep_alpha = cfg.nstep_alpha
        self._nstep_alpha_temp = cfg.nstep_alpha_temp
        self.setup()
        self.agent = make_agent(self.train_env.observation_spec(),
                                self.train_env.action_spec(), self.cfg.agent)
        self.timer = utils.Timer()
        self._global_step = 0
        self._global_episode = 0

    def setup(self):
        # create logger
        self.logger = Logger(self.work_dir,
                             use_tb=self.cfg.use_tb,
                             use_wandb=self.cfg.use_wandb)
        # create envs
        self.train_env = mw.make(self.cfg.task_name, self.cfg.frame_stack,
                                  self.cfg.action_repeat, self.cfg.seed)
        self.eval_env = mw.make(self.cfg.task_name, self.cfg.frame_stack,
                                 self.cfg.action_repeat, self.cfg.seed)
        # create replay buffer
        data_specs = (self.train_env.observation_spec(),
                      self.train_env.action_spec(),
                      specs.Array((1, ), np.float32, 'reward'),
                      specs.Array((1, ), np.float32, 'discount'))

        self.replay_storage = ReplayBufferStorage(data_specs,
                                                  self.work_dir / 'buffer')
        print("[Agent] ",self.cfg.agent._target_)
        self.replay_loader, self.buffer = make_replay_loader(
            self.work_dir / 'buffer', self.cfg.replay_buffer_size,
            self.cfg.batch_size,
            self.cfg.replay_buffer_num_workers, self.cfg.save_snapshot,
            math.floor(self._nstep + self._nstep_alpha,),
            self._discount - self._discount_alpha - self._discount_beta,
            self.cfg.agent._target_)
        self._replay_iter = None

        #self.video_recorder = VideoRecorder(
            #self.work_dir if self.cfg.save_video else None)
        self.train_video_recorder = TrainVideoRecorder(
            self.work_dir if self.cfg.save_train_video else None)

    @property
    def global_step(self):
        return self._global_step

    @property
    def global_episode(self):
        return self._global_episode

    @property
    def global_frame(self):
        return self.global_step * self.cfg.action_repeat

    @property
    def replay_iter(self):
        if self._replay_iter is None:
            self._replay_iter = iter(self.replay_loader)
        return self._replay_iter

    @property
    def discount(self):
        return self._discount - self._discount_alpha * math.exp(
            -self.global_step /
            self._discount_alpha_temp) - self._discount_beta * math.exp(
                -self.global_step / self._discount_beta_temp)

    @property
    def nstep(self):
        return math.floor(self._nstep + self._nstep_alpha *
                          math.exp(-self.global_step / self._nstep_alpha_temp))

    def update_buffer(self):
        #self.buffer.update_discount(self.discount)
        self.buffer.update_nstep(self.nstep)
        return

    def eval(self):
        step, episode, total_reward, total_sr = 0, 0, 0, 0
        eval_until_episode = utils.Until(self.cfg.num_eval_episodes)

        while eval_until_episode(episode):
            episode_sr = False
            time_step = self.eval_env.reset()
            #self.video_recorder.init(self.eval_env, enabled=(episode == 0))
            while not time_step.last():
                with torch.no_grad(), utils.eval_mode(self.agent):

                    action = self.agent.act(time_step.observation,
                                            self.global_step,
                                            eval_mode=True)
                time_step = self.eval_env.step(action)
                episode_sr = episode_sr or time_step.success
                total_reward += time_step.reward
                step += 1

            total_sr += episode_sr
            episode += 1

            #if self.global_frame > 1000000 and success > 0 and save_num < 4:
                #imageio.mimsave(video_path, frames, format='GIF', duration = 40)
                #save_num += 1
                #print(video_path)
            #self.video_recorder.save(f'{self.global_frame}.mp4')

        with self.logger.log_and_dump_ctx(self.global_frame, ty='eval') as log:
            log('episode_success_rate', total_sr / episode)
            log('episode_reward', total_reward / episode)
            log('episode_length', step * self.cfg.action_repeat / episode)
            log('episode', self.global_episode)
            log('step', self.global_step)

    def train(self):
        # predicates
        train_until_step = utils.Until(self.cfg.num_train_frames,
                                       self.cfg.action_repeat)
        seed_until_step = utils.Until(self.cfg.num_seed_frames,
                                      self.cfg.action_repeat)
        eval_every_step = utils.Every(self.cfg.eval_every_frames,
                                      self.cfg.action_repeat)

        episode_step, episode_reward, episode_sr = 0, 0, False
        time_step = self.train_env.reset()
        self.replay_storage.add(time_step)
        self.train_video_recorder.init(time_step.observation)
        metrics = None
        while train_until_step(self.global_step):
            if time_step.last():
                self._global_episode += 1
                self.train_video_recorder.save(f'{self.global_frame}.mp4')

                # reset env
                time_step = self.train_env.reset()
                self.replay_storage.add(time_step)                # wait until all the metrics schema is populated
                if metrics is not None:
                    # log stats
                    elapsed_time, total_time = self.timer.reset()
                    episode_frame = episode_step * self.cfg.action_repeat
                    with self.logger.log_and_dump_ctx(self.global_frame,
                                                      ty='train') as log:
                        log('fps', episode_frame / elapsed_time)
                        log('total_time', total_time)
                        log('episode_success_rate', episode_sr)
                        log('episode_reward', episode_reward)
                        log('episode_length', episode_frame)
                        log('episode', self.global_episode)
                        log('buffer_size', len(self.replay_storage))
                        log('step', self.global_step)

                # reset env
                time_step = self.train_env.reset()
                self.replay_storage.add(time_step)
                if self.cfg.save_train_video:
                    self.train_video_recorder.init(time_step.observation)
                # try to save snapshot
                if self.cfg.save_snapshot:
                    self.save_snapshot()
                episode_sr = False
                episode_step = 0
                episode_reward = 0



            # sample action
            with torch.no_grad(), utils.eval_mode(self.agent):
                action = self.agent.act(time_step.observation,
                                        self.global_step,
                                        eval_mode=False)

            
            # visualization
            if 0: self.visualization(time_step)
            else:
                # try to evaluate
                if eval_every_step(self.global_step):
                    self.logger.log('eval_total_time', self.timer.total_time(),
                                self.global_frame)
                    self.eval()
                # try to update the agent
                if not seed_until_step(self.global_step) and self.global_step % self.cfg.update_every_steps == 0:   
                    metrics = self.agent.update(
                        self.replay_iter, self.global_step)
                    self.logger.log_metrics(metrics, self.global_frame, ty='train')

            # take env step
            time_step = self.train_env.step(action)
            episode_reward += time_step.reward
            self.replay_storage.add(time_step)
            if self.cfg.save_train_video:
                render_img = self.train_env._env.render((784,784))
                self.train_video_recorder.record(render_img)
            episode_step += 1
            self._global_step += 1

    def visualization(self, time_step, resize=784, size = 784, aug = 0.10714):
         #0.10714 [cnn_out 3000 0.6 Green][cnn_out 4800 0.65 Green][cnn_out 6000 0.4 red]
        obs = torch.as_tensor(time_step.observation, device=self.device)
        obs_aug = self.agent.aug(obs.unsqueeze(0).float()).squeeze(0) #[9, 84, 84]
        cnn_out = self.agent.encoder(obs.unsqueeze(0)) #([1, 32, 35, 35])
        cnn_out = np.array(cnn_out.reshape((32, 35, 35)).mean(0).detach().cpu()) #(35,35)
        cnn_norm = (cnn_out - np.min(cnn_out)) / np.max(cnn_out) - np.min(cnn_out)
        zero_c = np.zeros_like(cnn_out)
        img_hot= np.dstack([zero_c,cnn_out,zero_c])    #(35,35,3)
        img_hot = np.clip(img_hot*3500,0,255)
        img_hot = Image.fromarray((img_hot).astype(np.uint8))
        resized_image = img_hot.resize((resize,resize))
        print(self.global_step, np.max(np.array(img_hot)))

        tmp = self.train_env._env.render((size,size)) 
        c = round(size * aug)
        tmp = tmp[c:size-c,c:size-c,:]
        #tmp = np.array(time_step.observation[-3:,:,:].transpose((1,2,0))) #（84*84*3）
        img = Image.fromarray(tmp.astype(np.uint8))
        tmp1 = img.resize((resize,resize)) 

        # merge = Image.blend(resized_image,tmp1,0.6)

    def save_snapshot(self):
        snapshot = self.work_dir / 'snapshot.pt'
        keys_to_save = ['agent', 'timer', '_global_step', '_global_episode']
        payload = {k: self.__dict__[k] for k in keys_to_save}
        with snapshot.open('wb') as f:
            torch.save(payload, f)

    def load_snapshot(self):
        #snapshot = self.work_dir / 'snapshot.pt'
        snapshot = snapshot / '{}-snapshot.pt'.format(self.cfg.task_name)
        with snapshot.open('rb') as f:
            payload = torch.load(f)
        for k, v in payload.items():
            self.__dict__[k] = v


@hydra.main(config_path='cfgs', config_name='config')
def main(cfgs):
    from train_mw import Workspace as W
    root_dir = Path.cwd()
    workspace = W(cfgs)
    snapshot = root_dir / 'snapshot.pt'
    # snapshot = snapshot / '{}-snapshot.pt'.format(cfgs.task_name)
    if 1 and snapshot.exists():
        print(f'resuming: {snapshot}')
        workspace.load_snapshot()
    workspace.train()


if __name__ == '__main__':
    main()
