import argparse
from datetime import datetime
import gym
import numpy as np
import torch
import pickle
import random
from d4rl.infos import REF_MIN_SCORE, REF_MAX_SCORE
import os
import wandb
from tqdm import trange
from diffuser.utils.arrays import to_np
from diffuser.models.diffusion import GaussianInvDynDiffusion
from diffuser.models.temporal import TemporalUnet, AttTemporalUnet, TransformerNoise
from diffuser.models.encoder_transformer import EncoderTransformer
from diffuser.models.state_encoder_transformer import StateEncoderTransformer
from diffuser.training.all_trainer import AllTrainer
import warnings
warnings.simplefilter(action='ignore', category=DeprecationWarning)
from tensorboardX import SummaryWriter
import collections
from env import get_envs
from typing import Callable

parser = argparse.ArgumentParser()
parser.add_argument('--env_name', type=str, default='hopper-medium-expert')
parser.add_argument('--K', type=int, default=20)
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--learning_rate', '-lr', type=float, default=2e-4)
parser.add_argument('--seed', type=int, default=100)
parser.add_argument('--max_iters', type=int, default=2000)
parser.add_argument('--z_dim', type=int, default=16)
parser.add_argument('--num_steps_per_iter', type=int, default=1000)
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--condition_guidance_w', type=float, default=1.2)
parser.add_argument('--n_timesteps', type=int, default=200)
parser.add_argument('--repre_type', type=str, choices=['vec', 'dist', 'none'], default='dist')
parser.add_argument('--phi_norm_loss_ratio', type=float, default=0.1)
parser.add_argument('--w_lr', type=float, default=0.01)
parser.add_argument('--predict_x', action='store_true', default=False)
parser.add_argument('--use_transformer', action='store_true', default=False)
parser.add_argument('--info_loss_weight', type=float, default=0.1)
parser.add_argument('--no_scale', action='store_true', default=False)
parser.add_argument('--sg', action='store_true', default=False)
parser.add_argument('--normalize', action='store_true', default=False)
parser.add_argument('--normalize_name', type=str, default='CDF')
parser.add_argument('--pw', type=str, choices=['gaussian', 'average', 'respective'], default='respective')
args = parser.parse_args()
variant = vars(args)

def seed(seed: int = 0):
  RANDOM_SEED = seed
  np.random.seed(RANDOM_SEED)
  torch.manual_seed(RANDOM_SEED)
  torch.cuda.manual_seed_all(RANDOM_SEED)
  random.seed(RANDOM_SEED)
seed(variant['seed']) # 0

def cycle(dl):
    while True:
        for data in dl:
            yield data

device = variant.get('device', 'cuda')
sg = variant['sg'] # stop gradients
env_name = variant['env_name']
env = gym.make(f'{env_name}-v2')
env.reset(seed=variant['seed'])
max_ep_len = 1000
state_dim = env.observation_space.shape[0]
act_dim = env.action_space.shape[0]

from data.d4rl import get_dataset
dataset = get_dataset(env, max_traj_length=1000, include_next_obs=False, termination_penalty=0)
from data.sequence import SequenceDataset
if variant['no_scale']:
    returns_scale = 1
else:
    if 'halfcheetah' in variant['env_name']:
        returns_scale = 1200
    elif 'walker' in variant['env_name']:
        returns_scale = 550# if variant['K'] == 100 else 110
    elif 'hopper' in variant['env_name']:
        returns_scale = 400# if variant['K'] == 100 else 80
        
normalize_name = variant['normalize_name']
dataset = SequenceDataset(dataset, horizon=variant['K'], max_traj_length=1000, include_returns=True,
                          returns_scale=returns_scale, discount=0.99, normalize=variant['normalize'],
                          use_padding=True, normalize_name=normalize_name)
data_sampler = torch.utils.data.RandomSampler(dataset)
dataloader = cycle(torch.utils.data.DataLoader(
        dataset,
        sampler=data_sampler,
        batch_size=variant['batch_size'],
        drop_last=True,
        num_workers=8,
    ))

print('=' * 50)
print(f'Starting new experiment: {env_name}')
print('=' * 50)
pw = variant['pw']
K = variant['K']
batch_size = variant['batch_size']
z_dim = variant['z_dim'] # 8
print(f'z_dim is: {z_dim}')
repre_type = variant['repre_type']
expert_score = REF_MAX_SCORE[f"{variant['env_name']}-v2"]
random_score = REF_MIN_SCORE[f"{variant['env_name']}-v2"]

n_eval = 10
eval_envs = [gym.make(f'{env_name}-v2') for _ in range(n_eval)]

def eval_episodes(model, phi):
    avg_reward = [0. for _ in range(n_eval)]
    dones = [False for _ in range(n_eval)]
    if repre_type == 'none':
        if env_name == 'walker2d-medium-replay':
            phi = torch.tensor(0.65).to(device=device).unsqueeze(0).repeat(n_eval,1)
        elif env_name == 'hopper-medium':
            phi = torch.tensor(0.8).to(device=device).unsqueeze(0).repeat(n_eval,1)
        elif env_name == 'hopper-medium-replay':
            phi = torch.tensor(0.85).to(device=device).unsqueeze(0).repeat(n_eval,1)
        else:
            phi = torch.tensor(0.9).to(device=device).unsqueeze(0).repeat(n_eval,1)
    else:
        phi = phi.unsqueeze(0).repeat(n_eval, 1)

    obs_list = [env.reset() for env in eval_envs]
    for _ in trange(1000, desc='eval'):
        observation = np.array(obs_list)
        conditions = torch.from_numpy(observation).to(device=device, dtype=torch.float32)
        with torch.no_grad():
            samples = model.conditional_sample(conditions, returns=phi)# condition在(s_t,R)上
        obs_comb = torch.cat([samples[:, 0, :], samples[:, 1, :]], dim=-1)# [s0, s1]
        obs_comb = obs_comb.reshape(-1, 2*state_dim)
        action = model.inv_model(obs_comb)#由逆动态模型来得到action
        action = to_np(action)
        if variant['normalize']:
            action = dataset.normalizer.unnormalize(action, "actions")
        for i in range(n_eval):
            if not dones[i]:
                next_obs, reward, done, info = eval_envs[i].step(action[i])
                dones[i] = done
                avg_reward[i] += reward
                if variant['normalize']:
                    next_obs = dataset.normalizer.normalize(next_obs, 'observations')
                obs_list[i] = next_obs
            else:
                print(f'Env {i} Done.')
        if all(dones):
            print('All Env Done.')
            break
    episode_returns = np.array(avg_reward)
    norm_ret = (episode_returns - random_score) / (expert_score - random_score) * 100
    print('norm_score: ', norm_ret, np.mean(norm_ret))
    return {
            f'target_return_mean': np.mean(episode_returns),
            f'target_norm_return_mean': np.mean(norm_ret),
            f'target_norm_return_std': np.std(norm_ret),
        }

if variant['use_transformer']:
    noise_predictor = TransformerNoise(horizon=K, obs_dim=state_dim, z_dim=z_dim)
else:
    noise_predictor = TemporalUnet(horizon=K, transition_dim=state_dim, cond_dim=state_dim,
                                   dim_mults=(1,4,8),
                                   repre_type=repre_type,
                                   z_dim = z_dim,
                                  )
model = GaussianInvDynDiffusion(noise_predictor, horizon=K, observation_dim=state_dim, action_dim=act_dim,
                                condition_guidance_w=variant['condition_guidance_w'],
                                n_timesteps=variant['n_timesteps'],
                                predict_epsilon=not variant['predict_x'],
                                info_loss_weight=variant['info_loss_weight'],
                                repre_type=variant['repre_type'],
                                z_dim = z_dim,
                                pw = pw,
                               ).to(device=device)
optimizer = torch.optim.AdamW(model.parameters(), lr=variant['learning_rate'])

en_model = StateEncoderTransformer(
    state_dim=state_dim,
    act_dim=act_dim,
    hidden_size=128,
    output_size=z_dim,
    max_length=K,
    max_ep_len=max_ep_len,
    num_hidden_layers=3,
    num_attention_heads=2,
    intermediate_size=4*128,
    max_position_embeddings=1024,
    repre_type=variant['repre_type'],
    hidden_act='relu',
    hidden_dropout_prob=0.1,
    attention_probs_dropout_prob=0.1,
)
en_model = en_model.to(device=device)
et_optimizer = torch.optim.AdamW(en_model.parameters(), lr=1e-4, weight_decay=1e-4)

w = torch.randn(z_dim).to(device=device)
w.requires_grad = True
w_std = torch.randn(z_dim).to(device=device)
w_std.requires_grad = True
w_optimizer = torch.optim.AdamW([w, w_std], lr=variant["w_lr"], weight_decay=1e-4)


trainer = AllTrainer(
    en_model=en_model,
    de_model=model,
    optimizer=optimizer,
    batch_size=batch_size,
    get_batch=dataloader,
    device=device,
    et_optimizer=et_optimizer,
    w=w,
    w_std=w_std,
    w_optimizer=w_optimizer,
    repre_type=variant['repre_type'],
    phi_norm_loss_ratio=variant["phi_norm_loss_ratio"],
    info_loss_weight=variant['info_loss_weight'],
    sg=sg,
)

t = datetime.now().strftime('%Y%m%d%H%M%S')
name = f"{variant['env_name']}-s{variant['seed']}-{t}"
net_name = 'tfm' if variant['use_transformer'] else 'unet'
info_loss_weight = variant['info_loss_weight']
condition_w = variant['condition_guidance_w']
supfix = f'{info_loss_weight}info-{condition_w}guide'
normalize = variant['normalize']
writer = SummaryWriter(f'./tunelogs/{repre_type}-{net_name}-{name}-{supfix}-sg{sg}-k{K}-b{batch_size}-normal{normalize}{normalize_name}-scale{returns_scale}-z_dim{z_dim}-pw-{pw}')

folder = f"./saved_models/all_model_{repre_type}/{name}"
if not os.path.exists(folder):
    os.makedirs(folder)
# torch.save((model.state_dict(), en_model.state_dict(), w), f"./{folder}/params_0.pt")

for iter in trange(variant['max_iters'], desc='epoch'): # 100
    if iter % 200 == 0 and iter != 0:
        # if iter % 1000 == 0:
            # torch.save((trainer.ema_model.state_dict(), en_model.state_dict(), w), f"{folder}/params_{iter}.pt")
        trainer.ema_model.eval()
        outputs = eval_episodes(trainer.ema_model, w)
        for key, values in outputs.items():
            writer.add_scalar(f'evaluation/{key}', values, global_step=iter)

    outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True) # 1000
    for key, values in outputs.items():
        writer.add_scalar(key, values, global_step=iter)

outputs = eval_episodes(trainer.ema_model, w)
for key, values in outputs.items():
    writer.add_scalar(f'evaluation/{key}', values, global_step=variant['max_iters'])
# torch.save((trainer.ema_model.state_dict(), en_model.state_dict(), w), f"{folder}/params_{variant['max_iters']}.pt")