import d4rl, gym
import numpy as np
import torch
import torch.nn.functional as F

from model import NewPriorModel
from sac.model import QNetwork
from sac.utils import soft_update, hard_update
import matplotlib.pyplot as plt
from torch.distributions.normal import Normal
from model import NewPriorModel
import os

def loss(prior, obs, acs, rews):
	acs = acs.to(prior.device).float()
	rews = rews.to(prior.device).float()
	mu, std = prior(obs)
	log_probs = -Normal(mu, std).log_prob(acs).sum(axis=1)
	return (log_probs * rews).mean()

def evaluate(prior, val_obs, val_acs, val_rews):
	with torch.no_grad():
		val_obs = torch.from_numpy(val_obs)
		val_acs = torch.from_numpy(val_acs)
		val_rews = torch.from_numpy(val_rews)
		l = loss(prior, val_obs, val_acs, val_rews)
		return l.mean().item()
def get_q_vals(args, data, env):
	obs, acs, next_obss, rs, terms = data['observations'], data['actions'], data['next_observations'], data['rewards'], data['terminals']
	next_acs = np.vstack((acs[1:], [acs[0]]))
	n = len(obs)
	inds = np.arange(n)
	np.random.shuffle(inds)
	critic = QNetwork(env.observation_space.shape[0], env.action_space.shape[0], 128).cuda()
	critic_target = QNetwork(env.observation_space.shape[0], env.action_space.shape[0], 128).cuda()
	hard_update(critic_target, critic)
	gamma = 0.99

	optim = torch.optim.Adam(critic.parameters(), lr=3e-4)
	batch_size = 256
	for epoch in range(args.n_epoch_qfn):
		losses = []
		for start_pos in range(0, n, batch_size):
			this_inds = inds[start_pos:start_pos + batch_size]
			s = torch.from_numpy(obs[this_inds]).cuda()
			a = torch.from_numpy(acs[this_inds]).cuda()
			next_a = torch.from_numpy(next_acs[this_inds]).cuda().float()
			next_s = torch.from_numpy(next_obss[this_inds]).cuda()
			r = torch.from_numpy(rs[this_inds]).cuda().unsqueeze(1)
			term = torch.from_numpy(terms[this_inds]).cuda().unsqueeze(1)
			with torch.no_grad():
				qf1_next_target, qf2_next_target = critic(next_s, next_a)
				min_qf_next_target = torch.min(qf1_next_target, qf2_next_target)
				next_q_value = r + ~term * gamma * (min_qf_next_target)

			qf1, qf2 = critic(s, a)
			qf1_loss = F.mse_loss(qf1, next_q_value)
			qf2_loss = F.mse_loss(qf2, next_q_value)
			optim.zero_grad()
			qf1_loss.backward()
			optim.step()
			
			optim.zero_grad()
			qf2_loss.backward()
			optim.step()
			losses.append(qf1_loss.item())
		print(f"EPOCH {epoch} LOSS {np.mean(losses)}")
	states = torch.from_numpy(obs).cuda()
	actions = torch.from_numpy(acs).cuda()
	with torch.no_grad():
		qvals, qvals2 = critic(states, actions)
	return torch.min(qvals, qvals2).cpu().numpy().squeeze()

def train_prior(prior, num_epochs, obss, acss, traj_rews, train_inds, val_inds, batch_size=256):
	print(f"Training Prior with {len(train_inds)} samples")
	for epoch in range(num_epochs):
		epoch_losses = []
		for start_pos in range(0, len(train_inds), batch_size):
			this_inds = train_inds[start_pos:start_pos + batch_size]
			obs = torch.from_numpy(obss[this_inds]).float()
			rews = torch.from_numpy(traj_rews[this_inds]).float()
			acs = torch.from_numpy(acss[this_inds]).float()
			prior.optim.zero_grad()
			l = loss(prior, obs, acs, rews) 
			l.backward()
			prior.optim.step()
			epoch_losses.append(l.item())
		val_obs, val_acs, val_rews = obss[val_inds], acss[val_inds], traj_rews[val_inds]
		val_loss = evaluate(prior, val_obs, val_acs, val_rews)
		print("EPOCH: ", epoch, "TRAIN LOSS: ", np.mean(epoch_losses), "VAL LOSS: ", val_loss)

def rollout_prior(p, env):
	rs = []
	for i in range(100):
		o = env.reset()
		done = False
		r = 0
		while not done:
			ac, _ = p(torch.from_numpy(o))#p.dist(torch.from_numpy(o)).loc.cpu().numpy()
			ac = ac.detach().cpu().numpy()
			o, rew, done, _ = env.step(ac)
			r += rew
		rs.append(r)
	return np.mean(rs)

def train(args):
	env_name = args.env_name
	env = gym.make(env_name)
	data = d4rl.qlearning_dataset(env.env)
	advs = np.array(get_q_vals(args, data, env))
	obss = data['observations']
	acss = data['actions']

	max_rew = max(advs)
	normalized_advs = advs/max_rew
	normalized_advs = np.exp(normalized_advs)

	inds = np.arange(len(obss))
	np.random.shuffle(inds)
	num_train = int(len(inds) * 0.9)
	train_inds, val_inds = inds[:num_train], inds[:num_train]
	np.random.shuffle(train_inds)
	np.random.shuffle(val_inds)
	p = NewPriorModel(env.observation_space.shape[0], env.action_space.shape[0], 256)
	train_prior(p, args.n_epoch_prior, obss, acss, normalized_advs, train_inds, val_inds, batch_size=256)
	print("Average rollout of prior", rollout_prior(p, env))
	os.mkdir(args.save_path)
	p.save(args.save_path)

if __name__ == '__main__':
	import argparse
	parser = argparse.ArgumentParser()
	parser.add_argument('env_name', type=str)
	parser.add_argument('--n_epoch_qfn', type=int, default=30)
	parser.add_argument('--n_epoch_prior', type=int, default=50)
	parser.add_argument('--save_path', type=str, default=1000)
	args = parser.parse_args()
	train(args)


# python3 train_prior.py 'halfcheetah-medium-replay-v0' --n_epoch_qfn 10 --n_epoch_prior 100 --save_path 'exp/adv_prior/test_script_cheetah'