import os
from copy import deepcopy

import numpy as np
import torch
import wandb
from tqdm import trange

from .SimpleSAC.mixed_replay_buffer import MixedReplayBuffer
from .SimpleSAC.model import FullyConnectedQFunction, SamplerPolicy, TanhGaussianPolicy
from .SimpleSAC.sampler import StepSampler, TrajSampler
from .SimpleSAC.sim2real_sac import Sim2realSAC
from .SimpleSAC.utils_h2o import (Timer, WandBLogger, define_flags_with_default,
                                  get_user_flags, prefix_metrics, print_flags,
                                  set_random_seed)
from .Network.Weight_net import ConcatDiscriminator
import utils


class H2O_Wrapper:
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        self.env = env
        self.eval_env = eval_env
        self.config = config
        self.agent_path = agent_path
        self.evaluations_path = evaluations_path

        use_hidden_dims = True
        if use_hidden_dims:
            if hasattr(env, 'hidden_dims'):
                self.hidden_dims = self.env.hidden_dims
            elif hasattr(env, 'obs_hidden_dims'):
                self.hidden_dims = self.env.obs_hidden_dims
            else:
                self.hidden_dims = None
            # self.hidden_dims = self.env.hidden_dims if hasattr(env, 'hidden_dims') else None
        else:
            self.hidden_dims = None

        print(f'H2O created with use_hidden_dims = {use_hidden_dims}, and hidden_dims = {self.hidden_dims}')

        self.args = utils.Dict2Class({
            "batch_ratio": 0.5,
            "replaybuffer_ratio": 10,
            "real_residual_ratio": 1.0,
            "dis_dropout": False,
            "max_traj_length": 1000,
            "seed": 42,
            "save_model": False,
            "batch_size": 256,

            "reward_scale": 1.0,
            "reward_bias": 0.0,
            "clip_action": 1.0,
            "joint_noise_std": 0.0,

            "policy_arch": '256-256',
            "qf_arch": '256-256',
            "orthogonal_init": False,
            "policy_log_std_multiplier": 1.0,
            "policy_log_std_offset": -1.0,

            # train and evaluate policy
            "n_epochs": 1000,
            "bc_epochs": 0,
            "n_rollout_steps_per_epoch": 1000,
            "n_train_step_per_epoch": 1000,
            "eval_period": 10,
            "eval_n_trajs": 5
        })

    def train(self):
        args = self.args

        # device:
        if not self.config.system.cpu:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        # a step sampler for "simulated" training
        train_sampler = StepSampler(self.env, args.max_traj_length, self.hidden_dims)
        # a trajectory sampler for "real-world" evaluation
        eval_sampler = TrajSampler(self.eval_env, args.max_traj_length, self.hidden_dims)

        # replay buffer
        num_state = self.eval_env.observation_space.shape[0]
        num_action = self.eval_env.action_space.shape[0]
        replay_buffer = MixedReplayBuffer(self.env, args.reward_scale, args.reward_bias,
                                          args.clip_action, num_state, num_action,
                                          device=device,
                                          buffer_ratio=args.replaybuffer_ratio,
                                          residual_ratio=args.real_residual_ratio)

        # discriminators
        d_sa = ConcatDiscriminator(num_state + num_action, 256, 2, device,
                                   dropout=args.dis_dropout).float().to(
            device)
        d_sas = ConcatDiscriminator(2 * num_state + num_action, 256, 2, device,
                                    dropout=args.dis_dropout).float().to(
            device)

        # agent
        policy = TanhGaussianPolicy(
            eval_sampler.env.observation_space.shape[0],
            eval_sampler.env.action_space.shape[0],
            arch=args.policy_arch,
            log_std_multiplier=args.policy_log_std_multiplier,
            log_std_offset=args.policy_log_std_offset,
            orthogonal_init=args.orthogonal_init,
        )

        qf1 = FullyConnectedQFunction(
            eval_sampler.env.observation_space.shape[0],
            eval_sampler.env.action_space.shape[0],
            arch=args.qf_arch,
            orthogonal_init=args.orthogonal_init,
        )
        target_qf1 = deepcopy(qf1)

        qf2 = FullyConnectedQFunction(
            eval_sampler.env.observation_space.shape[0],
            eval_sampler.env.action_space.shape[0],
            arch=args.qf_arch,
            orthogonal_init=args.orthogonal_init,
        )
        target_qf2 = deepcopy(qf2)

        cql = Sim2realSAC.get_default_config()

        if cql.target_entropy >= 0.0:
            cql.target_entropy = -np.prod(eval_sampler.env.action_space.shape).item()

        sac = Sim2realSAC(cql, policy, qf1, qf2, target_qf1, target_qf2, d_sa, d_sas, replay_buffer,
                          device, dynamics_model=None)
        sac.torch_to_device(device)

        # sampling policy is always the current policy: \pi
        sampler_policy = SamplerPolicy(policy, device)

        # viskit_metrics = {}

        # train and evaluate for n_epochs
        for epoch in trange(args.n_epochs):
            metrics = {}

            # TODO rollout from the simulator
            with Timer() as rollout_timer:
                # rollout and append simulated trajectories to the replay buffer
                train_sampler.sample(
                    sampler_policy, args.n_rollout_steps_per_epoch,
                    deterministic=False, replay_buffer=replay_buffer, joint_noise_std=args.joint_noise_std
                )
                metrics['epoch'] = epoch

            # TODO Train from the mixed data
            with Timer() as train_timer:
                for batch_idx in trange(args.n_train_step_per_epoch):
                    real_batch_size = int(args.batch_size * (1 - args.batch_ratio))
                    sim_batch_size = int(args.batch_size * args.batch_ratio)
                    if batch_idx + 1 == args.n_train_step_per_epoch:
                        metrics.update(
                            prefix_metrics(sac.train(real_batch_size, sim_batch_size), 'sac')
                        )
                    else:
                        sac.train(real_batch_size, sim_batch_size)

            # TODO Evaluate in the real world
            with Timer() as eval_timer:
                if epoch == 0 or (epoch + 1) % args.eval_period == 0:
                    trajs = eval_sampler.sample(
                        sampler_policy, args.eval_n_trajs, deterministic=True
                    )

                    eval_dsa_loss, eval_dsas_loss = sac.discriminator_evaluate()
                    metrics['eval_dsa_loss'] = eval_dsa_loss
                    metrics['eval_dsas_loss'] = eval_dsas_loss

                    all_rewards = [np.sum(t['rewards']) for t in trajs]
                    avg_reward, std_reward, avg_norm_reward, std_norm_reward = utils.get_eval_statistics(all_rewards,
                                                                                                   self.config.env.eval_env)

                    current_time_step = epoch * args.n_train_step_per_epoch
                    eval_episodes = args.eval_n_trajs
                    print("---------------------------------------")
                    print(
                        f"Epoch {current_time_step}: Evaluation over {eval_episodes} episodes: {avg_reward:.3f} +- {std_reward:.3f}, Normalized score = {avg_norm_reward:.3f} +- {std_norm_reward:.3f}")
                    print("---------------------------------------")
                    if self.config.wandb.enable:
                        metrics.update(
                            {'eval mean reward': avg_reward, 'avg_norm_reward': avg_norm_reward,
                             'eval std reward': std_reward,
                             'epochs': current_time_step})

                    if self.config.save_model:
                        save_data = {'sac': sac, 'epoch': epoch}
                        torch.save(save_data, os.path.join(self.agent_path, 'model.pkl'))

            if self.config.wandb.enable:
                wandb.log(metrics)

        if self.config.save_model:
            save_data = {'sac': sac, 'epoch': epoch}
            torch.save(save_data, os.path.join(self.agent_path, 'model.pkl'))
