"""
highly based on https://github.com/kzl/decision-transformer/blob/master/gym/experiment.py
"""

import gym
import numpy as np
import torch
import wandb
import math
import argparse
import pickle
import random
import sys
import os
import mujoco_py
import d4rl
from typing import Optional
from decision_transformer.evaluation.evaluate_episodes import  evaluate_episode_rtg_dist
from decision_transformer.models.decision_transformer_dist import DecisionTransformer_Dist
from decision_transformer.training.seq_trainer import SequenceTrainer


def _get_model_state_dict(model):
    return model.module.state_dict() if hasattr(model, 'module') else model.state_dict()


def save_checkpoint(
    path: str,
    model,
    optimizer,
    scheduler,
    iter_num: int,
    global_step: int,
    config_dict: dict,
    extra: Optional[dict] = None,
):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    model_state = _get_model_state_dict(model)
    model_state_cpu = {k: v.detach().cpu() for k, v in model_state.items()}

    ckpt = {
        'model_state': model_state_cpu,
        'optimizer_state': optimizer.state_dict() if optimizer is not None else None,
        'scheduler_state': scheduler.state_dict() if scheduler is not None else None,
        'iter_num': iter_num,
        'global_step': global_step,
        'config': config_dict,
        'extra': extra or {},
    }
    torch.save(ckpt, path)
    print(f'[checkpoint] saved to {path}')


def load_checkpoint(path: str, model, optimizer=None, scheduler=None, map_location='cpu'):
    ckpt = torch.load(path, map_location=map_location)
    model.load_state_dict(ckpt['model_state'])
    if optimizer is not None and ckpt.get('optimizer_state') is not None:
        optimizer.load_state_dict(ckpt['optimizer_state'])
    if scheduler is not None and ckpt.get('scheduler_state') is not None:
        scheduler.load_state_dict(ckpt['scheduler_state'])
    print(f'[checkpoint] loaded from {path} (iter={ckpt.get("iter_num")}, step={ckpt.get("global_step")})')
    return ckpt



def discount_cumsum(x, gamma):
    discount_cumsum = np.zeros_like(x)
    discount_cumsum[-1] = x[-1]
    for t in reversed(range(x.shape[0]-1)):
        discount_cumsum[t] = x[t] + gamma * discount_cumsum[t+1]
    return discount_cumsum


def experiment(
        exp_prefix,
        variant,
):
    device = variant.get('device', 'cuda')
    log_to_wandb = variant.get('log_to_wandb', False)

    save_model = variant.get('save_model', True)
    load_model = variant.get('load_model', None)  
    model_save_dir = variant.get('model_save_dir', 'saved_models_maze')
    save_every = variant.get('save_every', 10)  

    env_name, dataset = variant['env'], variant['dataset']
    model_type = variant['model_type']
    group_name = f'{exp_prefix}-{env_name}-{dataset}'
    exp_prefix = f'{group_name}-{random.randint(int(1e5), int(1e6) - 1)}'

    if save_model:
        model_dir = os.path.join(model_save_dir, f'{env_name}_{dataset}_{model_type}_2')
        os.makedirs(model_dir, exist_ok=True)
        print(f"Models will be saved to: {model_dir}")

    env_id = f"{env_name.lower()}-{dataset.lower()}-v1"  # hopper-medium-v2 maze2d-medium-v1
    env = gym.make(env_id)
 

    if env_name == 'hopper':
        env_targets = [3600, 1800]
        scale = 1000.
        max_ep_len = 1000
    elif env_name == 'halfcheetah':
        env_targets = [12000, 6000]
        scale = 1000.
        max_ep_len = 1000
    elif env_name == 'walker2d':
        env_targets = [5000, 2500]
        scale = 1000.
        max_ep_len = 1000
    elif env_name == 'maze2d':
        max_ep_len = 999
        env_targets = [300, 200, 150, 100, 50, 20]
        scale = 10


    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]

    # load dataset
    dataset_path = f'data/{env_name}-{dataset}-v1.pkl'  # hopper-medium-v2 maze2d-medium-v1
    with open(dataset_path, 'rb') as f:
        trajectories = pickle.load(f)

    # save all path information into separate lists
    mode = variant.get('mode', 'normal')
    states, traj_lens, returns = [], [], []
    for path in trajectories:
        if mode == 'delayed':  # delayed: all rewards moved to end of trajectory
            path['rewards'][-1] = path['rewards'].sum()
            path['rewards'][:-1] = 0.
        states.append(path['observations'])
        traj_lens.append(len(path['observations']))
        returns.append(path['rewards'].sum())
    traj_lens, returns = np.array(traj_lens), np.array(returns)

    # used for input normalization
    states = np.concatenate(states, axis=0)
    state_mean, state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6

    num_timesteps = sum(traj_lens)

    print('=' * 50)
    print(f'Starting new experiment: {env_name} {dataset}')
    print(f'{len(traj_lens)} trajectories, {num_timesteps} timesteps found')
    print(f'Average return: {np.mean(returns):.2f}, std: {np.std(returns):.2f}')
    print(f'Max return: {np.max(returns):.2f}, min: {np.min(returns):.2f}')
    print('=' * 50)

    K = variant['K']
    batch_size = variant['batch_size']
    num_eval_episodes = variant['num_eval_episodes']
    pct_traj = variant.get('pct_traj', 1.)

    # only train on top pct_traj trajectories (for %BC experiment)
    num_timesteps = max(int(pct_traj*num_timesteps), 1)
    sorted_inds = np.argsort(returns)  # lowest to highest
    num_trajectories = 1
    timesteps = traj_lens[sorted_inds[-1]]
    ind = len(trajectories) - 2
    while ind >= 0 and timesteps + traj_lens[sorted_inds[ind]] <= num_timesteps:
        timesteps += traj_lens[sorted_inds[ind]]
        num_trajectories += 1
        ind -= 1
    sorted_inds = sorted_inds[-num_trajectories:]

    # used to reweight sampling so we sample according to timesteps instead of trajectories
    p_sample = traj_lens[sorted_inds] / sum(traj_lens[sorted_inds])

    def get_batch(batch_size=256, max_len=K):
        batch_inds = np.random.choice(
            np.arange(num_trajectories),
            size=batch_size,
            replace=True,
            p=p_sample,  # reweights so we sample according to timesteps
        )

        s, a, r, d, rtg, timesteps, mask = [], [], [], [], [], [], []
        for i in range(batch_size):
            traj = trajectories[int(sorted_inds[batch_inds[i]])]
            si = random.randint(0, traj['rewards'].shape[0] - 1)

            # get sequences from dataset
            s.append(traj['observations'][si:si + max_len].reshape(1, -1, state_dim))
            a.append(traj['actions'][si:si + max_len].reshape(1, -1, act_dim))
            r.append(traj['rewards'][si:si + max_len].reshape(1, -1, 1))
            if 'terminals' in traj:
                d.append(traj['terminals'][si:si + max_len].reshape(1, -1))
            else:
                d.append(traj['dones'][si:si + max_len].reshape(1, -1))
            timesteps.append(np.arange(si, si + s[-1].shape[1]).reshape(1, -1))
            timesteps[-1][timesteps[-1] >= max_ep_len] = max_ep_len-1  # padding cutoff
            rtg.append(discount_cumsum(traj['rewards'][si:], gamma=1.)[:s[-1].shape[1] + 1].reshape(1, -1, 1))
            if rtg[-1].shape[1] <= s[-1].shape[1]:
                rtg[-1] = np.concatenate([rtg[-1], np.zeros((1, 1, 1))], axis=1)

            # padding and state + reward normalization
            tlen = s[-1].shape[1]
            s[-1] = np.concatenate([np.zeros((1, max_len - tlen, state_dim)), s[-1]], axis=1)
            s[-1] = (s[-1] - state_mean) / state_std
            a[-1] = np.concatenate([np.ones((1, max_len - tlen, act_dim)) * -10., a[-1]], axis=1)
            r[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), r[-1]], axis=1)
            d[-1] = np.concatenate([np.ones((1, max_len - tlen)) * 2, d[-1]], axis=1)
            rtg[-1] = np.concatenate([np.zeros((1, max_len - tlen, 1)), rtg[-1]], axis=1) / scale
            timesteps[-1] = np.concatenate([np.zeros((1, max_len - tlen)), timesteps[-1]], axis=1)
            mask.append(np.concatenate([np.zeros((1, max_len - tlen)), np.ones((1, tlen))], axis=1))

        s = torch.from_numpy(np.concatenate(s, axis=0)).to(dtype=torch.float32, device=device)
        a = torch.from_numpy(np.concatenate(a, axis=0)).to(dtype=torch.float32, device=device)
        r = torch.from_numpy(np.concatenate(r, axis=0)).to(dtype=torch.float32, device=device)
        d = torch.from_numpy(np.concatenate(d, axis=0)).to(dtype=torch.long, device=device)
        rtg = torch.from_numpy(np.concatenate(rtg, axis=0)).to(dtype=torch.float32, device=device)
        timesteps = torch.from_numpy(np.concatenate(timesteps, axis=0)).to(dtype=torch.long, device=device)
        mask = torch.from_numpy(np.concatenate(mask, axis=0)).to(device=device)

        return s, a, r, d, rtg, timesteps, mask
    
    
    def gmm_nll_loss(pi, mu, sigma, a, mask=None):
        if a.dim() == 2:
            a = a.unsqueeze(1)  
        var = sigma * sigma + 1e-12
        
        # compute log probability components
        log_prob_components = -0.5 * (
            torch.sum((a - mu)**2 / var, dim=-1) +  
            torch.sum(torch.log(var), dim=-1) +    
            mu.shape[-1] * math.log(2 * math.pi)
        )  
    
        log_mix = torch.log(pi + 1e-12) + log_prob_components 
        log_total_prob = torch.logsumexp(log_mix, dim=-1)  
        nll = -log_total_prob  
        return nll.mean()

    def gmm_loss_fn(s_hat, a_hat, r_hat, s, a, r, mask=None):
        """
        a_hat: GMM  (pi, mu, sigma) 
        a: actions [B, T, act_dim]
        """
        if isinstance(a_hat, tuple) and len(a_hat) == 3:
            pi, mu, sigma = a_hat
            return gmm_nll_loss(pi, mu, sigma, a, mask)
        else:
            return torch.mean((a_hat - a)**2)

    def eval_episodes(target_rew):
        def fn(model):
            returns, lengths, scores = [], [], []
            for _ in range(num_eval_episodes):
                with torch.no_grad():
                    ret, length, norm_score = evaluate_episode_rtg_dist(
                        env,
                        state_dim,
                        act_dim,
                        model,
                        max_ep_len=max_ep_len,
                        scale=scale,
                        target_return=target_rew/scale,
                        mode=mode,
                        state_mean=state_mean,
                        state_std=state_std,
                        device=device,
                    )
                returns.append(ret)
                lengths.append(length)
                scores.append(norm_score)
            return {
                f'target_{target_rew}_return_mean': np.mean(returns),
                f'target_{target_rew}_return_std': np.std(returns),
                f'target_{target_rew}_length_mean': np.mean(lengths),
                f'target_{target_rew}_length_std': np.std(lengths),
                f'target_{target_rew}_norm_score_mean': np.mean(scores),
                f'target_{target_rew}_norm_score_std': np.std(scores),
            }
        return fn


    model = DecisionTransformer_Dist(
        state_dim=state_dim,
        act_dim=act_dim,
        max_length=K,
        max_ep_len=max_ep_len,
        hidden_size=variant['embed_dim'],
        n_layer=variant['n_layer'],
        n_head=variant['n_head'],
        n_inner=4*variant['embed_dim'],
        activation_function=variant['activation_function'],
        n_positions=1024,
        resid_pdrop=variant['dropout'],
        attn_pdrop=variant['dropout'],
    )
  
    global_step = 0
    best_metric_val = -float('inf')
    best_ckpt_path = None

    if load_model:
        ckpt = load_checkpoint(load_model, model, optimizer=optimizer, scheduler=scheduler, map_location='cpu')
        global_step = ckpt.get('global_step', 0)
        print(f'[resume] global_step set to {global_step}')

    model = model.to(device=device)

    warmup_steps = variant['warmup_steps']
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=variant['learning_rate'],
        weight_decay=variant['weight_decay'],
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda steps: min((steps+1)/warmup_steps, 1)
    )

   
    trainer = SequenceTrainer(
        model=model,
        optimizer=optimizer,
        batch_size=batch_size,
        get_batch=get_batch,
        scheduler=scheduler,
        loss_fn=gmm_loss_fn,
        eval_fns=[eval_episodes(tar) for tar in env_targets],
    )

    if log_to_wandb:
        wandb.init(
            name=exp_prefix,
            group=group_name,
            project='decision-transformer-dist',
            config=variant
        )

    for iter in range(variant['max_iters']):
        outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], iter_num=iter+1, print_logs=True)
        global_step += variant['num_steps_per_iter']
        if log_to_wandb:
            wandb.log(outputs)
        metric_key = variant.get('save_best_metric', None)
        current_metric_val = None
        if metric_key and metric_key in outputs:
            current_metric_val = outputs[metric_key]
        else:
            norm_keys = [k for k in outputs.keys() if k.endswith('_norm_score_mean')]
            if len(norm_keys) > 0:
                current_metric_val = float(np.mean([outputs[k] for k in norm_keys]))

        if save_model:
            model_dir = os.path.join(model_save_dir, f'{env_name}_{dataset}_{model_type}_2')
            os.makedirs(model_dir, exist_ok=True)
            if ((iter + 1) % save_every == 0):
                ckpt_path = os.path.join(model_dir, f'iter{iter+1}_step{global_step}.pt')
                save_checkpoint(
                    path=ckpt_path,
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    iter_num=iter+1,
                    global_step=global_step,
                    config_dict=variant,
                    extra={'last_outputs': outputs},
                )

            # save best model
            if current_metric_val is not None and current_metric_val > best_metric_val:
                best_metric_val = current_metric_val
                best_ckpt_path = os.path.join(model_dir, 'best.pt')
                save_checkpoint(
                    path=best_ckpt_path,
                    model=model,
                    optimizer=optimizer,
                    scheduler=scheduler,
                    iter_num=iter+1,
                    global_step=global_step,
                    config_dict=variant,
                    extra={'best_metric_key': metric_key, 'best_metric_val': best_metric_val, 'last_outputs': outputs},
                )
                print(f'[best] metric={best_metric_val:.4f} saved to {best_ckpt_path}')

        if save_model and (iter + 1 == variant['max_iters']):
            last_ckpt_path = os.path.join(model_dir, 'last.pt')
            save_checkpoint(
                path=last_ckpt_path,
                model=model,
                optimizer=optimizer,
                scheduler=scheduler,
                iter_num=iter+1,
                global_step=global_step,
                config_dict=variant,
                extra={'last_outputs': outputs, 'best_ckpt_path': best_ckpt_path, 'best_metric_val': best_metric_val},
            )


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--env', type=str, default='antmaze')
    parser.add_argument('--dataset', type=str, default='medium-play')  # medium, medium-replay, medium-expert, expert
    parser.add_argument('--mode', type=str, default='normal')  # normal for standard setting, delayed for sparse
    parser.add_argument('--K', type=int, default=20)
    parser.add_argument('--pct_traj', type=float, default=1.)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--model_type', type=str, default='dt_dist')  # dt for decision transformer, bc for behavior cloning
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--n_layer', type=int, default=3)
    parser.add_argument('--n_head', type=int, default=1)
    parser.add_argument('--activation_function', type=str, default='relu')
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--learning_rate', '-lr', type=float, default=1e-4)
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10000)
    parser.add_argument('--num_eval_episodes', type=int, default=100)
    parser.add_argument('--max_iters', type=int, default=10)
    parser.add_argument('--num_steps_per_iter', type=int, default=10000)
    parser.add_argument('--device', type=str, default='cuda')
    parser.add_argument('--log_to_wandb', '-w', type=bool, default=False)
    parser.add_argument('--save_model', type=bool, default=True)
    parser.add_argument('--model_save_dir', type=str, default='saved_models_maze')
    parser.add_argument('--save_every', type=int, default=2)  
    parser.add_argument('--load_model', type=str, default=None) 
    parser.add_argument('--save_best_metric', type=str, default='target_3600_norm_score_mean')
    
    args = parser.parse_args()

    experiment('gym-experiment', variant=vars(args))
