import os

import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
import wandb
# import d4rl
from tqdm import trange

from .models.transition_model import TransitionModel
from .models.transition_model_w_sim import TransitionModelWithSimulator
from .models.policy_models import MLP, ActorProb, Critic, DiagGaussian
from .algo.sac import SACPolicy
from .algo.mopo import MOPO
from .common.buffer import ReplayBuffer
from .common.logger import Logger
from .trainer import Trainer
from .static_fns import get_static_fn
from .configs import get_algo_config
import utils
from .common.util import calc_sim_next_obs


class MOPO_Wrapper:
    """
    Based on the code in: https://github.com/junming-yang/mopo
    """
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        self.env = env
        self.eval_env = eval_env
        self.config = config
        self.agent_path = agent_path
        self.evaluations_path = evaluations_path
        self.use_sim_model = False
        self.get_args()

        self.hidden_dims = self.env.hidden_dims if hasattr(env, 'hidden_dims') else None

    def get_args(self):
        # Origianls
        self.args = utils.Dict2Class({
            "algo-name": "mopo",
            "actor-lr": 3e-4,
            "critic-lr": 3e-4,
            "gamma": 0.99,
            "tau": 0.005,
            "alpha": 0.2,
            'auto-alpha': True,
            'target-entropy': -3,
            'alpha-lr': 3e-4,

            # dynamics model's arguments
            "dynamics-lr": 0.001,
            "n-ensembles": 7,
            "n-elites": 5,
            "reward-penalty-coef": 1.0,
            "rollout-length": 1,
            "rollout-batch-size": 50000,
            "rollout-freq": 1000,
            "model-retain-epochs": 5,
            "real-ratio": 0.05,

            "epoch": 1000,
            "step-per-epoch": 1000,
            "eval_episodes": 10,
            "batch-size": 256,
            "logdir": "log",
            "log-freq": 1000
        })

        self.args.reward_penalty_coef = 1.0
        self.args.rollout_length = 1

        # HalfCheetah
        if 'halfcheetah-random' in self.config.env.train_env:
            self.args.reward_penalty_coef = 0.5
            self.args.rollout_length = 5
        elif 'halfcheetah-medium-replay' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 5
        elif 'halfcheetah-medium-expert' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 5
        elif 'halfcheetah-medium' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 1

        # Hopper
        elif 'hopper-medium-replay' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 5
        elif 'hopper-medium-expert' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 5
        elif 'hopper-medium' in self.config.env.train_env:
            self.args.reward_penalty_coef = 5.0
            self.args.rollout_length = 5

        # Walker2d
        if 'walker2d-random' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 1
        elif 'walker2d-medium-replay' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 1
        elif 'walker2d-medium-expert' in self.config.env.train_env:
            self.args.reward_penalty_coef = 2.0
            self.args.rollout_length = 1
        elif 'walker2d-medium' in self.config.env.train_env:
            self.args.reward_penalty_coef = 5.0
            self.args.rollout_length = 5

        print(f'Initialized MOPO algorithm on environment {self.config.env.train_env} with reward-penalty-coefficient = {self.args.reward_penalty_coef} and rollout-length = {self.args.rollout_length}')

    # env, eval_env, config, agent_path, evaluations_path

    def train(self):
        args = self.args

        # device:
        if not self.config.system.cpu:
            device = torch.device('cuda')
        else:
            device = torch.device('cpu')

        # create env and dataset
        dataset = utils.format_dataset(self.env)

        if self.use_sim_model:
            sim_data_dir = f'data/sim_data/'
            if len(self.config.simulator.transform_list) > 0:
                transform_path = '_'.join('_'.join(str(val) for val in transform) for transform in self.config.simulator.transform_list)
                sim_file_path = f'{self.config.env.train_env}_{transform_path}_next_sim_obs_dataset.pkl'
            else:
                sim_file_path = f'{self.config.env.train_env}_next_sim_obs_dataset.pkl'

            sim_data_path = os.path.join(sim_data_dir, sim_file_path)
            try:
                next_sim_obs_dataset = torch.load(sim_data_path)
                print(f'MOPO-Plus: loaded simulator predictions on dataset from file {sim_data_path}')
            except FileNotFoundError:
                print('MOPO-Plus: sim predictions not found. Computing on all observations in the dataset...')
                next_sim_obs_dataset = np.zeros_like(dataset['observations'])
                for i in trange(dataset['observations'].shape[0]):
                    current_state = dataset['observations'][i]
                    current_action = dataset['actions'][i]
                    next_sim_obs = calc_sim_next_obs(self.env, current_state, current_action)
                    next_sim_obs_dataset[i] = next_sim_obs

                os.makedirs(sim_data_dir, exist_ok=True)
                torch.save(next_sim_obs_dataset, sim_data_path)

            remove_hidden_dims_from_sim = True
            print(f'HyCORL with remove_hidden_dims_from_sim = {remove_hidden_dims_from_sim}, and self.hidden_dims = {self.hidden_dims}')
            if self.hidden_dims and remove_hidden_dims_from_sim:
                next_sim_obs_dataset[:, self.hidden_dims] = 0.0

            dataset['next_sim_observations'] = next_sim_obs_dataset

        if self.config.wandb.enable:
            wandb.config.update(vars(args))

        obs_dim = dataset['observations'].shape[1]
        action_dim = dataset['actions'].shape[1]

        # log
        writer = SummaryWriter(self.evaluations_path)
        writer.add_text("args", str(args))
        logger = Logger(writer)

        # import configs
        task = self.config.env.eval_env.split('-')[0]
        static_fns = get_static_fn(task)
        algo_config = get_algo_config(task)

        # create policy model
        actor_backbone = MLP(input_dim=obs_dim, hidden_dims=[256, 256])
        critic1_backbone = MLP(input_dim=obs_dim + action_dim, hidden_dims=[256, 256])
        critic2_backbone = MLP(input_dim=obs_dim + action_dim, hidden_dims=[256, 256])
        dist = DiagGaussian(
            latent_dim=getattr(actor_backbone, "output_dim"),
            output_dim=action_dim,
            unbounded=True,
            conditioned_sigma=True
        )

        actor = ActorProb(actor_backbone, dist, device)
        critic1 = Critic(critic1_backbone, device)
        critic2 = Critic(critic2_backbone, device)
        actor_optim = torch.optim.Adam(actor.parameters(), lr=args.actor_lr)
        critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args.critic_lr)
        critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args.critic_lr)

        if args.auto_alpha:
            target_entropy = args.target_entropy if args.target_entropy else -action_dim

            args.target_entropy = target_entropy

            log_alpha = torch.zeros(1, requires_grad=True, device=device)
            alpha_optim = torch.optim.Adam([log_alpha], lr=args.alpha_lr)
            args.alpha = (target_entropy, log_alpha, alpha_optim)

        # create policy
        sac_policy = SACPolicy(
            actor,
            critic1,
            critic2,
            actor_optim,
            critic1_optim,
            critic2_optim,
            action_space=self.eval_env.action_space,
            dist=dist,
            tau=args.tau,
            gamma=args.gamma,
            alpha=args.alpha,
            device=device
        )

        # create dynamics model
        if self.use_sim_model:
            dynamics_model_class = TransitionModelWithSimulator
        else:
            dynamics_model_class = TransitionModel

        dynamics_model = dynamics_model_class(obs_dim=obs_dim,
                                              action_dim=action_dim,
                                              static_fns=static_fns,
                                              lr=args.dynamics_lr,
                                              save_path=self.agent_path,
                                              penalty_coeff=args.reward_penalty_coef,
                                              device=device,
                                              **algo_config["transition_params"],
                                              )

        # create buffer
        offline_buffer = ReplayBuffer(
            buffer_size=len(dataset["observations"]),
            obs_shape=obs_dim,
            obs_dtype=np.float32,
            action_dim=action_dim,
            action_dtype=np.float32,
            use_sim_model=self.use_sim_model
        )
        offline_buffer.load_dataset(dataset)
        model_buffer = ReplayBuffer(
            buffer_size=args.rollout_batch_size * args.rollout_length * args.model_retain_epochs,
            obs_shape=obs_dim,
            obs_dtype=np.float32,
            action_dim=action_dim,
            action_dtype=np.float32
        )

        # create MOPO algo
        algo = MOPO(
            sac_policy,
            dynamics_model,
            offline_buffer=offline_buffer,
            model_buffer=model_buffer,
            reward_penalty_coef=args.reward_penalty_coef,
            rollout_length=args.rollout_length,
            batch_size=args.batch_size,
            real_ratio=args.real_ratio,
            logger=logger,
            use_sim_model=self.use_sim_model,
            simulator=self.env,
            config=self.config,
            **algo_config["mopo_params"]
        )
        # create trainer
        trainer = Trainer(
            algo,
            eval_env=self.eval_env,
            epoch=args.epoch,
            step_per_epoch=args.step_per_epoch,
            rollout_freq=args.rollout_freq,
            logger=logger,
            log_freq=args.log_freq,
            env_type=self.config.env.eval_env,
            save_path=self.agent_path,
            eval_episodes=args.eval_episodes,
            use_wandb=self.config.wandb.enable,
            hidden_dims=self.hidden_dims
        )

        # pretrain dynamics model on the whole dataset
        trainer.train_dynamics()

        # begin train
        trainer.train_policy()


class MOPOPlus_Wrapper(MOPO_Wrapper):
    """
    Based on the code in: https://github.com/junming-yang/mopo
    """
    def __init__(self, env, eval_env, config, agent_path, evaluations_path):
        super(MOPOPlus_Wrapper, self).__init__(env, eval_env, config, agent_path, evaluations_path)
        self.use_sim_model = True

    def get_args(self):
        # Origianls
        self.args = utils.Dict2Class({
            "algo-name": "mopo",
            "actor-lr": 3e-4,
            "critic-lr": 3e-4,
            "gamma": 0.99,
            "tau": 0.005,
            "alpha": 0.2,
            'auto-alpha': True,
            'target-entropy': -3,
            'alpha-lr': 3e-4,

            # dynamics model's arguments
            "dynamics-lr": 0.001,
            "n-ensembles": 7,
            "n-elites": 5,
            "reward-penalty-coef": 1.0,
            "rollout-length": 1,
            "rollout-batch-size": 50000,
            "rollout-freq": 1000,
            "model-retain-epochs": 5,
            "real-ratio": 0.05,

            "epoch": 1000,
            "step-per-epoch": 1000,
            "eval_episodes": 10,
            "batch-size": 256,
            "logdir": "log",
            "log-freq": 1000
        })

        # HalfCheetah
        if 'halfcheetah-medium-replay' in self.config.env.train_env:
            self.args.reward_penalty_coef = 0.0
            self.args.rollout_length = 5
        elif 'halfcheetah-medium-expert' in self.config.env.train_env:
            self.args.reward_penalty_coef = 0.0
            self.args.rollout_length = 5
        elif 'halfcheetah-medium' in self.config.env.train_env:
            self.args.reward_penalty_coef = 0.0
            self.args.rollout_length = 5

        # Hopper
        elif 'hopper-medium-replay' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 5
        elif 'hopper-medium-expert' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 5
        elif 'hopper-medium' in self.config.env.train_env:
            self.args.reward_penalty_coef = 5.0
            self.args.rollout_length = 5

        # Walker2d
        if 'walker2d-random' in self.config.env.train_env:
            self.args.reward_penalty_coef = 1.0
            self.args.rollout_length = 1
        elif 'walker2d-medium-replay' in self.config.env.train_env:
            self.args.reward_penalty_coef = 0.0
            self.args.rollout_length = 10
        elif 'walker2d-medium-expert' in self.config.env.train_env:
            self.args.reward_penalty_coef = 0.0
            self.args.rollout_length = 10
        elif 'walker2d-medium' in self.config.env.train_env:
            self.args.reward_penalty_coef = 0.0
            self.args.rollout_length = 10

        print(f'Initialized MOPO-Plus algorithm on environment {self.config.env.train_env} with reward-penalty-coefficient = {self.args.reward_penalty_coef} and rollout-length = {self.args.rollout_length}')
