import datetime
import os
import time

import hydra
import safety_gymnasium
import torch
import wandb
from PIL import Image

import utils
from replay_buffer import ReplayBuffer


class Workspace:
    def __init__(self, cfg):
        self.cfg = cfg
        self.curr_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        self.work_dir = os.getcwd()
        print(f'Workspace: {self.work_dir}')
        self.cfg = cfg
        assert 1 >= cfg.risk_level >= 0, f"risk_level must be between 0 and 1 (inclusive), got: {cfg.risk_level}"
        assert cfg.seed != -1, f"seed must be provided, got default seed: {cfg.seed}"
        self.device = torch.device(cfg.device)
        self.env = safety_gymnasium.make(cfg.env, render_mode='rgb_array')
        cfg.agent.params.obs_dim = self.env.observation_space.shape[0]
        cfg.agent.params.action_dim = self.env.action_space.shape[0]
        cfg.agent.params.action_range = [
            float(self.env.action_space.low.min()),
            float(self.env.action_space.high.max())
        ]
        self.agent = hydra.utils.instantiate(cfg.agent)
        self.replay_buffer = ReplayBuffer(self.env.observation_space.shape,
                                          self.env.action_space.shape,
                                          int(cfg.replay_buffer_capacity),
                                          self.device)
        self.step = 0
        if cfg.restart_path != "dummy":
            self.agent.load(cfg.restart_path)

        utils.make_dir(self.work_dir, "data", cfg.agent.name, self.curr_time)

        wandb.init(
            mode="offline",
            project="GMM-SSAC",
            name=self.curr_time
        )

    def run(self):
        episode, episode_reward, episode_cost, total_cost, terminated, truncated = 0, 0, 0, 0, True, True
        start_time = time.time()
        while self.step < self.cfg.num_train_steps:
            if terminated or truncated:
                print(f"Episode: {episode}, Reward: {episode_reward}, Cost: {episode_cost}")
                if self.step > 0:
                    duration = time.time() - start_time
                    wandb.log({'train/duration': duration}, step=self.step)
                    start_time = time.time()
                if self.step > 0 and self.step % self.cfg.eval_frequency == 0:
                    wandb.log({'eval/episode': episode}, step=self.step)
                    self.evaluate(episode)
                    start_time = time.time()

                wandb.log({'train/episode_reward': episode_reward}, step=self.step)
                wandb.log({'train/episode_cost': episode_cost}, step=self.step)

                if self.step > 0:
                    cost_rate = total_cost / self.step
                    wandb.log({'train/cost_rate': cost_rate}, step=self.step)

                episode += 1
                episode_reward = 0
                episode_cost = 0
                observation, info = self.env.reset()
                terminated, truncated = False, False
                episode_step = 0

                wandb.log({'train/episode': episode}, step=self.step)

            if self.step < self.cfg.num_seed_steps:
                action = utils.sample_action(self.env)
            else:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(observation, sample=True)

            if self.step >= self.cfg.num_seed_steps:
                self.agent.update(self.replay_buffer, self.step)

            next_observation, reward, cost, terminated, truncated, info = self.env.step(action)
            terminated = float(terminated)
            terminated_no_max = 0 if episode_step + 1 == self.env.spec.max_episode_steps else terminated
            episode_reward += reward
            episode_cost += cost
            total_cost += cost
            self.replay_buffer.add(observation, action, reward, cost, next_observation, terminated,
                                       terminated_no_max)
            observation = next_observation
            episode_step += 1
            self.step += 1

        model_path = os.path.join(self.work_dir, "data", self.cfg.agent.name, self.curr_time, f"episode_{episode}")
        self.agent.save(model_path)
        wandb.log({'eval/episode': episode}, step=self.step)
        self.evaluate(episode)

    def evaluate(self, episode_count):
        mean_reward, mean_cost = 0, 0
        for episode in range(self.cfg.num_eval_episodes):
            obs, _ = self.env.reset()
            terminated, truncated = False, False
            episode_reward = 0
            episode_cost = 0
            frames = []

            while not terminated and not truncated:
                with utils.eval_mode(self.agent):
                    action = self.agent.act(obs, sample=False)
                obs, reward, cost, terminated, truncated, info = self.env.step(action)
                frame = self.env.render()
                frames.append(Image.fromarray(frame))
                episode_reward += reward
                episode_cost += cost

            mean_reward += episode_reward
            mean_cost += episode_cost

            gif_path = os.path.join(self.work_dir, "data", self.cfg.agent.name, self.curr_time,
                                    f"eval_{episode_count}_{episode + 1}.gif")
            frames[0].save(gif_path, save_all=True, append_images=frames[1:], duration=40, loop=0)

        mean_reward /= self.cfg.num_eval_episodes
        mean_cost /= self.cfg.num_eval_episodes

        wandb.log({
            'eval/mean_reward': mean_reward,
            'eval/mean_cost': mean_cost
        }, step=self.step)

        model_path = os.path.join(self.work_dir, "data", self.cfg.agent.name, self.curr_time,
                                  f"episode_{episode_count}")
        self.agent.save(model_path)


@hydra.main(config_path='config/train.yaml', strict=True)
def main(cfg):
    utils.set_seed_everywhere(cfg.seed)
    workspace = Workspace(cfg)
    workspace.run()


# 按装订区域中的绿色按钮以运行脚本。
if __name__ == '__main__':
    main()
