import torch
from Algorithms.gvf import GVF
from Algorithms.network import Network
from Algorithms.master_slave_network import MasterSlaveNetwork
from Algorithms.cnn import CNN
from Algorithms.master_slave_cnn import MasterSlaveCNN
from Algorithms.replay_buffer import ReplayBuffer


class Learner:
    def __init__(self, config):

        self.main_task_on = config['main_task_on']
        self.main_task_ind = config['main_task_ind']

        self.num_actions = config['num_actions']
        self.output_size = config['output_size']
        layer_number = config['layer_number']
        self.hidden_size = config['hidden_size']
        self.num_aux_tasks = config['num_aux_tasks']
        self.aux_type = config['aux_type']
        self.generate_and_test = config['generate_and_test']

        net_dict = {'fully_connected': Network, 'cnn': CNN,
                    'master_slave': MasterSlaveNetwork, 'master_slave_cnn': MasterSlaveCNN}

        network_config = {'input_size': self.input_size,
                          'layer_number': layer_number,
                          'hidden_size': self.hidden_size,
                          'output_size': self.output_size,
                          'num_actions': self.num_actions,
                          'num_actions': self.num_actions,
                          'num_aux_tasks': self.num_aux_tasks,
                          'mins': config['mins'],
                          'maxes': config['maxes'],
                          'head_activation': config['head_activation']}

        self.network = net_dict[config['net_type']](network_config)
        self.target_network = net_dict[config['net_type']](network_config)

        self.target_network.load_state_dict(self.network.state_dict())
        self.target_net_update_frequency = config['target_net_update_frequency']
        self.replay_buffer = ReplayBuffer(config['buffer_capacity'],
                                          self.input_size,
                                          config['buffer_batch_size'],
                                          config['replay_start_size'],
                                          config['HER_sample_state_constant'])

        self.epsilon = config['epsilon']
        self.epsilon_annealing = config['epsilon_annealing']
        self.end_epsilon = config['end_epsilon']
        self.alpha = config['alpha']
        self.gamma = config['gamma']
        self.num_replay = config['num_replay']


        if config['optimizer'] == 'RMSprop':
            self.optimizer = torch.optim.RMSprop(self.network.parameters(), lr=self.alpha)
        elif config['optimizer'] == 'Adam':
            self.optimizer = torch.optim.Adam(self.network.parameters(), lr=self.alpha)

        self.criterion = torch.nn.MSELoss()
        self.aux_weight_loss = config['aux_weight_loss']

        gvf_config = {
            'obs_size': config['obs_size'],
            'buffer_batch_size': config['buffer_batch_size'],
            'test_freq': config['test_freq'],
            'replace_rate': config['replace_rate'],
            'age_threshold': config['age_threshold'],
            'feature_size': config['hidden_size'],
            'env_name': config['env_name'],
            'pinball_random_goal_radius': config['pinball_random_goal_radius'],
            'pinball_random_goal_increment': config['pinball_random_goal_increment']
        }

        if config['aux_type'] in ['random_obs', 'random_pixel', 'GT', 'MS', 'MSGT']:

            if config['aux_type'] in ['MS', 'MSGT']:
                hidden_size = self.network.feature_per_aux
            else:
                hidden_size = self.hidden_size

            cumulant_net_config = {'input_size': self.input_size,
                                    'layer_number': layer_number,
                                    'hidden_size': hidden_size,
                                    'output_size': config['num_aux_tasks'],
                                    'num_actions': config['num_actions'],
                                    'mins': config['mins'],
                                    'maxes': config['maxes']}
            cumulant_net = net_dict[config['cumulant_net']](cumulant_net_config)
            gvf_config['cumulant_net'] = cumulant_net

        else:
            gvf_config['cumulant_net'] = None
        self.gvf = GVF(self.aux_type, self.num_aux_tasks, gvf_config)
        self.step = 0






