import os
import json
import datetime
import numpy as np
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from env import ENV
import shutil

class BASETrainer:
    """ base trainer """

    def __init__(self, args):
        # init env
        self.env = ENV[args.env](args.env_name)
        self.env.action_space.seed(args.seed)

        self.eval_env = ENV[args.env](args.env_name)
        self.eval_env.action_space.seed(args.seed)

        if args.env == "adroit" or args.env == "maze":
            self.env.seed(args.seed)
            self.eval_env.seed(args.seed)
        else:
            self.env.reset(seed=args.seed)
            self.eval_env.reset(seed=args.seed)

        args.obs_shape = self.env.observation_space.shape
        args.action_space = self.env.action_space
        args.action_dim = np.prod(args.action_space.shape)

        # running parameters
        self.batch_size = args.batch_size
        self.eval_n_episodes = args.eval_n_episodes
        self.device = args.device
        self.seed = args.seed
        self.args = args
        
        # resume
        self.records = self._default_records()
        self.start_it = 0
        
        # log and backup
        if args.disable_wandb == True: 
            print("\n>>> Disabled uploading files <<<")
            return
        dtime = datetime.datetime.now().strftime("%y-%m%d-%H%M%S")
        if args.env == "neorl": args.env_name += f"-{args.data_type}"
        self.model_dir = "./result/{}/{}/{}/{}/model".format(args.env, args.env_name, args.algo, dtime)
        self.record_dir = "./result/{}/{}/{}/{}/record".format(args.env, args.env_name, args.algo, dtime)
        self.log_dir = "./result/{}/{}/{}/{}/log".format(args.env, args.env_name, args.algo, dtime)
        self.file_dir = "./result/{}/{}/{}/{}/file".format(args.env, args.env_name, args.algo, dtime)
        os.makedirs(self.model_dir, exist_ok=True)
        os.makedirs(self.record_dir, exist_ok=True)
        os.makedirs(self.log_dir, exist_ok=True)
        os.makedirs(self.file_dir, exist_ok=True)
        print("\nTraining data and model saved in \nmodel >>>{}\nrecord >>>{}\nlogs >>>{}\nbackup >>>{}\n".format(self.model_dir, self.record_dir, self.log_dir, self.file_dir))
        
        self.logger = SummaryWriter(self.log_dir)
        self.files_to_backup = [
            "./agent/admpo.py",
            "./runner/online_trainer.py",
            "./runner/base_trainer.py",
            "./main4online.py"
        ]
        self._save_file(self.files_to_backup, self.file_dir)

    def _save_file(self, src, dest):
        for file in src:
            if os.path.exists(file):
                shutil.copy2(file, dest)
                print(f"Backed up: {file}")
            else: print(f"Not existed: {file}")

    def _default_records(self):
        return {
            "step": [],
            "loss": {
                "model": [],
                "actor": [],
                "critic1": [],
                "critic2": [],
            },
            "alpha": [],
            "reward_mean": [],
            "reward_std": [],
            "reward_min": [],
            "reward_max": [],
        }

    def _load_resume(self):
        assert self.args.resume_path is not None, "resume=True but resume_path is None"
        assert os.path.exists(self.args.resume_path), "resume_path not exists!"

        print(f"Resume model from <<< {self.args.resume_path}")
        self.agent.load_model(self.args.resume_path)
        
        fname = os.path.basename(self.args.resume_path)
        try:
            self.start_it = int(fname.split("-")[1])
        except:
            self.start_it = 0

        model_dir = os.path.dirname(self.args.resume_path)
        record_dir = model_dir.replace(os.sep + "model", os.sep + "record")
        record_path = os.path.join(record_dir, f"record_seed-{self.seed}.txt")

        if os.path.exists(record_path):
            with open(record_path, "r") as f:
                self.records = json.load(f)
        else:
            self.records = self._default_records()

        print(f">>> Resume from step {self.start_it}")


    def _warm_up(self):
        """ randomly sample a lot of transitions into buffer before starting learning """
        obs, _ = self.env.reset()

        # step for {self.start_learning} time-steps
        pbar = tqdm(range(self.start_learning), desc="Warming up")
        for _ in pbar:
            action = self.env.action_space.sample()
            next_obs, reward, terminated, truncated, _ = self.env.step(action)
            self.memory.store(obs, action, reward, next_obs, terminated, truncated)

            obs = next_obs
            if terminated or truncated: obs, _ = self.env.reset()

        return obs

    def _eval_policy(self):
        """ evaluate policy """
        episode_rewards = []
        for _ in range(self.eval_n_episodes):
            done = False
            episode_rewards.append(0)
            obs, _ = self.eval_env.reset()
            while not done:
                action, _ = self.agent.act(obs, deterministic=True)
                obs, reward, terminated, truncated, _ = self.eval_env.step(action)
                done = terminated or truncated
                episode_rewards[-1] += reward
        return episode_rewards

    def _save(self, records, global_step):
        """ save model and record """
        file_path = os.path.join(self.model_dir, "agent_step-{}-seed-{}.pth".format(global_step, self.seed))
        self.agent.save_model(file_path)
        with open(os.path.join(self.record_dir, "record_seed-{}.txt".format(self.seed)), "w") as f:
            json.dump(records, f)
