import os

import torch
import numpy as np
import wandb
# import d4rl

from .buffer import ReplayBuffer
from .model import TD3_BC
import utils


class TD3_BC_Wrapper:
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        self.env = env
        self.eval_env = eval_env
        self.config = config
        self.agent_path = agent_path
        self.evaluations_path = evaluations_path

        self.hidden_dims = self.env.hidden_dims if hasattr(env, 'hidden_dims') else None

        # TD3
        self.args = utils.Dict2Class({
            "expl_noise": 0.1,  # Std of Gaussian exploration noise
            "batch_size": 256,  # Batch size for both actor and critic
            "discount": 0.99,  # Discount factor
            "tau": 0.005,  # Target network update rate
            "policy_noise": 0.2,  # Noise added to target policy during critic update
            "noise_clip": 0.5,  # Range to clip target policy noise
            "policy_freq": 2,  # Frequency of delayed policy updates
            # TD3 + BC
            "alpha": 2.5,
            "normalize": True,
        })

    def eval_policy(self, policy, mean, std, eval_episodes=10):
        rewards = []
        for i in range(eval_episodes):
            state, done = self.eval_env.reset(), False
            current_reward = 0.
            while not done:
                if self.hidden_dims is not None:
                    state[self.hidden_dims] = 0.0

                state = (np.array(state).reshape(1, -1) - mean) / std
                action = policy.select_action(state)
                state, reward, done, _ = self.eval_env.step(action)
                current_reward += reward

            rewards.append(current_reward)

        return rewards

    def train(self):
        args = self.args

        # device:
        if not self.config.system.cpu:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        dataset = utils.format_dataset(self.env)
        if len(dataset['observations'].shape) == 3:
            dataset['observations'] = dataset['observations'].reshape(dataset['observations'].shape[0], dataset['observations'].shape[1] * dataset['observations'].shape[2])
            dataset['next_observations'] = dataset['next_observations'].reshape(dataset['next_observations'].shape[0], dataset['next_observations'].shape[1] * dataset['next_observations'].shape[2])

        state_dim = dataset['observations'].shape[1]
        action_dim = self.env.action_space.shape[0]
        max_action = float(self.env.action_space.high[0])

        kwargs = {
            "state_dim": state_dim,
            "action_dim": action_dim,
            "max_action": max_action,
            "discount": args.discount,
            "tau": args.tau,
            # TD3
            "policy_noise": args.policy_noise * max_action,
            "noise_clip": args.noise_clip * max_action,
            "policy_freq": args.policy_freq,
            # TD3 + BC
            "alpha": args.alpha,
            "device": device
        }

        # Initialize policy
        policy = TD3_BC(**kwargs)

        if self.config.load_model:
            policy.load(self.agent_path)

        replay_buffer = ReplayBuffer(state_dim, action_dim)
        replay_buffer.convert_D4RL(dataset)
        if args.normalize:
            mean, std = replay_buffer.normalize_states()
        else:
            mean, std = 0, 1

        evaluations = []
        for t in range(int(self.config.train.max_timesteps)):
            policy.train(replay_buffer, self.config.train.batch_size)
            # Evaluate episode
            if (t + 1) % self.config.train.eval_freq == 0:
                print(f"Time steps: {t + 1}")
                all_rewards = self.eval_policy(policy, mean, std)
                avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(all_rewards, self.config.env.eval_env)
                eval_episodes = len(all_rewards)
                print("---------------------------------------")
                print(f"Epoch {t + 1}: Evaluation over {eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
                print("---------------------------------------")

                if self.config.wandb.enable:
                    metrics = {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward,
                               'eval std reward': std_reward, 'epochs': t + 1}
                    wandb.log(metrics)

                evaluations.append(np.mean(all_rewards))
                torch.save(evaluations, self.evaluations_path)
                if self.config.save_model:
                    policy.save(self.agent_path)
