import torch
from ml_collections import ConfigDict
from copy import deepcopy
import numpy as np
import torch.nn.functional as F

from critics.test_cql import CQLCritic
from policy.sac_policy import SACPolicy
from infreastructure import pytorch_util as ptu
from infreastructure import sampler

class test(object):
    def __init__(self,env,agent_params,opt_class):

        self.env = env
        self.agent_params = agent_params
        self.batch_size = agent_params['batch_size']
        self.total = 0
        self.update_freq = agent_params['update_freq']
        self.config = agent_params['config']
        self.agent_params['q_lr'] = self.config.q_lr
        self.agent_params['po_rl'] = self.config.policy_lr
        self.agent_params['q_arch'] = self.config.q_arch
        self.agent_params['policy_arch'] = self.config.policy_arch
        self.ob_dim = agent_params['ob_dim']
        self.ac_dim = agent_params['ac_dim']
        self.p_arch = self.config.policy_arch
        self.q_arch = self.config.q_arch
        self.log_std_multiplier = agent_params['policy_log_std_multiplier']
        self.log_std_offset = agent_params['policy_log_std_offset']
        self.orthogonal_init = agent_params['orthogonal_init']
        self.po_rl = self.config.policy_lr
        optimizer_class = {
            'adam': torch.optim.Adam,
            'sgd': torch.optim.SGD,
        }[opt_class]
        self.policy = ptu.TanhGaussianPolicy(self.ob_dim,self.ac_dim,self.p_arch,self.log_std_multiplier,
                                             self.log_std_offset,self.orthogonal_init,self.config.p_dropout)
        #self.policy = SACPolicy(self.agent_params,optimizer_class,self.config.p_dropout)
        self.actor = SACPolicy(self.agent_params,optimizer_class,self.policy,self.config.p_dropout)

        self.q1_net = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, self.q_arch, self.orthogonal_init,
                                                 self.config.dropout_rate)
        self.q2_net = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, self.q_arch, self.orthogonal_init,
                                                 self.config.dropout_rate)
        self.q1_net_target = deepcopy(self.q1_net)
        self.q2_net_target = deepcopy(self.q2_net)
        self.v_net = ptu.FullConnecteNetwork(self.ob_dim, 1, self.q_arch, self.orthogonal_init)
        #self.train_policy = SACPolicy(self.agent_params, optimizer_class, self.config.p_dropout)
        #self.row_policy = SACPolicy(self.agent_params, optimizer_class, self.config.p_dropout)
        #self.policy = IQLPolicy(self.agent_params,optimizer_class,self.config.t,self.config.p_dropout)
        self.critic = CQLCritic(self.agent_params,optimizer_class,self.config,self.q1_net,self.q2_net,
                                self.q1_net_target,self.q2_net_target,self.v_net,self.config.start_beta)
        self.sampler_policy = ptu.SamplerPolicy(self.policy,agent_params['device'])
        self.eval_sampler = sampler.TrajSampler(env,agent_params['max_trajs_length'])
        self.eval_n_trajs = agent_params['eval_n_trajs']
        self.beta = self.config.start_beta

        self.max_ac = True
        self.min_ac = True

        '''
        if self.config.use_automatic_entropy_tuning:
            self.log_alpha = ptu.Scalar(0.0)
            self.alpha_optimizer = optimizer_class(
                self.log_alpha.parameters(),
                lr=self.config.policy_lr,
            )
        else:
            self.log_alpha = None
        '''

        if self.config.target_entropy >= 0.0:
            self.config.target_entropy = -np.prod(self.eval_sampler.env.action_space.shape).item()

    def train(self,batch):
        ob_n = batch['observations']
        ac_n = batch['actions']
        re_n = batch['rewards']
        next_ob_n = batch['next_observations']
        terminal_n = batch['dones']

        new_ac, log_pi = self.policy(ob_n)

        log = {}
        self.total += 1
        train_log,alpha = self.critic.update(ob_n,ac_n,next_ob_n,re_n,terminal_n,new_ac,log_pi)

        new_ac, log_pi = self.policy(ob_n)
        q_new_values = torch.min(self.q1_net_target(ob_n,new_ac),self.q2_net_target(ob_n,new_ac))
        policy_loss = self.actor.update(alpha,q_new_values,log_pi)


        if self.total % self.update_freq == 0:
            self.critic.update_target_network(self.config.soft_target_update_rate)

        if self.beta != self.config.end_beta and self.total % self.config.change_beta == 0:
            self.beta -= self.config.beta_rate
            self.beta = max(self.beta,self.config.end_beta) if self.config.start_beta > self.config.end_beta else min(self.beta,self.config.end_beta)
            self.critic.change_beta(self.beta)


        '''
        if self.total == 1:
            self.train_action,_ = self.train_policy.get_action(ob_n,deterministic=True)
            self.row_action,_ = self.row_policy.get_action(ob_n,deterministic=True)
            self.obs = ob_n

        random_q = self.critic.get_q(self.obs,self.train_action)
        random_q1 = self.critic.get_q(self.obs,self.row_action)

        max_idx = torch.argmax(random_q)
        min_idx = torch.argmin(random_q1)
        if ob_n[248].equal(self.obs[248]) and ac_n[248].equal(self.train_action[248]):
            self.max_ac = False
        if ob_n[11].equal(self.obs[11]) and ac_n[11].equal(self.row_action[11]):
            self.min_ac = False
        '''


        log['Train average reward'] = train_log['average_reward']
        log['Train max reward'] = train_log['max_reward']
        log['Train min reward'] = train_log['min_reward']
        log['Alpha loss'] = train_log['alpha_loss']
        log['Alpha'] = train_log['alpha']
        log['q1 loss'] = train_log['q1_loss']
        log['q2 loss'] = train_log['q2_loss']
        log['p loss'] = policy_loss
        log['v_values'] = train_log['v_values']
        log['true r'] = train_log['r']
        log['v0_values'] = train_log['v0_values']
        log['next_v0_values'] = train_log['next_v0_values']
        log['q0_values'] = train_log['q0_values']


        '''
        log['random_Q_mean'] = random_q.mean().item()
        log['random_Q_min'] = random_q1[11].item()#random_q[self.min_idx].item()
        log['random_Q_max'] = random_q[248].item()#random_q[self.max_idx].item()
        log['max_OOD'] = self.max_ac
        log['min_OOD'] = self.min_ac
        '''


        '''
        if self.beta != 0:
            log['rand_ac_q_max'] = train_log['rand_q_max']
            log['rand_ac_q_min'] = train_log['rand_q_min']
            log['rand_ac_q_mean'] = train_log['rand_q_mean']
            log['random_q1_0'] = train_log['random_q1_0']
            log['random_q2_0'] = train_log['random_q2_0']
        '''
        #log['a_max'] = train_log['a_max']
        #log['a_0_max'] = train_log['a_0_max']
        #log['a_d_max'] = train_log['a_d_max']
        #log['beta0_v0'] = train_log['beta0_v0']
        #log['beta0_random_q1_0'] = train_log['beta0_random_q1_0']
        #log['beta0_random_q2_0'] = train_log['beta0_random_q2_0']
        #log['beta0_q0_values'] = train_log['beta0_q0_values']
        return log

    def eval(self):
        trajs = self.eval_sampler.sample(self.sampler_policy,self.eval_n_trajs,deterministic=True)
        log = {}
        log['Eval average reward'] = np.mean([np.sum(t['rewards']) for t in trajs])
        log['average traj length'] = np.mean([len(t['rewards']) for t in trajs])
        log['average normalizd return'] = np.mean([self.eval_sampler.env.get_normalized_score(np.sum(t['rewards'])) for t in trajs]) * 100
        return log

    def save(self,path,epoch):
        policy_path = path + '\policy_{}.pt'.format(epoch)
        critic1_path = path + '\critic1_{}.pt'.format(epoch)
        critic2_path = path + '\critic2_{}.pt'.format(epoch)
        self.q1_net.save(critic1_path)
        self.q2_net.save(critic2_path)
        self.policy.save(policy_path)

    def load(self,path):
        self.policy.load(path)