import pickle
import numpy as np
from dataset import SequenceDataset, SARPairs
from diffusion import GaussianDiffusion
from temporal import TemporalUnet
from model import ResidualNet, RewardMLPNet
import torch
import gym
from trainer import Trainer
import argparse
from DiffEnv import DiffEnv
from evaluation import evaluation
import d4rl

from SimpleSAC.sac_main import sac_train

DTYPE = torch.float
DEVICE = 'cuda'

def to_torch(x, dtype=None, device=None):
	dtype = dtype or DTYPE
	device = device or DEVICE
	if type(x) is dict:
		return {k: to_torch(v, dtype, device) for k, v in x.items()}
	elif torch.is_tensor(x):
		return x.to(device).type(dtype)
		# import pdb; pdb.set_trace()
	return torch.tensor(x, dtype=dtype, device=device)
def to_device(x, device='cuda'):
	if torch.is_tensor(x):
		return x.to(device)
	elif type(x) is dict:
		return {k: to_device(v, device) for k, v in x.items()}
def batch_to_device(batch, device='cuda'):
    vals = [
        to_device(getattr(batch, field), device)
        for field in batch._fields
    ]
    return type(batch)(*vals)

def evaluate(trainer,env):
    rewards = 0
    returns = to_device(0.95 * torch.ones(1, 1), 'cuda')
    obs = env.reset()[0]
    obs = [obs]
    obs = np.concatenate(obs,axis=0)
    # print(obs,returns)
    done = False
    for step in range(100):
        conditions = {0:to_torch([obs],device ='cuda')}
        samples = trainer.ema_model.conditional_sample(conditions, returns=returns)
        obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)
        obs_comb = obs_comb.reshape(-1, 2*4)
        action =  trainer.ema_model.inv_model(obs_comb)
        action = action.argmax().item()
        next_obs, reward, done, _,_ = env.step(action)
        obs = next_obs
        if done:
            break
        rewards += reward
        # print(action)
    return rewards

def train_diffuser(configs):
    env = gym.make(configs.env_name)
    with open(configs.data_path, 'rb') as f:
        data = pickle.load(f)
    dataset = SequenceDataset(data, horizon=configs.horizon)
    model = TemporalUnet(horizon=horizon,transition_dim=observation_dim,cond_dim=observation_dim).to(device=DEVICE)
    # diffusion = GaussianInvDynDiffusion(model=model,horizon=horizon,observation_dim=observation_dim,action_dim=action_dim)
    trainer = Trainer(diffusion_model=diffusion, dataset=dataset)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, required=True, help='env_name:')
    parser.add_argument('--savename', type=str, default='', help='env_name:')
    parser.add_argument('--trainrl', action='store_true')
    parser.add_argument('--trainrw', action='store_true')
    parser.add_argument('--iql', action='store_true')
    parser.add_argument('--cql', action='store_true')
    parser.add_argument('--eval', action='store_true')
    parser.add_argument('--awac', action='store_true')
    args = parser.parse_args()

    save_name = args.savename if args.savename else None
    train_steps= 1000
    discount=0.99
    horizon=5
    env_name = args.env
    env = gym.make(env_name)
    dataset = d4rl.qlearning_dataset(env)

    if "antmaze" in env_name:
        target = env.target_goal
    elif "maze2d" in env_name:
        target = np.array([1,1])
    else:
        target = None
    def return_range(dataset, max_episode_steps):
        returns, lengths = [], []
        ep_ret, ep_len = 0., 0
        for r, d in zip(dataset['rewards'], dataset['terminals']):
            ep_ret += float(r)
            ep_len += 1
            if d or ep_len == max_episode_steps:
                returns.append(ep_ret)
                lengths.append(ep_len)
                ep_ret, ep_len = 0., 0
        # returns.append(ep_ret)    # incomplete trajectory
        lengths.append(ep_len)  # but still keep track of number of steps
        assert sum(lengths) == len(dataset['rewards'])
        return min(returns), max(returns)

    def terminal_penalty(dataset, gamma):
        n = len(dataset['terminals'])
        returns = np.zeros(n)
        last = 0
        for i in reversed(range(n)):
            if dataset['terminals'][i]:
                returns[i] = -1
                last = -1 * gamma
            else:
                returns[i] = max(dataset['rewards'][i] + last, last)
                last = last * gamma
        return returns
    reward_normalization = False


    if reward_normalization:
        max_episode_steps = 1000
        min_ret, max_ret = return_range(dataset, max_episode_steps)
        dataset['rewards'] /= (max_ret - min_ret)
        dataset['rewards'] *= max_episode_steps
        reward_normalization = max_episode_steps/(max_ret - min_ret)

    observation_dim = len(dataset['observations'][0])
    action_dim = len(dataset['actions'][0])
    dataset = SARPairs(dataset)

    t_dim = 32

    use_reward_in_diffusion = 0


    model = ResidualNet(input_dim=observation_dim + use_reward_in_diffusion + (observation_dim + action_dim) + t_dim,
                         output_dim=observation_dim + use_reward_in_diffusion,
                         t_dim=t_dim,
                         hidden_dim=1024,
                         depth=6,
                         condition_dropout=0.1).to(device=DEVICE)

    rw_model = RewardMLPNet(input_dim=observation_dim * 2 + action_dim, output_dim=2, hidden_dim=256).to(device=DEVICE)
    diffusion = GaussianDiffusion(model=model, input_dim=observation_dim + use_reward_in_diffusion, condition_dim=observation_dim + action_dim, n_timesteps=horizon).to(device=DEVICE)


    if args.trainrl:
        diffusion_env = DiffEnv(dataset, diffusion, rw_model, env_name, reward_normalization, target=target)
        sac_train(env_name, diffusion_env, save_name=save_name)
    elif args.eval:
        diffusion_env = DiffEnv(dataset, diffusion, rw_model, env_name, reward_normalization, target=target)
        evaluation(diffusion_env, env_name)
        # reward_test(diffusion_env, env_name)
    else:
        trainer = Trainer(diffusion_model=diffusion,
                          rw_model=rw_model,
                          dataset=dataset,
                          envname=env_name,
                          log_freq=1000,
                          save_freq=100000,
                          train_batch_size=1024,
                          train_lr=1e-4,
                          savename=save_name
                          )
        if args.trainrw:
            trainer.train_reward(1000000)
        else:
            trainer.train(1000000)
