import pdb
import os
import sys
sys.path.append('../pytorch-soft-actor-critic')
sys.path.append('../sac')

import torch
import torch.nn as nn
import numpy as np

from sac import SAC

class LinearPolicy(nn.Module):
    ''' linear policy for simulating data
    '''
    def __init__(self, state_dim, action_dim):
        super(LinearPolicy, self).__init__()
        self.linear = torch.nn.Linear(state_dim, action_dim)

    def forward(self, state):
        a = self.linear(state)
        a = a.data.numpy().flatten()
        return a

    def select_action(self, state, memory, evaluate=True):
        state = torch.tensor(state).float()
        a = self.linear(state)
        a = a.data.numpy().flatten()
        a = a/4
        return a

class TestPolicy(nn.Module):
    ''' linear policy for simulating data
    '''
    def __init__(self, max_action, action_dim):
        self.action_dim = action_dim
        self.max_action = max_action

    def select_action(self, state, memory, evaluate=True):
        comp = np.random.choice(2)
        means = [-self.max_action+self.max_action*6/10, self.max_action-self.max_action*6/10]
        stds = [self.max_action*2/20, self.max_action*2/20]
        a = np.random.normal(means[comp], stds[comp])
        return a

class TrainPolicy(nn.Module):
    ''' linear policy for simulating data
    '''
    def __init__(self, max_action, action_dim):
        self.action_dim = action_dim
        self.max_action = max_action

    def select_action(self, state, memory, evaluate=True):
        comp = np.random.choice(3)
        means = [-self.max_action+self.max_action*3/10, np.zeros(self.action_dim), 
            self.max_action-self.max_action*3/10]
        stds = [self.max_action*1/20, self.max_action*1/20, self.max_action*1/20]
        a = np.random.normal(means[comp], stds[comp])
        return a

def load_policy(args, state_dim, action_dim, env, store_dir):
    if args.policy_type == 'LinearRand':
        filename = args.policy_type+ '_seed' +str(args.noise_seed)+'.pt'
        filepath = os.path.join(store_dir, filename)
        policy = LinearPolicy(state_dim, action_dim)
        policy.load_state_dict(torch.load(filepath))
        policy = [policy]
    elif args.policy_type =='PureRand':
        policy = []
    else:
        actor = 'sac_actor_'+args.env+ '_' +'.pt'
        critic = 'sac_critic_'+args.env+ '_' +'.pt'
        policy_path = os.path.join(store_dir, actor)
        critic_path = os.path.join(store_dir, critic)
        agent = SAC(state_dim, env.action_space, args)
        agent.load_model(policy_path, critic_path)
        policy = [agent]
    return policy


