import torch.multiprocessing as mp
if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn', force=True)
import argparse
import os
import torch
import pickle
import numpy as np
from common_args import *
import random
from tqdm import tqdm
import transformers
from easydict import EasyDict
from prompt_dt.dataset import Dataset
transformers.set_seed(42)
from torch.utils.tensorboard import SummaryWriter
from prompt_dt.prompt_decision_transformer import PromptDecisionTransformer
from prompt_dt.eval import evaluation

if __name__ == '__main__':
    '''args'''
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)
    parser = argparse.ArgumentParser()
    add_dataset_args(parser)
    add_model_args(parser)
    add_train_args(parser)
    args = EasyDict(vars(parser.parse_args()))
    if args.env == 'ml1_pick_place':
        add_ml1_pick_place_dataset_args(parser)
    elif args.env == 'cheetah_vel':
        add_cheetah_vel_dataset_args(parser)
    elif args.env == 'darkroom':
        add_darkroom_dataset_args(parser)
    parser.add_argument('--seed', type=int, default=42)
    args = EasyDict(vars(parser.parse_args()))
    print("Args: ", args)
    device = torch.device(f'{args.device}')
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)
    random.seed(args.seed)
    if args.env == 'darkroom':
        log_name = f'env_{args.env}_H{args.horizon}_g{args.gamma}_p{args.p}_prompt_dt'
    else:
        log_name = f'env_{args.env}_H{args.horizon}_g{args.gamma}_q{args.policy_quality}_prompt_dt'

    os.makedirs(f'logs_action/{log_name}', exist_ok=True)
    writer = SummaryWriter(log_dir=f'logs_action/{log_name}')

    '''dataset'''
    dataset_config = {
        'device': device,
        'horizon': args.horizon,
        'gamma': 0.8,
        'K': 20,
    }
    if args.env == 'ml1_pick_place':
        args.n_trails = 1000
    elif args.env == 'cheetah_vel':
        args.n_trails = 2000
    elif args.env == 'darkroom':
        args.n_trails = 100
    train_dataset = Dataset(build_data_filename('train', args), dataset_config)
    params = {'batch_size': 256, 'shuffle': True}
    train_loader = torch.utils.data.DataLoader(train_dataset, **params)
    args.n_trails = 4
    eval_trajs = [] # [random, expert]
    # load eval trajs
    if args.env == 'darkroom':
        for p in [0.0, 1.0]:
            args.p = p
            path_eval = build_data_filename('eval', args)
            with open(path_eval, 'rb') as f:
                eval_trajs.append(pickle.load(f))
    else:
        for policy_quality in ['20', 'best']:
            args.policy_quality = policy_quality
            path_eval = build_data_filename('eval', args)
            with open(path_eval, 'rb') as f:
                eval_trajs.append(pickle.load(f))

    '''model'''
    model = PromptDecisionTransformer(
        state_dim=args.state_dim,
        act_dim=args.action_dim,
        max_length=20,
        max_ep_len=1000,
        hidden_size=128,
        n_layer=3,
        n_head=1,
        n_inner=4 * 128,
        activation_function='relu',
        n_positions=1024,
        resid_pdrop=0.1,
        attn_pdrop=0.1,
        device=device,
    )
    model = model.to(device=device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
    
    '''loss'''
    loss_fn = torch.nn.MSELoss()

    train_global_step = 0
    best_epoch_train_loss = 1e9
    best_epoch_offline_reward = -1e9
    
    for epoch in range(args.num_epochs):

        # TRAINING
        model.train()
        epoch_train_loss = 0.0
        for batch in tqdm(train_loader, desc=f'training on Epoch {epoch}', total=len(train_loader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            _, pred_actions, _ = model(states=batch['states'], 
                                      actions=batch['actions'], 
                                      returns_to_go=batch['returns_to_go'], 
                                      timesteps=batch['timesteps'],
                                      prompt=[batch['prompt_states'], 
                                              batch['prompt_actions'], 
                                              batch['prompt_returns_to_go'], 
                                              batch['prompt_timesteps']])
            true_actions = batch['actions'].clone()
            true_actions = true_actions.reshape(-1, args.action_dim)
            pred_actions = pred_actions.reshape(-1, args.action_dim)
            optimizer.zero_grad()
            loss = loss_fn(pred_actions, true_actions)
            loss.backward()
            optimizer.step()
            writer.add_scalar('Loss/train', loss.item(), train_global_step)
            epoch_train_loss += loss.item()
            train_global_step += 1         
        # save model
        if epoch_train_loss / len(train_loader) < best_epoch_train_loss:
            torch.save(model.state_dict(), f'logs_action/{log_name}/train_sota.pt')
            best_epoch_train_loss = epoch_train_loss / len(train_loader)

        # EVALUATION
        if (epoch+1) % 10 == 0:
            model.eval()
            online_rewards, offline_rewards_random, offline_rewards_expert, optimal_rewards = evaluation(eval_trajs, model, device, epoch, args, writer=writer)
            print(f'online_rewards: {online_rewards}')
            print(f'offline_rewards_random: {offline_rewards_random}')
            print(f'offline_rewards_expert: {offline_rewards_expert}')
            print(f'optimal_rewards: {optimal_rewards}')
            offline_rewards_mean = (offline_rewards_expert.mean() + offline_rewards_random.mean()) / 2
            if offline_rewards_mean > best_epoch_offline_reward:
                torch.save(model.state_dict(), f'logs_action/{log_name}/offline_sota.pt')
                best_epoch_offline_reward = offline_rewards_mean
                # save rewards in a file 
                rewards = {
                    'online_rewards': online_rewards,
                    'offline_rewards_random': offline_rewards_random,
                    'offline_rewards_expert': offline_rewards_expert,
                    'optimal_rewards': optimal_rewards,
                }
                with open(f'logs_action/{log_name}/rewards_sota.pkl', 'wb') as f:
                    pickle.dump(rewards, f)

    writer.close()
    print("Done.")
    