import torch
import numpy as np
import wandb
# import d4rl
import os

import utils
from .utils_iql import ReplayBuffer
from .IQL import IQL


class IQL_Wrapper:
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        self.env = env
        self.eval_env = eval_env
        self.config = config
        self.agent_path = agent_path
        self.evaluations_path = evaluations_path

        self.hidden_dims = self.env.hidden_dims if hasattr(env, 'hidden_dims') else None

        # TD3
        self.args = utils.Dict2Class({
            'batch_size': 256,  # Batch size for both actor and critic
            'temperature': 3.0,
            'expectile': 0.7,
            'tau': 0.005,
            'discount': 0.99,
            'normalize': True}) # Discount factor

    def eval_policy(self, policy, mean, std, eval_episodes=10):
        rewards = []
        for i in range(eval_episodes):
            state, done = self.eval_env.reset(), False
            current_reward = 0.
            while not done:
                if self.hidden_dims is not None:
                    state[self.hidden_dims] = 0.0

                state = (np.array(state).reshape(1, -1) - mean) / std
                action = policy.select_action(state)
                state, reward, done, _ = self.eval_env.step(action)
                current_reward += reward

            rewards.append(current_reward)

        return rewards

    def train(self):
        args = self.args

        # device:
        if not self.config.system.cpu:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        dataset = utils.format_dataset(self.env)

        state_dim = self.env.observation_space.shape[0]
        action_dim = self.env.action_space.shape[0]

        kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "discount": args.discount,
            "tau": args.tau,
            "temperature": args.temperature,
            "expectile": args.expectile,
            "device": device
        }

        # Initialize policy
        policy = IQL(**kwargs)

        if self.config.load_model:
            policy.load(self.agent_path)

        replay_buffer = ReplayBuffer(state_dim, action_dim)
        replay_buffer.convert_D4RL(dataset)
        if args.normalize:
            mean, std = replay_buffer.normalize_states()
        else:
            mean, std = 0, 1

        evaluations = []
        for t in range(int(self.config.train.max_timesteps)):
            policy.train(replay_buffer, self.config.train.batch_size)
            # Evaluate episode
            if (t + 1) % self.config.train.eval_freq == 0:
                print(f"Time steps: {t + 1}")
                all_rewards = self.eval_policy(policy, mean, std)
                avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(all_rewards, self.config.env.eval_env)
                eval_episodes = len(all_rewards)
                print("---------------------------------------")
                print(f"Epoch {t + 1}: Evaluation over {eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
                print("---------------------------------------")

                if self.config.wandb.enable:
                    metrics = {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward,
                               'eval std reward': std_reward, 'epochs': t + 1}
                    wandb.log(metrics)

                evaluations.append(np.mean(all_rewards))
                torch.save(evaluations, self.evaluations_path)
                if self.config.save_model:
                    policy.save(self.agent_path)
