import sys
sys.path.append('..')
import gym, torch
import gym_compete
import numpy as np
from src.xfeat import MaskFeatExp
from src.xstep import DGaussianStepExp, DGPStepExp
from src.xstep_feat import DGaussianStepFeatExp, DGPStepFeatExp

sys.path.append('.')
sys.path.append('..')

torch.autograd.set_detect_anomaly(True)

# Setup env and load the target agent.
# env_name = 'multicomp/YouShallNotPassHumans-v0'
# env = gym.make(env_name)
# env.seed(1)


# The original policy is tensorflow, craft a random torch policy for debugging purpose
# Input observation dim 380 and output action dim 17.
class MuJoCoPolicy(torch.nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MuJoCoPolicy, self).__init__()
        self.input_dim = input_dim
        self.model = torch.nn.Sequential()
        self.model.add_module('mlp_%d' % 0, torch.nn.Linear(input_dim, 64))
        self.model.add_module('mlp_%d' % 1, torch.nn.Linear(64, output_dim))

    def forward(self, x, hidden):
        act = self.model(x.view(1, self.input_dim))
        return act


policy = MuJoCoPolicy(380, 17)
train_idx = np.asarray([a for a in range(0, 100)])
traj_path = 'trajs/youshallnotpasshumans_v0'

## feature-level explanation: solving the mask.
# step_idx = np.asarray([a for a in range(179, 189)])
# BATCH_SIZE = 5
# N_EPOCHS = 20
# LR = 0.01
# DECAY = 0.1 # learning rate decay
# INITIALIZER = 'normal' # P initialize ['zero', 'one', 'uniform', 'normal']
# NORMALIZE_CHOICE = 'clip' # P normalize choice ['sigmoid', 'tanh', 'clip']
# UPSAMPLING_MODE = 'nearest' # up-sampling choice ['nearest', 'bilinear']
# EPSILON = 1e-8
# FUSED_CHOICE = None # fused choice ['mean', 'random', 'blur']
# NORM_CHOICE = 'l2' # loss norm choice  ['l2', 'l1', 'inf']
#
# REG_CHOICE = 'elasticnet' # regularization term to use ['l1', 'elasticnet']
# REG_COEF_1 = 1e-3 # coefficient of the shape regularization
# REG_COEF_2 = 1e-4 # coefficient of the smoothness regularization
# LAMBDA_PATIENCE = 20 # regularization multiple waiting epoch
# LAMBDA_MULTIPILER = 1.2 # regularization multipler
# ITER_THRD = 1e-3  # early stop threshold
# EARLY_STOP = 50 # early step waiting epoch
# DISP = 1 # display interval
#
# mask_explainer = MaskFeatExp(policy=policy, act_distribution='normal', input_shape=(380,),
#                              mask_shape=(380,), lr=LR, initializer=INITIALIZER, normalize_choice=NORMALIZE_CHOICE,
#                              upsampling_mode=UPSAMPLING_MODE, epsilon=EPSILON)
#
# p = mask_explainer.train(train_idx=train_idx, traj_path=traj_path, step_idx=step_idx, batch_size=BATCH_SIZE,
#                          n_epochs=N_EPOCHS, reg_choice=REG_CHOICE, reg_coef_1=REG_COEF_1, reg_coef_2=REG_COEF_2,
#                          temp=0.1, norm_choice=NORM_CHOICE, fused_choice=FUSED_CHOICE, lambda_patience=LAMBDA_PATIENCE,
#                          lambda_multiplier=LAMBDA_MULTIPILER, decay_weight=DECAY, iteration_threshold=ITER_THRD,
#                          early_stop_patience=EARLY_STOP, display_interval=DISP)

## Step-level explanation using the deep gaussian model
# HIDDENS = [64, 32, 8]
# LR = 0.001
# DECAY = 0.1 # learning rate decay
# BATCH_SIZE = 20
# N_EPOCHS = 50
# REG_WEIGHT = 0.0001 # sparse regularization on the regression weight
#
# step_explainer = DGaussianStepExp(seq_len=200, input_dim=380, hiddens=HIDDENS, input_channels=1,
#                                   likelihood_type='classification', lr=LR, encoder_type='MLP', num_class=2)
#
# model_name = 'exp_models/dgaussian'+'_'+str(LR)+'_'+str(BATCH_SIZE)+'_'+str(DECAY)+'_'+str(REG_WEIGHT)
#
# step_explainer.train(n_epoch=N_EPOCHS, train_idx=train_idx, batch_size=BATCH_SIZE, traj_path=traj_path,
#                      reg_weight=REG_WEIGHT, decay_weight=DECAY, save_path=model_name)
#
# step_explainer.load(load_path='exp_models/dgaussian_0.001_20_0.1_0.0001_epoch_10.data')
#
# step_explainer.test(test_idx=train_idx, batch_size=BATCH_SIZE, traj_path=traj_path)
#
# step_importance_score = step_explainer.get_explanations(class_id=0)
#
# print(np.argsort(step_importance_score)[::-1])

## Step-level explanation using the DGP model.
# HIDDENS = [64, 32, 8]
# LR = 0.01
# DECAY = 0.1 # learning rate decay
# BATCH_SIZE = 20
# N_EPOCHS = 10
# REG_WEIGHT = 0.001 # sparse regularization on the regression weight
# INDUCE_NUM = 300 # number of inducing points
#
# step_explainer = DGPStepExp(train_len=train_idx.shape[0], seq_len=200, input_dim=380, hiddens=HIDDENS, input_channels=1,
#                             likelihood_type='classification', lr=LR, optimizer_type='adam', n_epoch=N_EPOCHS,
#                             gamma=DECAY, num_inducing_points=INDUCE_NUM, encoder_type='MLP', num_class=2,
#                             lambda_1=REG_WEIGHT)
#
# model_name = 'exp_models/dgp'+'_'+str(LR)+'_'+str(BATCH_SIZE)+'_'+str(DECAY)+'_'+str(REG_WEIGHT)+'_'+str(INDUCE_NUM)
# # step_explainer.train(train_idx=train_idx, batch_size=BATCH_SIZE, traj_path=traj_path, save_path=model_name)
#
# step_explainer.load(load_path='exp_models/dgp_0.01_20_0.1_0.001_300_10_model.data')
#
# step_explainer.test(test_idx=train_idx, batch_size=BATCH_SIZE, traj_path=traj_path)
#
# step_importance_score = step_explainer.get_explanations(class_id=0)
#
# print(np.argsort(step_importance_score)[::-1])

## Feature and step-level explanation using the deep gaussian model + mask explanation
# HIDDENS = [64, 32, 8]
# LR = 0.01
# INITIALIZER = 'normal' # P initialize ['zero', 'one', 'uniform', 'normal']
# NORMALIZE_CHOICE = 'clip' # P normalize choice ['sigmoid', 'tanh', 'clip']
# UPSAMPLING_MODE = 'nearest' # up-sampling choice ['nearest', 'bilinear']
# EPSILON = 1e-8
# FUSED_CHOICE = None # fused choice ['mean', 'random', 'blur']
#
# STEP_BATCH_SIZE = 10
# MASK_BATCH_SIZE = 50
#
# N_EPOCHS = 10
# REG_CHOICE = 'elasticnet' # regularization term to use ['l1', 'elasticnet']
# REG_COEF_1 = 1e-3 # coefficient of the shape regularization
# REG_COEF_2 = 1e-3 # coefficient of the smoothness regularization
# NORM_CHOICE = 'l2' # loss norm choice  ['l2', 'l1', 'inf']
# DECAY = 0.1 # learning rate decay
# LAMBDA_PATIENCE = 20 # regularization multiple waiting epoch
# LAMBDA_MULTIPILER = 1.2 # regularization multipler
# ITER_THRD = 1e-3  # early stop threshold
# EARLY_STOP = 50 # early step waiting epoch
# DISP = 1 # display interval
# REG_WEIGHT = 0.001 # sparse regularization on the regression weight
#
# feat_step_explainer = DGaussianStepFeatExp(seq_len=200, input_dim=380, hiddens=HIDDENS, input_channels=1,
#                                            likelihood_type='classification', lr=LR, mask_shape=(380,),
#                                            act_distribution='normal',  policy=policy, encoder_type='MLP', num_class=2,
#                                            initializer=INITIALIZER, normalize_choice=NORMALIZE_CHOICE,
#                                            upsampling_mode=UPSAMPLING_MODE, fused_choice=FUSED_CHOICE,
#                                            epsilon=EPSILON)
#
# model_name = 'exp_models/mask_dgaussian'+'_'+str(LR)+'_'+str(FUSED_CHOICE)+'_'+str(INITIALIZER)+'_'+str(NORMALIZE_CHOICE)\
#              +'_'+str(UPSAMPLING_MODE)+'_'+str(REG_CHOICE)+'_'+str(REG_COEF_1)+'_'+str(REG_COEF_2)+'_'+str(REG_WEIGHT)\
#              +'_'+str(NORM_CHOICE)+'_'+str(STEP_BATCH_SIZE)+'_'+str(MASK_BATCH_SIZE)
#
# feat_step_explainer.train(n_epoch=N_EPOCHS, train_idx=train_idx, step_batch_size=STEP_BATCH_SIZE,
#                           mask_batch_size=MASK_BATCH_SIZE, traj_path=traj_path, reg_choice=REG_CHOICE,
#                           reg_coef_1=REG_COEF_1, reg_coef_2=REG_COEF_2, reg_coef_w=REG_WEIGHT, norm_choice=NORM_CHOICE,
#                           decay_weight=DECAY, lambda_patience=LAMBDA_PATIENCE, lambda_multiplier=LAMBDA_MULTIPILER,
#                           save_path=model_name)
#
# feat_step_explainer.load(load_path=
#                          'exp_models/mask_dgaussian_0.01_None_normal_clip_nearest_elasticnet_0.001_0.001'
#                          '_0.001_l2_10_50_10_model.data')
#
# feat_step_explainer.reward_pred_test(test_idx=train_idx, batch_size=STEP_BATCH_SIZE, traj_path=traj_path)
#
# step_importance_score, p = feat_step_explainer.get_explanations(class_id=0)
#
# print(np.argsort(step_importance_score)[::-1])

# Feature and step-level explanation using the DGP model + mask explanation
HIDDENS =  [64, 32, 8]
LR = 0.01
INITIALIZER = 'normal' # P initialize ['zero', 'one', 'uniform', 'normal']
NORMALIZE_CHOICE = 'clip' # P normalize choice ['sigmoid', 'tanh', 'clip']
UPSAMPLING_MODE = 'nearest' # up-sampling choice ['nearest', 'bilinear']
EPSILON = 1e-8
FUSED_CHOICE = None # fused choice ['mean', 'random', 'blur']
INDUCE_NUM = 300 # number of inducing points

STEP_BATCH_SIZE = 10
MASK_BATCH_SIZE = 50

N_EPOCHS = 10
REG_CHOICE = 'elasticnet' # regularization term to use ['l1', 'elasticnet']
REG_COEF_1 = 1e-3 # coefficient of the shape regularization
REG_COEF_2 = 1e-3 # coefficient of the smoothness regularization
NORM_CHOICE = 'l2' # loss norm choice  ['l2', 'l1', 'inf']
DECAY = 0.1 # learning rate decay
LAMBDA_PATIENCE = 20 # regularization multiple waiting epoch
LAMBDA_MULTIPILER = 1.2 # regularization multipler
ITER_THRD = 1e-3  # early stop threshold
EARLY_STOP = 50 # early step waiting epoch
DISP = 1 # display interval
REG_WEIGHT = 0.01 # sparse regularization on the regression weight

feat_step_explainer = DGPStepFeatExp(train_len=train_idx.shape[0], seq_len=200, input_dim=380, hiddens=HIDDENS,
                                     input_channels=1, likelihood_type='classification', lr=LR, mask_shape=(380,),
                                     policy=policy, num_inducing_points=INDUCE_NUM, fused_choice=FUSED_CHOICE,
                                     act_distribution='normal', encoder_type='MLP', num_class=2, initializer=INITIALIZER,
                                     normalize_choice=NORMALIZE_CHOICE, upsampling_mode=UPSAMPLING_MODE,
                                     epsilon=EPSILON)

model_name = 'exp_models/mask_dgp'+'_'+str(INDUCE_NUM)+'_'+str(LR)+'_'+str(FUSED_CHOICE)+'_'+str(INITIALIZER)\
             +'_'+str(NORMALIZE_CHOICE)+'_'+str(UPSAMPLING_MODE)+'_'+str(REG_CHOICE)+'_'+str(REG_COEF_1)\
             +'_'+str(REG_COEF_2)+'_'+str(REG_WEIGHT)+'_'+str(NORM_CHOICE)+'_'+str(STEP_BATCH_SIZE)\
             +'_'+str(MASK_BATCH_SIZE)

feat_step_explainer.train(n_epoch=N_EPOCHS, train_idx=train_idx, step_batch_size=STEP_BATCH_SIZE,
                          mask_batch_size=MASK_BATCH_SIZE, traj_path=traj_path, reg_choice=REG_CHOICE,
                          reg_coef_1=REG_COEF_1, reg_coef_2=REG_COEF_2, reg_coef_w=REG_WEIGHT, norm_choice=NORM_CHOICE,
                          decay_weight=DECAY, lambda_patience=LAMBDA_PATIENCE, lambda_multiplier=LAMBDA_MULTIPILER,
                          save_path=model_name)

feat_step_explainer.load(load_path='exp_models/mask_dgp_300_0.01_None_normal_clip_nearest_elasticnet_0.001_0.001'
                                   '_0.01_l2_10_50_10_model.data')

feat_step_explainer.reward_pred_test(test_idx=train_idx, batch_size=STEP_BATCH_SIZE, traj_path=traj_path)

step_importance_score, p = feat_step_explainer.get_explanations(class_id=0)

print(np.argsort(step_importance_score)[::-1])






