import os
import sys
from collections import deque
from pathlib import Path

import gymnasium as gym  # To load the new GYM API
import numpy as np
import setproctitle
import torch
from tensorboardX import SummaryWriter

from src.config import get_config
from src.utils.replayMemory import Memory
from src.utils.zfilter import ZFilter


class MujocoRunner:
    def __init__(self, config):
        self.all_args = config['all_args']
        self.envs = config['envs']
        self.device = config['device']
        self.running_state = config["running_state"]
        self.env_name = self.all_args.env_name
        self.algorithm_name = self.all_args.algorithm_name
        self.num_env_steps = self.all_args.num_env_steps
        self.episode_length = self.all_args.episode_length
        self.hidden_size = self.all_args.hidden_size
        self.env_steps_before_update = self.all_args.env_steps_before_update \
            if self.all_args.env_steps_before_update > 0 else 1
        
        self.save_interval = self.all_args.save_interval
        self.log_interval = self.all_args.log_interval

        self.run_dir = config["run_dir"]
        self.log_dir = str(self.run_dir / 'logs')
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)
        self.save_dir = str(self.run_dir / 'models')
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        self.all_args.__dict__['run_dir'] = self.run_dir

        self.recent_episode_returns = deque(maxlen=100)

        self.writter = SummaryWriter(self.log_dir)
        if self.all_args.use_wandb:
            from src.utils.wandb_logger import WandbWriter
            self.writter = WandbWriter(self.writter, self.all_args)

        if self.algorithm_name == 'smac':
            from src.algorithms.ac.smac import ActorCritic as TrainAlgo
            from src.algorithms.ac.acPolicy import Critic as Critic
        elif self.algorithm_name == 'actor-critic':
            from src.algorithms.ac.ac import ActorCritic as TrainAlgo
            from src.algorithms.ac.acPolicy import Critic as Critic
        elif self.algorithm_name == 'actor-critic-sgd':
            from src.algorithms.ac.ac_sgd import ActorCritic as TrainAlgo
            from src.algorithms.ac.acPolicy import Critic as Critic
        elif self.algorithm_name == 'actor-critic-npg':
            from src.algorithms.ac.ac_npg import ActorCritic as TrainAlgo
            from src.algorithms.ac.acPolicy import Critic as Critic
        elif self.algorithm_name == 'ac-kfac':
            from src.algorithms.ac.ac_KFAC import ActorCritic as TrainAlgo
            from src.algorithms.ac.acPolicy import Critic as Critic
        else:
            raise ValueError(f'Unknown algorithm: {self.algorithm_name}')

        input_dim = self.envs.observation_space.shape[0]
        if isinstance(self.envs.action_space, gym.spaces.Discrete):
            from src.algorithms.ac.acPolicy import DiscreteActor as Policy
            output_dim = self.envs.action_space.n
        else:
            from src.algorithms.ac.acPolicy import Actor as Policy
            output_dim = self.envs.action_space.shape[0]

        self.policy = Policy(input_dim, output_dim,
                             activation=self.all_args.activation,
                             hidden_size=self.hidden_size).to(self.device)
        self.critic = Critic(self.envs.observation_space.shape[0], hidden_size=self.hidden_size)

        self.trainer = TrainAlgo(self.all_args, self.policy, self.critic, device=self.device)

    def save(self):
        policy_actor = self.trainer.policy
        torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor.pt")
        policy_critic = self.trainer.critic
        torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic.pt")

    def run(self):
        episodes = 0
        steps = 0
        num_env_steps = self.num_env_steps
        avg_reward = []
        episode_returns = []
        episode_lens = []
        iteration = 0

        while steps < num_env_steps:
            iteration += 1
            steps_per_iteration = 0
            episode_return = []
            episode_len = []

            if self.all_args.use_linear_lr_decay and steps > 0:
                frac = 1.0 - (steps - 1.0) / num_env_steps
                lr_now = frac * self.all_args.lr
                lr_critic_now = frac * self.all_args.critic_lr if self.all_args.use_critic_linear_lr_decay else None
                self.trainer.set_lr(lr_now, lr_critic_now)

            memory = Memory()
            while steps_per_iteration < self.env_steps_before_update:
                episodes += 1
                avg_reward.append([])
                episode_return.append(0)
                episode_len.append(0)

                state, _ = self.envs.reset()
                if self.running_state is not None:
                    state = self.running_state(state)

                for step in range(self.episode_length):
                    steps += 1
                    steps_per_iteration += 1

                    with torch.no_grad():
                        
                        state1 = torch.from_numpy(np.stack(state)).to(torch.float64).float().to(self.device)
                        action = self.trainer.policy.select_action(state1)

                        if "xla" in str(self.device):
                            import torch_xla.core.xla_model as xm
                            xm.mark_step()

                        action = action.cpu().numpy()

                    next_state, reward, terminated, truncated, infos = self.envs.step(action)

                    episode_return[-1] += reward
                    avg_reward[-1].append(reward)
                    episode_len[-1] += 1

                    if self.running_state is not None:
                        next_state = self.running_state(next_state)

                    mask = 1.0 - terminated

                    memory.push(state, action, mask, next_state, reward)

                    if terminated or truncated:
                        break

                    state = next_state
                # Avg. reward per timestep
                avg_reward[-1] = np.mean(avg_reward[-1])
            episode_returns.append(np.mean(episode_return))
            episode_lens.append(np.mean(episode_len))

            train_info = self.train_model(memory.sample())

            self.recent_episode_returns.append(np.mean(episode_return))

            if iteration % self.save_interval == 0 or iteration <= 1 or steps >= num_env_steps:
                avg_reward = np.mean(avg_reward)
                self.save()
                self.log_train(train_info, steps, avg_reward, np.mean(episode_returns), np.mean(episode_lens))
                avg_reward = []
                episode_returns = []
                episode_lens = []
            print("updates {}/{} steps, score: {}\n".format(steps, num_env_steps, np.mean(episode_return)))

    def train_model(self, batch):
        states = torch.from_numpy(np.stack(batch.state)).to(torch.float64).float().to(self.device)
        actions = torch.from_numpy(np.stack(batch.action)).to(torch.float64).float().to(self.device)
        rewards = torch.from_numpy(np.stack(batch.reward)).to(torch.float64).float().to(self.device)
        masks = torch.from_numpy(np.stack(batch.mask)).to(torch.float64).float().to(self.device)
        info = self.trainer.train(states, actions, rewards, masks)
        return info

    def log_train(self, train_infos, total_num_steps, episode_rewards, episode_return, episode_len):
        train_infos["episode_rewards"] = episode_rewards
        train_infos["episode_returns"] = episode_return
        train_infos["episode_returns_last_100"] = np.mean(list(self.recent_episode_returns))
        train_infos["episode_len"] = episode_len
        train_infos["lr"] = self.trainer.lr
        train_infos["critic_lr"] = self.trainer.critic_lr
        for k, v in train_infos.items():
            self.writter.add_scalars(k, {k: v}, total_num_steps)


def detect_device(n_training_threads, cuda_deterministic):
    try:
        import torch_xla.core.xla_model as xm
        return xm.xla_device()
    except (ImportError, RuntimeError):
        pass

    if torch.cuda.is_available():
        print("choose to use gpu...")
        torch.set_num_threads(n_training_threads)
        if cuda_deterministic:
            torch.backends.cudnn.benchmark = False
            torch.backends.cudnn.deterministic = True
        return torch.device("cuda:0")

    print("choose to use cpu...")
    torch.set_num_threads(n_training_threads)
    return torch.device('cpu')


def main(args):
    parser = get_config()
    all_args = parser.parse_known_args(args)[0]

    device = detect_device(all_args.n_training_threads, all_args.cuda_deterministic)

    envs = make_train_env(all_args)
    state_dim = envs.observation_space.shape[0]
    if envs.spec.entry_point.split(".")[0] == "gymnasium_robotics":
        running_state = None
    else:
        running_state = ZFilter((state_dim,), clip=5)

    algo = str(all_args.algorithm_name) + str(all_args.add)
    run_dir = Path(__file__).resolve().parent.parent.parent / "results" / all_args.env_name / all_args.task_name / algo

    if not run_dir.exists():
        os.makedirs(str(run_dir))

    if not run_dir.exists():
        curr_run = 'run1'
    else:
        exst_run_nums = [int(str(folder.name).split('run')[1]) for folder in run_dir.iterdir() if
                         str(folder.name).startswith('run')]
        if len(exst_run_nums) == 0:
            curr_run = 'run1'
        else:
            curr_run = 'run%i' % (max(exst_run_nums) + 1)
    run_dir = run_dir / curr_run
    if not run_dir.exists():
        os.makedirs(str(run_dir))

    setproctitle.setproctitle(
        str(all_args.algorithm_name) + "-" + str(all_args.env_name))

    torch.manual_seed(all_args.seed)
    torch.cuda.manual_seed_all(all_args.seed)
    np.random.seed(all_args.seed)

    config = {
        "all_args": all_args,
        "envs": envs,
        "num_agents": 1,
        "device": device,
        "running_state": running_state,
        "run_dir": run_dir
    }

    runner = MujocoRunner(config)
    runner.run()

    envs.close()

    runner.writter.export_scalars_to_json(str(runner.log_dir + '/summary.json'))
    runner.writter.close()


def make_train_env(all_args):
    def init_env():
        if all_args.env_name == "mujoco":
            if all_args.task_name == "hopper":
                env = gym.make('Hopper-v5')  # Current version in MuJoCo-3
            elif all_args.task_name == "HalfCheetah":
                env = gym.make('HalfCheetah-v5')
            elif all_args.task_name == "Humanoid":
                env = gym.make('Humanoid-v5')
            elif all_args.task_name == "swimmer":
                env = gym.make('Swimmer-v5')
            elif all_args.task_name == "walker":
                env = gym.make('Walker2d-v5')
            elif all_args.task_name == "ant":
                env = gym.make('Ant-v2')
            elif all_args.task_name == "pusher":
                env = gym.make('Pusher-v5')
            else:
                raise ValueError(f'Unknown task {all_args.task_name} for env {all_args.env_name}')
        elif all_args.env_name == "box2d":
            if all_args.task_name == "bipedalwalker":
                env = gym.make('BipedalWalker-v3')  # Current version in Gymnasium
            elif all_args.task_name == "lunarlander":
                env = gym.make('LunarLander-v3')
            else:
                raise ValueError(f'Unknown task {all_args.task_name} for env {all_args.env_name}')
        elif all_args.env_name == "classic":
            if all_args.task_name == "acrobot":
                env = gym.make('Acrobot-v1')
            elif all_args.task_name == "cartpole":
                env = gym.make('CartPole-v1')
            elif all_args.task_name == "pendulum":
                env = gym.make('Pendulum-v1')
            else:
                raise ValueError(f'Unknown task {all_args.task_name} for env {all_args.env_name}')
        elif all_args.env_name == "mujoco-robotics":
            import gymnasium_robotics as gym_robo  # For Robotics Environments
            from gymnasium.wrappers import FlattenObservation
            gym.register_envs(gym_robo)  # Import Robotics Envs
            if all_args.task_name == "reach":
                env = gym.make('FetchReachDense-v3')  # Current version in MuJoCo-3
            elif all_args.task_name == "push":
                env = gym.make('FetchPushDense-v3')
            elif all_args.task_name == "picknplace":
                env = gym.make('FetchPickAndPlaceDense-v3')
            elif all_args.task_name == "slide":
                env = gym.make('FetchSlideDense-v3')
            elif all_args.task_name == "handreach":
                env = gym.make('HandReachDense-v2')
            elif all_args.task_name == "handblock":
                env = gym.make('HandManipulateBlockDense-v1')
            elif all_args.task_name == "handegg":
                env = gym.make('HandManipulateEggDense-v1')
            elif all_args.task_name == "handpen":
                env = gym.make('HandManipulatePenDense-v1')
            else:
                raise ValueError(f'Unknown task {all_args.task_name} for env {all_args.env_name}')
            env = FlattenObservation(env)
            env = gym.wrappers.ClipAction(env)
            env = gym.wrappers.NormalizeObservation(env)
            env = gym.wrappers.NormalizeReward(env, gamma=all_args.gamma)
        else:
            raise ValueError(f'Unknown env {all_args.env_name}')
        return env
    return init_env()


if __name__ == "__main__":
    main(sys.argv[1:])
