from datetime import datetime
import argparse
import gym
gym.logger.set_level(40)
import numpy as np
import torch
import wandb
import pickle
import random
from diffuser.models.encoder_transformer import EncoderTransformer
from diffuser.training.encoder_trainer import EncoderTrainer
from d4rl.infos import REF_MIN_SCORE, REF_MAX_SCORE
import os
from tqdm import trange
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)
from tensorboardX import SummaryWriter
def seed(seed: int = 0):
  RANDOM_SEED = seed
  np.random.seed(RANDOM_SEED)
  torch.manual_seed(RANDOM_SEED)
  torch.cuda.manual_seed_all(RANDOM_SEED)
  random.seed(RANDOM_SEED)

parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='hopper-medium-expert')
parser.add_argument('--K', type=int, default=32)
parser.add_argument('--pct_traj', type=float, default=1.)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--embed_dim', type=int, default=128)
parser.add_argument('--n_layer', type=int, default=3)
parser.add_argument('--n_head', type=int, default=2)
parser.add_argument('--activation_function', type=str, default='relu')
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
parser.add_argument('--warmup_steps', type=int, default=10000)
parser.add_argument('--num_eval_episodes', type=int, default=10)
parser.add_argument('--seed', type=int, default=3333)
parser.add_argument('--max_iters', type=int, default=1000)
parser.add_argument('--z_dim', type=int, default=16)
parser.add_argument('--foresee', type=int, default=200)
parser.add_argument('--num_steps_per_iter', type=int, default=100)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--train-no-change', type=bool, default=True)
parser.add_argument('--eval-no-change', type=bool, default=True)
parser.add_argument('--subepisode', type=bool, default=True)
parser.add_argument('--phi_norm_loss_ratio', type=float, default=0.1)
parser.add_argument('--w_lr', type=float, default=0.01)
parser.add_argument('--repre_type', type=str, choices=['vec', 'dist', 'vq_vec'], default='vec')
args = parser.parse_args()


def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum

variant=vars(args)
seed(variant['seed']) # 0
device = variant.get('device', 'cuda')

env_name = variant['env_name']
max_ep_len = 1000
scale = 1000.  # normalization for rewards/returns
env = gym.make(f'{env_name}-v2')
env.reset(seed=variant['seed'])
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

# load dataset
dir_path = variant.get('dirpath', '.') # current path
# dir_path = variant.get('dirpath', './PrefDiffuser')
dataset_path = f'{dir_path}/data/{env_name}-v2.pkl'
with open(dataset_path, 'rb') as f:
    trajectories = pickle.load(f)

# save all path information into separate lists
states, traj_lens, returns = [], [], []
for path in trajectories:
    states.append(path['observations'])
    traj_lens.append(len(path['observations']))
    returns.append(path['rewards'].sum())
traj_lens, returns = np.array(traj_lens), np.array(returns)

# used for input normalization
states = np.concatenate(states, axis=0)
state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

num_timesteps = sum(traj_lens)

print('=' * 50)
print(f'Starting new experiment: {env_name}')
print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
print('=' * 50)

K = variant['K'] # 20
batch_size = variant['batch_size'] # 64
num_eval_episodes = variant['num_eval_episodes'] # 10
pct_traj = variant.get('pct_traj', 1.) # 1.

z_dim = variant['z_dim'] # 8
print(f'z_dim is: {z_dim}')
print(f"reward foresee is: {variant['foresee']}")

expert_score = REF_MAX_SCORE[f"{variant['env_name']}-v2"]
random_score = REF_MIN_SCORE[f"{variant['env_name']}-v2"]
print(f"max score is: {expert_score}, min score is {random_score}")

# only train on top pct_traj trajectories (for %BC experiment)
num_timesteps = max(int(pct_traj*num_timesteps), 1)
sorted_inds = np.argsort(returns)  # lowest to highest
num_trajectories = 1
timesteps = traj_lens[sorted_inds[-1]]
ind = len(trajectories) - 2
while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
    timesteps += traj_lens[sorted_inds[ind]]
    num_trajectories += 1
    ind -= 1
sorted_inds = sorted_inds[-num_trajectories:]

# used to reweight sampling so we sample according to timesteps instead of trajectories
p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])

def get_batch(batch_size=256, max_len=K):
    batch_inds = np.random.choice(np.arange(num_trajectories),size=batch_size,replace=True,p=p_sample)

    s, a, rtg, timesteps, mask = [], [], [], [], []
    for i in range(batch_size):
        traj = trajectories[int(sorted_inds[batch_inds[i]])]
        si = random.randint(0, traj['rewards'].shape[0] - 1)

        # get sequences from dataset
        s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim))
        a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim))
        timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
        assert not (timesteps[-1] >= max_ep_len).any()
        timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1  # padding cutoff

        if variant['train_no_change']:
            if not variant['subepisode']:
                rtg.append(discount_cumsum(traj['rewards'][0:], gamma=1.)[0].reshape(1, 1, 1).repeat(s[-1].shape[1] + 1, axis=1))
            else:
                rtg.append(discount_cumsum(traj['rewards'][si:si+variant['foresee']], gamma=1.)[0].reshape(1, 1, 1).repeat(s[-1].shape[1] + 1, axis=1))
        else:
            rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))

        if rtg[-1].shape[1] <= s[-1].shape[1]:
            assert False

        # padding and state + reward normalization
        tlen = s[-1].shape[1]
        # s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
        # s[-1] = (s[-1] - state_mean) / state_std
        # a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
        rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, max_len - tlen, 1))], axis=1) / scale
        # timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
        # mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))
        s[-1] = np.concatenate([s[-1], np.zeros((1, max_len - tlen, state_dim))], axis=1)
        s[-1] = (s[-1] - state_mean) / state_std
        a[-1] = np.concatenate([a[-1], np.ones((1, max_len - tlen, act_dim)) * -10.], axis=1)
        timesteps[-1] = np.concatenate([timesteps[-1], np.zeros((1, max_len - tlen))], axis=1)
        mask.append(np.concatenate([np.ones((1, tlen)), np.zeros((1, max_len - tlen))], axis=1))

    s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
    a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
    rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
    timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
    mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)
    return s, a, rtg, timesteps, mask

en_model = EncoderTransformer(
    state_dim=state_dim,
    act_dim=act_dim,
    hidden_size=variant['embed_dim'],
    output_size=z_dim,
    max_ep_len=max_ep_len,
    repre_type=variant['repre_type'],
    num_hidden_layers=3,
    num_attention_heads=2,
    intermediate_size=4*variant['embed_dim'],
    max_position_embeddings=1024,
    hidden_act=variant['activation_function'],
    hidden_dropout_prob=variant['dropout'],
    attention_probs_dropout_prob=variant['dropout'],
)
en_model = en_model.to(device=device)
et_optimizer = torch.optim.AdamW(en_model.parameters(), lr=1e-4, weight_decay=1e-4)

w = torch.randn(z_dim).to(device=device)
w.requires_grad = True
w_std = 0.01 * torch.randn(z_dim).to(device=device)
if (variant['repre_type'] == 'vec') or (variant['repre_type'] == 'vq_vec'):
    w_optimizer = torch.optim.AdamW([w], lr=variant["w_lr"], weight_decay=1e-4)
elif variant['repre_type'] == 'dist':
    w_std.requires_grad = True
    w_optimizer = torch.optim.AdamW([w, w_std], lr=variant["w_lr"], weight_decay=1e-4)

# encoder_path = f'saved_models/encoder_dist/{env_name}-3333-20231012020348/params_100.pt'
# saved_model = torch.load(os.path.join(dir_path, encoder_path), map_location=device)
# en_model.load_state_dict(saved_model[0])
# w = saved_model[1][0]
# w_std = saved_model[1][1]

trainer = EncoderTrainer(
    en_model=en_model,
    et_optimizer=et_optimizer,
    w=w,
    w_std=w_std,
    w_optimizer=w_optimizer,
    batch_size=batch_size,
    get_batch=get_batch,
    device=device,
    repre_type=variant['repre_type'],
    phi_norm_loss_ratio=variant["phi_norm_loss_ratio"]
)

t = datetime.now().strftime('%Y%m%d%H%M%S')
name = f"{env_name}-{variant['seed']}-{t}"
repre_type = variant['repre_type']

writer = SummaryWriter(f'./logs/train_encoder-{repre_type}-{name}')
# wandb.init(name=f'train_encoder-{repre_type}-{t}', group=f'{env_name}', project='PrefDiffuser', config=variant)

folder = f"{dir_path}/saved_models/encoder_{repre_type}/{name}"
if not os.path.exists(folder):
    os.mkdir(folder)
save_w = [w, w_std] if variant['repre_type'] == 'dist' else [w]
torch.save((en_model.state_dict(), save_w), f"./{folder}/params_0.pt")

for iter in trange(variant['max_iters'], desc='train_iteration'): # 100
    outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) # 100
    # wandb.log(outputs)
    for key, values in outputs.items():
        writer.add_scalar(key, values, global_step=iter)
    if variant['repre_type'] == 'dist':
        # wandb.log({'w': trainer.w.mean(), 'w_std': trainer.w_std.mean()})
        writer.add_scalar('training/w', trainer.w.mean(), global_step=iter)
        writer.add_scalar('training/w_std', trainer.w_std.mean(), global_step=iter)
    if iter % 200 == 0:
        torch.save((en_model.state_dict(), save_w), f"./{folder}/params_{iter}.pt")
torch.save((en_model.state_dict(), save_w), f"./{folder}/params_1000.pt")
