
import d4rl
import gym
import numpy as np
import torch

import argparse
import pickle
import os
import pathlib
from datetime import datetime

from utils import wandb_init
import wandb

from evaluation.evaluate_episodes import evaluate_episode_rtg
from scripts.CDT_trainer import Trainer as CDTTrainer
from scripts.DT_trainer import Trainer as DTTrainer
from models.decision_transformer_contrast import DecisionTransformerContrast_SIMCLR, \
    DecisionTransformerContrast_ProductSIMCLR, DecisionTransformerContrast_SIMCLR_v2
from models.ATDT_model import DecisionTransformer
from models.GDT_model import DecisionTransformer as GDT_model
from models.critic_model import Critic

from utils import D4RLTrajectoryDataset,logger, setup_logger,set_seed,init_env, initialize_q_network
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import h5py

import wandb
from datetime import datetime



def save_checkpoint(state, name):
    filename = name
    torch.save(state, filename)





def experiment(
        exp_prefix,
        variant,
):
    device = variant.get('device', 'cuda')

    env_name, dataset = variant['env'], variant['dataset']
    # save the variant
    env,max_ep_len,scale,dversion,gym_name =init_env(env_name, dataset)

    variant['max_ep_len'] = max_ep_len
    variant['scale'] = scale

    if not os.path.exists(os.path.join(variant['save_path'], exp_prefix)):
        pathlib.Path(
            variant['save_path'] +
            exp_prefix).mkdir(
            parents=True,
            exist_ok=True)
        with open(os.path.join(variant['save_path'], exp_prefix, 'args.pkl'), 'wb') as f:
            pickle.dump(variant, f)

    setup_logger(exp_prefix, variant=variant, log_dir=os.path.join(variant['save_path'], exp_prefix))

    # writer = SummaryWriter(os.path.join(variant['save_path'], exp_prefix))
    writer = None


    env.seed(variant['seed'])
    set_seed(variant['seed'])

    state_dim = env.observation_space.shape[0]
    act_dim = env.action_space.shape[0]


    # Initialize the critic network
    critic = initialize_q_network(state_dim, act_dim, variant["iql_q_hiddens"], variant["iql_layernorm"], gym_name,load_state_dic=True)




    # load dataset

    dataset_path = f'data/{env_name}-{dataset}-v{dversion}.pkl'
    traj_dataset = D4RLTrajectoryDataset(dataset_path, variant['K'], scale,
                                         sample_size=variant['batch_size'] * variant['num_steps_per_iter'], normalize=variant["normalize"],critic=critic)
    traj_data_loader = DataLoader(traj_dataset, batch_size=variant['batch_size'], shuffle=True, pin_memory=True,
                                  drop_last=True)

    variant["subgoal_dim"] = state_dim
    variant["condition_dim"] = 1
    if "antmaze" in env_name:
        env_targets = [ 100*traj_dataset.max_return  ]
        variant["subgoal_dim"] = 2
        if variant["conditioning"] == "subgoal":
            variant["condition_dim"] = 2
    else:
        env_targets = [ traj_dataset.max_return * 2]

    variant['env_targets'] = env_targets

    all_states=[]
    all_actions=[]
    all_returns=[]
    advantages=[]
    for i, traj in enumerate(traj_dataset.trajectories):
        q1, q2 = critic(torch.FloatTensor((traj["observations"]*traj_dataset.state_std)+traj_dataset.state_mean), torch.FloatTensor(traj["actions"]))
        q_min = torch.minimum(q1, q2).detach().numpy().flatten()
        traj["q"] = q_min
        all_states.append(traj["observations"])
        all_actions.append(traj["actions"])
        all_returns.append(traj["returns_to_go"])
        advantages.append(traj["advantage"])



    all_states = np.concatenate(all_states, axis=0)
    all_actions = np.concatenate(all_actions, axis=0)
    all_returns = np.concatenate(all_returns, axis=0)
    advantages = np.concatenate(advantages, axis=0)
    mean_advantages = traj_dataset.mean_advantage


    num_bins = variant['num_bins']
    bin_ranges = np.linspace(min(all_returns), max(env_targets)/scale, num_bins + 1)
    adv_ranges = np.linspace(min(advantages), max(advantages) , num_bins + 1)
    low_adv = np.percentile(advantages, variant["avoidance_percentage"])
    with open(f"ret_ranges_{env_name}_{dataset}.pkl", "wb") as f:
        pickle.dump(bin_ranges, f)
    with open(f"adv_ranges_{env_name}_{dataset}.pkl", "wb") as f:
        pickle.dump(adv_ranges, f)
    # % top 20% of the returns
    top_20_percentile = np.percentile(all_returns, 80)
    low_return = np.percentile(all_returns, variant["avoidance_percentage"])
    top_adv = np.percentile(advantages, 80)
    top_return_buffer = {"states": all_states[all_returns >= top_20_percentile],
                         "actions": all_actions[all_returns >= top_20_percentile],
                         "returns_to_go": all_returns[all_returns >= top_20_percentile],
                         "advantage": advantages[all_returns >= top_20_percentile]}

    low_return_buffer = {"states": all_states[all_returns <= low_return],
                            "actions": all_actions[all_returns <= low_return],
                            "advantage": advantages[all_returns <= low_return],
                            "returns_to_go": all_returns[all_returns >= top_adv]}

    top_advantage_buffer = {"states": all_states[advantages >= top_adv],
                            "actions": all_actions[advantages >= top_adv],
                            "advantage": advantages[advantages >= top_adv],
                            "returns_to_go": all_returns[advantages >= top_adv]}

    low_advantage_buffer = {"states": all_states[advantages <= low_adv],
                            "actions": all_actions[advantages <= low_adv],
                            "advantage": advantages[advantages <= low_adv],
                            "returns_to_go": all_returns[advantages >= top_adv]}

    contrastive_data={"advantages": advantages,
                      "advantage_bin_ranges":adv_ranges,
                      "mean_adv":mean_advantages,
                      "contrastive_type":variant["contrastive_type"],
                      "top_return_buffer": top_return_buffer,
                      "low_return_buffer": low_return_buffer,
                      "top_advantage_buffer": top_advantage_buffer,
                      "low_advantage_buffer": low_advantage_buffer,
                      "num_samples_simclr":variant["num_samples_simclr"],
                      "with_soft_contrastive": variant['with_soft_contrastive']
                      }



    # save all path information into separate lists
    mode = variant.get('mode', 'normal')
    # to do: delayed reward mode

    ## get state stats from dataset
    if variant['normalize']:
        state_mean, state_std = traj_dataset.get_state_stats()
    else:
        state_mean, state_std = np.zeros(state_dim), np.ones(state_dim)


    K = variant['K']
    batch_size = variant['batch_size']
    num_eval_episodes = variant['num_eval_episodes']

    def eval_episodes(target_rew):
        def fn(model, critic):
            returns, lengths = [], []
            for _ in range(num_eval_episodes):
                with torch.no_grad():
                    ret, length, _ = evaluate_episode_rtg(
                        env,
                        state_dim,
                        act_dim,
                        model,
                        critic,
                        max_ep_len=max_ep_len,
                        scale=variant['scale'],
                        target_return=[t / variant['scale'] for t in target_rew],
                        mode=mode,
                        state_mean=state_mean,
                        state_std=state_std,
                        device=device,
                    )
                returns.append(ret)
                lengths.append(length)
            return {
                f'target_{target_rew}_return_mean': np.mean(returns),
                f'target_{target_rew}_return_std': np.std(returns),
                f'target_{target_rew}_length_mean': np.mean(lengths),
                f'target_{target_rew}_length_std': np.std(lengths),
                f'target_{target_rew}_normalized_score': env.get_normalized_score(np.mean(returns)),
            }
        return fn

    action_range = [
        float(env.action_space.low.min()) + 1e-6,
        float(env.action_space.high.max()) - 1e-6,
    ]

    Trainer = CDTTrainer
    if variant["model_type"] == 'dt':
        model = DecisionTransformer(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant['embed_dim'],
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout'],
            action_range=action_range,
        )
        Trainer = DTTrainer
    elif variant["model_type"] == 'gdt':
        model = GDT_model(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant['embed_dim'],
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout'],
            action_range=action_range,
            w=variant["w"],
            b=variant["b"],
        )
        Trainer = DTTrainer

    elif variant["model_type"] == 'dt_contrast_simclr':
        model =  DecisionTransformerContrast_SIMCLR(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant['embed_dim'],
            compress_dim= variant["embed_dim"], #128,
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            n_inner=4*variant['embed_dim'],
            n_positions=1024,
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout'],
            action_range=action_range,
    )
    elif variant["model_type"] == 'dt_contrast_simclr_product':
        model = DecisionTransformerContrast_ProductSIMCLR(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant['embed_dim'],
            compress_dim=variant["embed_dim"],  # 128,
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            n_inner=4 * variant['embed_dim'],
            n_positions=1024,
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout'],
            action_range=action_range,
        )
    elif variant["model_type"] == 'dt_contrast_simclr_v2':

        model=DecisionTransformerContrast_SIMCLR_v2(
            state_dim=state_dim,
            act_dim=act_dim,
            max_length=K,
            max_ep_len=max_ep_len,
            hidden_size=variant['embed_dim'],
            compress_dim=variant["embed_dim"],  # 128,
            n_layer=variant['n_layer'],
            n_head=variant['n_head'],
            n_inner=4 * variant['embed_dim'],
            n_positions=1024,
            resid_pdrop=variant['dropout'],
            attn_pdrop=variant['dropout'],
            action_range=action_range,
        )
    
    


    model = model.to(device=device)
    critic = critic.to(device=device)


    def get_q_loss_mean():
        try :
            ds = env.get_dataset()
            obs = ds['observations']
            actions = ds['actions']
        except Exception as e:
            print("Error getting dataset from environment, using downloaded data:", e)
            dataset_path = f'data/d4rl_data/{gym_name}.hdf5'
            # Open the HDF5 file
            with h5py.File(dataset_path, 'r') as hdf_file:
                obs = hdf_file['observations'][:]
                actions = hdf_file['actions'][:]


        dataset = TensorDataset(torch.Tensor(obs), torch.Tensor(actions))
        dataloader = DataLoader(dataset, batch_size=256, shuffle=False)

        tqdm_bar = tqdm(dataloader)

        total_q_loss = 0
        for batch_idx, (obs, act) in enumerate(tqdm_bar):
            batch_loss = 0

            obs = obs.to(device)
            act = act.to(device)

            q1, q2 = critic(obs, act)
            q_loss = torch.minimum(q1, q2).mean()

            batch_loss = q_loss.item()
            total_q_loss += q_loss.item()

            tqdm_bar.set_description('Q Loss: {:.2g}'.format(batch_loss))

        return abs(total_q_loss / (batch_idx + 1))

    # q_loss_mean = get_q_loss_mean(traj_dataset)





    model = model.to(device=device)


    trainer = Trainer(
        model=model,
        critic=None,
        batch_size=batch_size,
        tau=variant['tau'],
        discount=variant['discount'],
        dataloader=traj_data_loader,
        loss_fn=lambda s_hat, a_hat, r_hat, s, a, r: torch.mean((a_hat - a) ** 2),
        eval_fns=[eval_episodes([t]) for t in env_targets],
        eta=variant['eta'],
        lr=variant['learning_rate'],
        weight_decay=variant['weight_decay'],
        lr_decay=variant['lr_decay'],
        max_iters=variant['max_iters'],
        grad_norm=variant['grad_norm'],
        scale=scale,
        k_rewards=variant['k_rewards'],
        device=device,
        with_att_entropy_loss=variant["with_att_entropy_loss"],
        att_entropy_loss_weight=variant["att_entropy_loss_weight"],
        att_entropy_loss_weight_decay=variant["att_entropy_loss_weight_decay"],
        eval_every_n_epoch=variant["eval_every_n_epoch"],
        state_mean=state_mean,
        state_std=state_std,
        env_name=gym_name,
        q_loss_mean=0,
        q_scale=variant['q_scale'],
        q_min=variant['q_min'],
        num_steps= variant['num_steps_per_iter'],
        actor_optimizer=variant['actor_optimizer'],
        contrastive_frequency=variant['contrastive_frequency'],
        beta=variant['beta'],
        bin_ranges=bin_ranges,
        adv_ranges=adv_ranges,
        top_return_buffer=top_return_buffer,
        contrastive_data=contrastive_data,
    )

    best_ret = -10000
    best_nor_ret = -1000
    best_iter = -1
    for iter in range(variant['max_iters']):
        trainer.iter=iter
        outputs = trainer.train_iteration(num_steps=variant['num_steps_per_iter'], logger=logger,
                                          iter_num=iter + 1, log_writer=writer)
        # trainer.beta = variant['beta'] * ((variant['max_iters'] - iter) / variant['max_iters'])
        trainer.scale_down_weight(variant['lambda'])
        ret = outputs['Best_return_mean']
        nor_ret = outputs['Best_normalized_score']
        if ret >= best_ret:
            state = {
                'epoch': iter + 1,
                'actor': trainer.actor.state_dict(),
                # 'critic': trainer.critic.state_dict(),
            }
            save_checkpoint(state, os.path.join(variant['save_path'], exp_prefix, 'epoch_{}_score{}.pth'.format(iter + 1,nor_ret*100)))
            best_ret = ret
            best_nor_ret = nor_ret
            best_iter = iter + 1
        logger.log(
            f'Current best return mean is {best_ret}, normalized score is {best_nor_ret * 100}, Iteration {best_iter}')

        if variant['early_stop'] and iter >= variant['early_epoch']:
            break
    logger.log(f'The final best return mean is {best_ret}')
    logger.log(f'The final best normalized return is {best_nor_ret * 100}')
    wandb.log({'best_return_mean': best_ret, 'best_normalized_score': best_nor_ret * 100, 'best_iter': best_iter})


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', type=str, default='gym-experiment')
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--env', type=str,
                        default='hopper')  # halfcheetah, hopper, walker2d, reacher2d, pen, hammer, door, relocate, kitchen, maze2d, antmaze
    parser.add_argument('--dataset', type=str, default='medium')  # medium, medium-replay, medium-expert, expert
    parser.add_argument('--mode', type=str, default='normal')  # normal for standard setting, delayed for sparse
    parser.add_argument('--K', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--embed_dim', type=int, default=128)
    parser.add_argument('--n_layer', type=int, default=3)
    parser.add_argument('--n_head', type=int, default=2)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--conditioning', type=str, default=None)


    parser.add_argument('--learning_rate', '-lr', type=float, default=3e-4)
    parser.add_argument('--weight_decay', '-wd', type=float, default=1e-4)
    parser.add_argument('--warmup_steps', type=int, default=10000)
    parser.add_argument('--num_eval_episodes', type=int, default=10)

    parser.add_argument('--max_iters', type=int, default=500)
    parser.add_argument('--num_steps_per_iter', type=int, default=100)

    parser.add_argument('--device', type=str, default='cuda:0')
    parser.add_argument('--save_path', type=str, default='cdt_runs/')
    parser.add_argument('--eval_every_n_epoch', type=int, default=10)

    parser.add_argument("--discount", default=0.99, type=float)
    parser.add_argument("--tau", default=0.01, type=float)
    parser.add_argument("--eta", default=5, type=float)
    parser.add_argument("--lambda", default=0.99, type=float)
    parser.add_argument("--max_q_backup", action='store_true', default=False)
    parser.add_argument("--lr_decay", action='store_true', default=False)
    parser.add_argument("--grad_norm", default=0.5, type=float)
    parser.add_argument("--early_stop", action='store_true', default=False)
    parser.add_argument("--early_epoch", type=int, default=100)
    parser.add_argument("--k_rewards", action='store_true', default=False)
    parser.add_argument("--sar", action='store_true', default=False)
    parser.add_argument("--normalize", action='store_true', default=True)

    parser.add_argument("--with_att_entropy_loss", action='store_true', default=False)
    parser.add_argument("--att_entropy_loss_weight", type=float, default=0.05)
    parser.add_argument("--att_entropy_loss_weight_decay", type=float, default=2)

    parser.add_argument("--q_scale", type=float, default=0.2)
    parser.add_argument("--q_min", type=float, default=0)


    parser.add_argument("--description", type=str, default="")
    parser.add_argument("--actor_optimizer", type=str, default="adam", choices=["adam", "lamb"],)

    parser.add_argument('--num_bins', type=int, default=100)
    parser.add_argument('--beta', type=float, default=0.01)
    parser.add_argument('--contrastive_frequency', type=int, default=10)
    parser.add_argument('--num_samples_simclr', type=int, default=5)
    parser.add_argument('--compress_dim', type=int, default=50)
    parser.add_argument('--model_type', type=str, default="gdt", choices=["dt_contrast_simclr", "dt_contrast_simclr_product","dt","dt_contrast_simclr_v2"],)
    parser.add_argument('--contrastive_type', type=str, default="no_contrast", choices=["advantage", "return","no_contrast","advantage_with_avoidance","advantage_with_avoidance_and_bin_distance"],)
    parser.add_argument('--avoidance_percentage', type=int, default=20)
    parser.add_argument('--with_soft_contrastive', action='store_true', default=False)


    parser.add_argument("--iql_discount", type=float, default=0.99)
    parser.add_argument("--iql_layernorm", default=False, action='store_true')
    parser.add_argument("--iql_q_hiddens", type=int, default=2)
    parser.add_argument("--w", type=float, default=0.01)
    parser.add_argument("--b", type=float, default=-0.05)






    args = parser.parse_args()

    if (args.model_type == "dt") and  (not args.contrastive_type == "no_contrast"):
        raise ValueError("contrastive_type should be 'no_contrast'")
    group = "CDT"
    if (args.model_type == "dt"):
        group = "DT_baseline"
    if (args.model_type == "gdt"):
        group = "GDT"


    if "antmaze" in args.env:
        args.iql_layernorm = True
        args.iql_q_hiddens = 3
        args.normalize =False
        # args.embed_dim=512
        # args.discount=0.995


    config=vars(args)


    start_time = datetime.now().replace(microsecond=0)
    start_time_str = start_time.strftime("%y-%m-%d-%H-%M-%S")
    name=f"{config['env']}-{config['dataset']}-{start_time_str}"
    config.update({"name":name, "project":"ATDT","group":group})
    # config["att_entropy_loss_weight"]=att_entropy_loss_weight
    wandb_init(config)
    # Create an artifact
    artifact = wandb.Artifact(name="my-code-artifact", type="code")
    # Add the file to the artifact
    artifact.add_file(local_path="scripts/CDT_trainer.py", name="CDT_trainer.py")
    artifact.add_file(local_path="scripts/CDT.py", name="CDT.py")
    # Log the artifact
    wandb.run.log_artifact(artifact)
    experiment(name, variant=vars(args))

    wandb.finish()
