import torch

from copy import deepcopy
import numpy as np
import os, time, wandb

from ..Revalued.diff_state_agent import Agent as RevaluedAgent
from ..BC.diff_state_agent import Agent as BCAgent
from ..Revalued.action_agent import Agent as RevaluedActionAgent
from ..TD3.agent import Agent as TD3Agent
from ..SAC_N.agent import Agent as SACAgent
from inverse_model import InverseModel
from utils.base_agent import BaseAgent


class Agent(BaseAgent):

    def __init__(self, obs_dims, action_dims, algo_name='combined', **kwargs):

        self.dm_suite = kwargs['dm_suite']

        if self.dm_suite:
            self.online_agent = RevaluedActionAgent(obs_dims=obs_dims, action_dims=action_dims, use_data=False, **kwargs)
        else:
            kwargs['critic_factor'] = 2
            self.online_agent = TD3Agent(obs_dims=obs_dims, action_dims=action_dims, use_data=False, **kwargs)
           #self.online_agent = SACAgent(obs_dims=obs_dims, action_dims=action_dims, use_data=False, **kwargs)


        kwargs['sample_type'] = 'double_q'
        kwargs['critic_factor'] = 5
        self.offline_sample_rate = kwargs['offline_sample_rate']
        self.max_sample_rate = kwargs['max_sample_rate']     
        self.min_sample_rate = kwargs['min_sample_rate']     
        self.increment_sample_rate = kwargs['increment_sample_rate']

        self.offline_agent = RevaluedAgent(obs_dims=obs_dims, action_dims=action_dims, **kwargs)
        self.offline_agent.load_model(3000000)


        self.total_it = 0


    def __getattr__(self, attr):
        if hasattr(self.online_agent, attr):
            return getattr(self.online_agent, attr)
    
    def choose_action(self, state, diff_state=None, **kwargs):
        

        offline_sample_rate = min(self.max_sample_rate,self.offline_sample_rate)
        offline_sample_rate = max(self.min_sample_rate,offline_sample_rate)

        if kwargs.get('deterministic', False) or self.rng.uniform() > offline_sample_rate:
            return self.online_agent.choose_action(state, **kwargs)
        else:
            return self.offline_agent.choose_action(state, diff_state, **kwargs)



    def choose_offline_action(self, state, diff_state=None, **kwargs):
        return self.offline_agent.choose_action(state, diff_state=diff_state, **kwargs)

    def choose_online_action(self, state, **kwargs):
        return self.online_agent.choose_action(state,**kwargs)

    def choose_diff_state(self, state, **kwargs):
        state = torch.tensor(state,dtype=torch.float).to(self.device).squeeze(0)
        diff_state = self.offline_agent.choose_diff_state(state)
        return diff_state

    def learn(self, sample_range=None, **kwargs):

        self.total_it += 1

        if self.replay_buffer.mem_cntr < self.replay_buffer.batch_size:
            return

        *samples, batch_idx = self.replay_buffer.sample(rng=self.online_agent.rng,
                                                        batch_size=self.online_agent.batch_size)


        loss = self.online_agent.learn(samples=samples)

        idm_loss = self.offline_agent.update_actor(samples=samples, agent=self.online_agent)

        self.log_dict.update(self.online_agent.log_dict)
        self.log_dict.update(self.offline_agent.log_dict)

        if (self.total_it*self.update_ratio)%100000 == 0:
            self.offline_sample_rate += self.increment_sample_rate

        return loss

