import os

import numpy as np
import click
import json
import torch

from rlkit.envs import ENVS
from rlkit.envs.wrappers import NormalizedBoxEnv

from configs.default import default_config

from rlkit.torch.networks import stochastic_actor2, critic
from rlkit.data_management.env_replay_buffer import SimpleReplayBuffer
import torch.optim as optim

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

class SoftActorCritic():
    def __init__(self,
                 obs_dim,
                 action_dim,
                 net_size,
                 latent_dim,
                 device,
                 env,
                 policy_lr=1e-3,
                 qf_lr=1e-3,
                 **kwargs):
        
        self.env = env
        self.replay_buffer_size = 1000000
        self.batch_size = kwargs['batch_size']
        self.max_path_length = kwargs['max_path_length']
        self.reward_scale = kwargs['reward_scale']
        self.discount = kwargs['discount']

        self.device = device

        self.qf1 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.qf2 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf1 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf2 = critic(input_dim = obs_dim+action_dim,
                          hidden_dim = net_size).to(self.device)
        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self.policy = stochastic_actor2(obs_dim,
                                       action_dim,
                                       net_size,
                                       latent_dim=latent_dim).to(self.device)
        self.policy.shared_layer.requires_grad_(True)
        
        self.policy_optimizer = optim.Adam(self.policy.parameters(),
                                                lr=policy_lr)
        self.qf_optimizer = optim.Adam(list(self.qf1.parameters())+list(self.qf2.parameters()),
                                             lr=qf_lr)

        self.target_entropy = -np.prod((action_dim,)).item()
        self.log_alpha = torch.tensor(-0.5, requires_grad=True,device=self.device)
        self.alpha = self.log_alpha.exp().item()
        self.alpha_optimizer = optim.Adam([self.log_alpha], lr=policy_lr)
        
        self.replay_buffer = SimpleReplayBuffer(self.replay_buffer_size,
                                                obs_dim,
                                                action_dim)
        self.update_step = 0
        

    def alpha_update(self,obs):
        with torch.no_grad():
            _, log_prob = self.policy.action(obs)
        alpha_loss = -(self.log_alpha * (log_prob + self.target_entropy).detach()).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.log_alpha, max_norm=0.1)
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp().item()


    def training(self,obs_buffer):
        o = self.env.reset()
        env_step = 0
        episode_return = 0
        self.update_step = 0
        while env_step<self.max_path_length:
            obs_buffer.append({"obs": o})
            env_step += 1
            a, _ = self.policy.select_action(o)
            next_o, r, d, env_info = self.env.step(a)
            episode_return += r
            self.replay_buffer.add_sample(o,a,r,d,next_o,**{'env_info':env_info})
            if self.replay_buffer.size() > self.batch_size:
                self.update()
                self.update_step += 1
            o = next_o
            if d:
                break

    def _min_q(self, obs, actions):
        q1 = self.qf1(obs, actions)
        q2 = self.qf2(obs, actions)
        min_q = torch.min(q1, q2)
        return min_q

    def soft_target_update(self, main, target, tau: float = 0.005):
        for main_param, target_param in zip(main.parameters(), target.parameters()):
            target_param.data.copy_(tau * main_param.data + (1.0 - tau) * target_param.data)

    def update(self):
        random_batch = self.replay_buffer.sample_batch(self.batch_size)
        obs = torch.Tensor(random_batch['observations']).to(self.device)
        actions = torch.Tensor(random_batch['actions']).to(self.device)
        rewards = torch.Tensor(random_batch['rewards']).to(self.device)
        terms = torch.Tensor(random_batch['terminals']).to(self.device)
        next_obs = torch.Tensor(random_batch['next_observations']).to(self.device)
    
        with torch.no_grad():
            next_action, next_log_prob = self.policy.action(next_obs)
            next_q1 = self.target_qf1(next_obs,next_action)
            next_q2 = self.target_qf2(next_obs,next_action)
            min_next_q = torch.min(next_q1,next_q2) - self.alpha * next_log_prob
            target_q = self.reward_scale*rewards + self.discount*(1-terms)*min_next_q 

        q1_pred = self.qf1(obs, actions)
        q2_pred = self.qf2(obs, actions)
        qf1_loss = (q1_pred-target_q).pow(2).mean()
        qf2_loss = (q2_pred-target_q).pow(2).mean()
        qf_loss = qf1_loss + qf2_loss
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()


        cur_actions, log_prob = self.policy.action(obs)
        min_q = torch.min(self.qf1(obs, cur_actions),
                          self.qf2(obs, cur_actions))
        policy_loss = (self.alpha * log_prob - min_q).mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.soft_target_update(self.qf1,self.target_qf1)
        self.soft_target_update(self.qf2,self.target_qf2)
        self.alpha_update(obs)

    def save_model(self,test_env,eval_return):
        torch.save(self.policy.state_dict(),f'./reference_data/{test_env}/expert_policy({int(eval_return)}).pt')

    def evaluation(self):
        with torch.no_grad():
            n_eval_epi = 10
            returns = 0
            for _ in range(n_eval_epi):
                episode_returns = 0
                env_step = 0
                o = self.env.reset()
                while env_step<self.max_path_length:
                    env_step += 1
                    a = self.policy.select_action(o,deterministic=True)
                    next_o, r, d, env_info = self.env.step(a)
                    episode_returns += r
                    o = next_o
                    if d:
                        break
                returns += episode_returns
        return returns/n_eval_epi   

    
def experiment(test_env,variant):
    env = NormalizedBoxEnv(ENVS[variant['env_name']](**variant['env_params']))
    if test_env == 'cheetah-vel':
        env.set_velocity(-2.0) # set velocity (-2)
    elif test_env == 'cheetah-dir':
        env.set_direction(-1) # set direction (backward)
    elif test_env == 'ant-goal':
        env.set_goal_position(1.5*np.pi,3) # set goal (angle = 1.5 pi, radius = 3)
    elif test_env == 'ant-dir':
        env.set_direction(1.5*np.pi) # set direction (angle = 0)
    elif test_env =='humanoid-dir':
        env.set_direction(0) # set direction (angle = 0)
    elif 'params' in test_env:
        env.set_test_task()

    obs_dim = int(np.prod(env.observation_space.shape))
    action_dim = int(np.prod(env.action_space.shape))
    latent_action_dim = variant['sr_params']['latent_action_dim']
    net_size = variant['net_size']

    # device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    algorithm = SoftActorCritic(obs_dim=obs_dim,
                                action_dim=action_dim,
                                net_size=net_size,
                                latent_dim=latent_action_dim,
                                device=device,
                                env = env,
                                **variant['algo_params'])
    
    epi_length = variant['algo_params']['max_path_length']
    prev_return = -np.inf
    obs_buffer = []
    count = 0
    for i in range(int(5000000/epi_length)):
        algorithm.training(obs_buffer)
        if (i+1) % 50 == 0:
            eval_return = algorithm.evaluation()
            print(f"Steps: {(i+1)*epi_length}, Eval_return: {eval_return}")
            if eval_return > prev_return:
                algorithm.save_model(test_env,eval_return)
                prev_return = eval_return
                print(f'[{(i+1)*epi_length}] Model saved!')
                count = 0
            else:
                count += 1
                if count > 10:
                    if 'params' in test_env:
                        if prev_return > 400:
                            break
                        else:
                            pass
                    else:
                        break
        if (i+1)*epi_length >= 1400000:
            break
    df = pd.DataFrame(obs_buffer)
    pq.write_table(pa.Table.from_pandas(df), f'./reference_data/{test_env}/replay_obs_data.parquet')
        

def deep_update_dict(fr, to):
    ''' update dict of dicts with new values '''
    # assume dicts have same keys
    for k, v in fr.items():
        if type(v) is dict:
            deep_update_dict(v, to[k])
        else:
            to[k] = v
    return to

@click.command()
@click.option('--test_env',default=None)

def main(test_env):
    config = f'./configs/{test_env}.json'
    variant = default_config
    if config:
        with open(os.path.join(config)) as f:
            exp_params = json.load(f)
        variant = deep_update_dict(exp_params, variant)
    experiment(test_env,variant)

if __name__ == "__main__":
    main()