import torch
import numpy as np
import wandb

from .par_bc import PARBC, ReplayBuffer
import utils
from .configs import config as algo_config

class PAR_BC_Wrapper:
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        self.env = env
        self.eval_env = eval_env
        self.config = config
        self.agent_path = agent_path
        self.evaluations_path = evaluations_path

        self.hidden_dims = self.env.hidden_dims if hasattr(env, 'hidden_dims') else None

        if 'halfcheetah' in self.config.env.train_env.lower():
            self.args = algo_config['halfcheetah']
        elif 'walker2d' in self.config.env.train_env.lower():
            self.args = algo_config['walker2D']
        elif 'hopper' in self.config.env.train_env.lower():
            self.args = algo_config['hopper']
        else:
            raise NotImplementedError


    def eval_policy(self, policy, eval_episodes=10):
        rewards = []
        for i in range(eval_episodes):
            state, done = self.eval_env.reset(), False
            current_reward = 0.
            while not done:
                if self.hidden_dims is not None:
                    state[self.hidden_dims] = 0.0

                state = np.array(state).reshape(1, -1)
                action = policy.select_action(state)
                state, reward, done, _ = self.eval_env.step(action)
                current_reward += reward

            rewards.append(current_reward)

        return rewards

    def train(self):
        args = self.args

        # device:
        if not self.config.system.cpu:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        dataset = utils.format_dataset(self.env)

        policy = PARBC(args, device)

        if len(dataset['observations'].shape) == 3:
            dataset['observations'] = dataset['observations'].reshape(dataset['observations'].shape[0], dataset['observations'].shape[1] * dataset['observations'].shape[2])
            dataset['next_observations'] = dataset['next_observations'].reshape(dataset['next_observations'].shape[0], dataset['next_observations'].shape[1] * dataset['next_observations'].shape[2])

        state_dim = dataset['observations'].shape[1]
        action_dim = self.env.action_space.shape[0]
        max_action = float(self.env.action_space.high[0])

        src_replay_buffer = ReplayBuffer(state_dim, action_dim, device)
        tar_replay_buffer = ReplayBuffer(state_dim, action_dim, device)

        # in case that the domain is offline, we directly load its offline data
        tar_replay_buffer.convert_D4RL(dataset)

        src_state, src_done = self.env.reset(), False
        src_episode_reward, src_episode_timesteps, src_episode_num = 0, 0, 0

        for t in range(int(self.config.train.max_timesteps)):
            src_episode_timesteps += 1

            # select action randomly or according to policy, if the policy is deterministic, add exploration noise akin to TD3 implementation
            src_action = (
                    policy.select_action(np.array(src_state), test=False) + np.random.normal(0, max_action * 0.2,
                                                                                             size=action_dim)
            ).clip(-max_action, max_action)

            src_next_state, src_reward, src_done, _ = self.env.step(src_action)
            src_done_bool = float(src_done) if src_episode_timesteps < self.args['max_episode_steps'] else 0

            src_replay_buffer.add(src_state, src_action, src_next_state, src_reward, src_done_bool)

            src_state = src_next_state
            src_episode_reward += src_reward

            policy.train(src_replay_buffer, tar_replay_buffer, args['batch_size'])

            if src_done:
                src_state, src_done = self.env.reset(), False
                src_episode_reward = 0
                src_episode_timesteps = 0
                src_episode_num += 1

            if (t + 1) % self.config.train.eval_freq == 0:
                print(f"Time steps: {t + 1}")
                all_rewards = self.eval_policy(policy)
                avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(all_rewards, self.config.env.eval_env)
                eval_episodes = len(all_rewards)
                print("---------------------------------------")
                print(f"Epoch {t + 1}: Evaluation over {eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
                print("---------------------------------------")

                if self.config.wandb.enable:
                    metrics = {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward,
                               'eval std reward': std_reward, 'epochs': t + 1}
                    wandb.log(metrics)


