import argparse
import gym
import numpy as np
from itertools import count
import copy
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
from torch.distributions import Categorical


parser = argparse.ArgumentParser(description='PyTorch mountaincar attack')
parser.add_argument('--seed', type=int, default=20, metavar='N',
					help='random seed (default: 20)')
parser.add_argument('--render', action='store_true',
					help='render the environment')
parser.add_argument('--num_evals', type=int, default=250, metavar='N',
					help='Number of evaluations')
parser.add_argument('--store_all_rewards', action='store_true',
					help='store all rewards (vs just sum)')
parser.add_argument('--checkpoint', type=str,
					help='checkpoint path')
parser.add_argument('--log-interval', type=int, default=1, metavar='N',
					help='interval between training status logs (default: 10)')
parser.add_argument('--attack_eps', type=float, default=0.2, metavar='N',
					help='Attack epsilon, total')
parser.add_argument('--attack_step', type=float, default=0.01, metavar='N',
					help='Attack step size')
parser.add_argument('--steps', type=float, default=100, metavar='N',
					help='max number of steps to take in each attack')
parser.add_argument('--norm_coeff', type=float, default=1., metavar='N',
					help='norm coefficient')
parser.add_argument('--num_smoothing_points', type=int, default=128, metavar='N',
					help='Number of points to use for smoothing')
parser.add_argument('--sigma', type=float, default=.2, metavar='N',
					help='Smoothing std. dev.')
args = parser.parse_args()

class BooleanRewardWrapper(gym.RewardWrapper):
	def __init__(self, env):
		super().__init__(env)
	def reward(self, reward):
		return 0 if reward <= 0 else 1
# Use the same noise samples to smooth every element in batch, then return the average of fun accross the batch
def soft_smooth_fun(fun, batch, noise):
	big_batch = (batch.unsqueeze(1) +noise.unsqueeze(0)).reshape([batch.shape[0]*noise.shape[0]]+list(batch.shape[1:]))
	out = fun(big_batch)
	return out.reshape([batch.shape[0],noise.shape[0]]+list(out.shape[1:])).mean(dim=1)
def attack_clean(state,model,carryover_budget_squared, num_smoothing_points = 128, clean_prev_obs_tens = None, dirty_prev_obs_tens = None):
	if (clean_prev_obs_tens is not None):
		#prev_obs_tens = torch.stack(prev_obs_tens).cuda().reshape(-1).unsqueeze(0)
		clean_prev_obs_tens = torch.stack(clean_prev_obs_tens).reshape(-1).cuda().unsqueeze(0)
		dirty_prev_obs_tens = torch.cat(dirty_prev_obs_tens,dim=0).reshape(-1).cuda().unsqueeze(0)

	state = torch.tensor(state,device='cuda').float().unsqueeze(0)

	if (clean_prev_obs_tens is not None):
		obs = torch.cat([clean_prev_obs_tens,state],dim=1).detach()
	else:
		obs_lim = state.detach()
	#target = target_logits.argmin().unsqueeze(0)
	#starting_action = target_logits.argmax()
	#clean_out = policy.step_for_attack(state).detach()
	ori_state = copy.deepcopy(state.data)
	state= state.detach()
	actor = model.actor.cuda()
	critic = model.critic.cuda()
	#obj = torch.nn.CrossEntropyLoss()
	if (carryover_budget_squared <= 0):
		budget = 0
	else:
		budget = math.sqrt(carryover_budget_squared)
	step_count = args.steps
	
	for i in range(step_count):
		state.requires_grad = True
		if (clean_prev_obs_tens is not None):
			to_smooth = torch.cat([obs,state],dim=1)# pack clean history (inc. current state) with current state
			if (i == 0):
				noise = torch.randn([num_smoothing_points] +list(to_smooth.shape[1:]), device='cuda') * args.sigma
			out =  soft_smooth_fun(  lambda x: critic(x[:,:10], actor(torch.cat([dirty_prev_obs_tens.repeat(num_smoothing_points,1),x[:,10:]],dim=1)))[0], to_smooth, noise)
		else:
			to_smooth = torch.cat([obs_lim,state],dim=1)
			if (i == 0):
				noise = torch.randn([num_smoothing_points] +list(to_smooth.shape[1:]), device='cuda') * args.sigma
			out = soft_smooth_fun(  lambda x:critic(x[:,:2].repeat(1,5), actor(x[:,2:].repeat(1,5)))[0],to_smooth, noise)

		#cost = -out[0,label]

		cost = out[0] + args.norm_coeff*(state - ori_state).norm()**2
		grad, = torch.autograd.grad(inputs=state, outputs=cost)
		if (grad.norm()/state.norm() < 0.001):
			state= state.detach_()
			break
		state = state - args.attack_step*grad/grad.norm()
		if ((state - ori_state).norm() > budget):
			state = ori_state + (state - ori_state) * budget / (state - ori_state).norm()
		state[0,0].clamp_(-1.2,0.6)
		state[0,1].clamp_(-0.07,0.07)
		state= state.detach_()
		actor.zero_grad()
		critic.zero_grad()

	return state[0].cpu(), budget**2 - (state - ori_state).norm()**2

from stable_baselines3.common.vec_env import SubprocVecEnv
from stable_baselines3 import DDPG



if __name__ == '__main__':
	# Multiprocess environment
	env =  BooleanRewardWrapper(gym.make("MountainCarContinuous-v0"))

	env.seed(args.seed)
	reward_accum = []
	policy_kwargs = {}
	model = DDPG.load(args.checkpoint)
	for i_episode in range(args.num_evals):
		state = env.reset()
		ep_reward =0
		policy_rewards = []
		carryover = args.attack_eps*args.attack_eps
		frame_hist = []
		state_hist = []

		for t in range(1, 1000): 
			if (t == 1):
				observation,carryover = attack_clean(state,model,carryover, num_smoothing_points = args.num_smoothing_points)
				observation += torch.randn_like(observation) * args.sigma
				frame_hist.extend([observation]*5)
				state_hist.extend([torch.tensor(state)]*5)
			else:
				observation,carryover = attack_clean(state,model,carryover, clean_prev_obs_tens =state_hist[-4:],dirty_prev_obs_tens =frame_hist[-4:], num_smoothing_points = args.num_smoothing_points)
				observation += torch.randn_like(observation) * args.sigma
				frame_hist.append(observation)
				state_hist.append(torch.tensor(state))
			#print(carryover)
			#print(frame_hist[-5:])

			action, _ = model.predict(torch.stack(frame_hist[-5:]).reshape(-1),deterministic=True)
			#print(action)
			state, reward, done, _ = env.step(action)
			if args.render:
				env.render()
			policy_rewards.append(reward)
			ep_reward += reward
			if done:
				print(t)
				break
		if (args.store_all_rewards):
			reward_accum.append(policy_rewards)
		else:
			print(ep_reward)
			reward_accum.append(ep_reward)
		if i_episode % args.log_interval == 0:
			#print(np.array(state_hist).mean(axis=0))
			#print(np.array(state_hist).std(axis=0))
			print('Episode {}\t'.format(
				i_episode),flush=True)
	torch.save(reward_accum, args.checkpoint + '_smooth_evals_'+ str(args.num_evals) + '_attack_eps_' + str(args.attack_eps) + '_steps_' + str(args.steps) + '_attack_step_'+ str(args.attack_step)+ '_threshold_'+ str(args.norm_coeff)+  '_num_smoothing_points_'+ str(args.num_smoothing_points) +'.pth')
