import numpy as np

from rpo.common.initializers import init_seeds
from rpo.common.logger import Logger
import gym
import tensorflow as tf
import os
import pickle
class BaseAlg:
    """Base algorithm class for training"""

    def __init__(self,seed,env,actor,critic,runner,params):

        self.seed = seed
        self.env = env
        self.actor = actor
        self.critic = critic
        self.runner = runner
        self.params = params

        self.ac_kwargs = params['ac_kwargs']

        init_seeds(self.seed,env)
        self.logger = Logger(self.params)
    
    def _update(self,sim_total=0):
        raise NotImplementedError # Algorithm specific

    def update_old_weights(self):
        raise NotImplementedError # Algorithm specific

    def eval_episode(self):
        state = self.env.reset()
        state_raw, _, _ = self.env.get_raw()
        episodic_return = 0
        done = False
        while not done:
            action = self.actor.sample(state, mean=True)
            state, reward, done, _ = self.env.step(action)
            _, _, reward_raw_true = self.env.get_raw()

            episodic_return += reward_raw_true
        return episodic_return

    def eval_episodes(self):
        episodic_returns = []
        for i in range(10):
            episodic_returns.append(self.eval_episode())
        avg_returns = np.mean(episodic_returns)
        return avg_returns


    def record_values(self, info, step):
        for k, v in info.items():
            self.logger.add_scalar(k, v, step)


    def learn(self,sim_size,no_op_batches):         # sim_size=1000000, no_op_batches=1


        for _ in range(no_op_batches):                   # no_op_batches=1
            self.runner.generate_batch(self.env,self.actor)
            s_raw, rtg_raw = self.runner.get_env_info()
            self.env.update_rms(s_raw,rtg_raw)
            self.runner.reset()

        sim_total = 0
        while sim_total < sim_size:
            self.runner.generate_batch(self.env,self.actor)     #  s_batch
            
            log_ent = {'ent': np.squeeze(self.actor.entropy())}


            info = self._update()



            s_raw, rtg_raw = self.runner.get_env_info()
            self.env.update_rms(s_raw,rtg_raw)

            log_info = self.runner.get_log_info()
            # self.logger.log_train(log_info)

            sim_total += self.runner.steps_total
            kl_verify, length = self.runner.update(self.actor)



            if sim_total % 4096 == 0:
                episodic_return_test = self.eval_episodes()

                info["episodic_return_test"] = episodic_return_test
                info["episodic_return_train"] = log_info['J_tot']
                info["entropy"] = log_ent["ent"]
                info["kl_verify"] = kl_verify
                info["length"] = length

                print(f"Env: {self.params['env_name']} Seed: {self.params['ac_seed']} Total T: {sim_total+1} Reward: {episodic_return_test}")

                self.record_values(info, sim_total)

                current = {
                    'actor_weights': self.actor.get_weights(),
                    'critic_weights': self.critic.get_weights(),

                    's_t': self.env.s_rms.t_last,
                    's_mean': self.env.s_rms.mean,
                    's_var': self.env.s_rms.var,

                    'r_t': self.env.r_rms.t_last,
                    'r_mean': self.env.r_rms.mean,
                    'r_var': self.env.r_rms.var
                }
                self.logger.save_current(current, self.params["log_path"]+'/Model', 'steps_%s_seed_%s' % (str(sim_total), str(self.seed)))






        final = {
            'actor_weights':    self.actor.get_weights(),
            'critic_weights':   self.critic.get_weights(),

            's_t':              self.env.s_rms.t_last,
            's_mean':           self.env.s_rms.mean,
            's_var':            self.env.s_rms.var,

            'r_t':              self.env.r_rms.t_last,
            'r_mean':           self.env.r_rms.mean,
            'r_var':            self.env.r_rms.var
        }
        self.logger.log_final(final)
        
    def dump(self,params):
        self.logger.log_params(params)
        return self.logger.dump()

    def save(self,params,log_path,log_name):
        self.logger.log_params(params)
        self.logger.save(log_path,log_name)

    def eval_advantage(self, S_traj, S_f_traj, A_traj, Adv_all, rtg_all, s_reg_m, s_reg_v, r_reg_m, r_reg_v):
        Sample_idx = np.random.randint(1025, 4096, self.params["eval_sample_size"])
        S_sample = S_traj[Sample_idx]
        S_next_sample = S_traj[Sample_idx+1]
        S_f_sample = S_f_traj[Sample_idx]
        # print('S_sample', S_sample)
        A_sample = A_traj[Sample_idx]
        Adv_sample = Adv_all[Sample_idx]

        rtg_sample = rtg_all[Sample_idx]

        S_m_sample = s_reg_m[Sample_idx]
        S_v_sample = s_reg_v[Sample_idx]
        R_m_sample = r_reg_m[Sample_idx]
        R_v_sample = r_reg_v[Sample_idx]

        eval_advantage = []
        eval_V = []
        for s, s_f, a, s1, s_m, s_v, r_m, r_v in zip(S_sample, S_f_sample, A_sample, S_next_sample, S_m_sample, S_v_sample, R_m_sample, R_v_sample):
            # print('---' ,s, a)
            adv, rtg = self.eval_state_advantage(s, s_f, a, s1, s_m, s_v, r_m, r_v)
            eval_advantage.append(adv)
            eval_V.append(rtg)

        eval_advantage = np.array(eval_advantage)
        ratio_Adv = np.abs(eval_advantage-Adv_sample)/(np.abs(eval_advantage) + 1e-8)

        eval_V = np.array(eval_V)
        ratio_V = np.abs(eval_V-rtg_sample)/(np.abs(eval_V) + 1e-8)


        return eval_advantage, Adv_sample, ratio_Adv, eval_V, rtg_sample, ratio_V

    def eval_state_advantage(self, s, s_f, a, s1, s_m, s_v, r_m, r_v):
        self.env.reset()

        self.env.set_rms1(s_m, s_v, r_m, r_v)
        self.env.state_from_flattened(s_f)
        S_ba = []
        R_ba = []
        d = []
        S_ba.append(s)

        state, reward, done, _ = self.env.step(a)
        R_ba.append(reward)
        S_ba.append(state)
        d.append(done)
        while not done:
            action = self.actor.sample(state)
            state, reward, done, _ = self.env.step(action)
            R_ba.append(reward)
            S_ba.append(state)
            d.append(done)

        eva_adv, rtg = self.runner.eval_adv(S_ba, R_ba, d, self.critic)
        return eva_adv, rtg
