import torch.multiprocessing as mp
if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn', force=True)
import argparse
import os
import torch
import pickle
import numpy as np
from common_args import *
import random
from tqdm import tqdm
import transformers
from easydict import EasyDict
from dataset import Dataset_DIT as Dataset
transformers.set_seed(42)
from torch.utils.tensorboard import SummaryWriter
from net import action_learner
from eval import evaluation


if __name__ == '__main__':
    '''args'''
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)
    parser = argparse.ArgumentParser()
    add_dataset_args(parser)
    add_model_args(parser)
    add_train_args(parser)
    args = EasyDict(vars(parser.parse_args()))
    if args.env == 'ml1_pick_place':
        add_ml1_pick_place_dataset_args(parser)
    elif args.env == 'cheetah_vel':
        add_cheetah_vel_dataset_args(parser)
    elif args.env == 'darkroom':
        add_darkroom_dataset_args(parser)
    parser.add_argument('--seed', type=int, default=42)
    args = EasyDict(vars(parser.parse_args()))
    print("Args: ", args)
    device = torch.device(f'{args.device}')
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)
    random.seed(args.seed)
    if args.env == 'darkroom':
        log_name = f'env_{args.env}_H{args.horizon}_g{args.gamma}_p{args.p}_BC'
    else:
        log_name = f'env_{args.env}_H{args.horizon}_g{args.gamma}_q{args.policy_quality}_BC'
    os.makedirs(f'logs_action/{log_name}', exist_ok=True)
    writer = SummaryWriter(log_dir=f'logs_action/{log_name}')

    '''dataset'''
    dataset_config = {
        'device': device,
        'horizon': args.horizon,
        'gamma': 0.8,
    }

    if args.env == 'ml1_pick_place':
        args.n_trails = 1000
    elif args.env == 'cheetah_vel':
        args.n_trails = 2000
    elif args.env == 'darkroom':
        args.n_trails = 100
    train_dataset = Dataset(build_data_filename('train', args), dataset_config)
    params = {'batch_size': 256, 'shuffle': True}
    train_loader = torch.utils.data.DataLoader(train_dataset, **params)
    args.n_trails = 4
    eval_trajs = [] # [random, expert]
    # load eval trajs
    if args.env == 'darkroom':
        for p in [0.0, 1.0]:
            args.p = p
            path_eval = build_data_filename('eval', args)
            with open(path_eval, 'rb') as f:
                eval_trajs.append(pickle.load(f))
    else:
        for policy_quality in ['20', 'best']:
            args.policy_quality = policy_quality
            path_eval = build_data_filename('eval', args)
            with open(path_eval, 'rb') as f:
                eval_trajs.append(pickle.load(f))
        
    '''model'''
    model_config = {
        'test': False,
        'horizon': args.horizon,
        'n_embd': args.embd,
        'n_layer': args.layer,
        'n_head': args.head,
        'state_dim': args.state_dim,
        'action_dim': args.action_dim,
        'dropout': args.dropout,
        'device': device,
    }    
    model = action_learner(model_config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)

    '''loss'''
    loss_fn = torch.nn.MSELoss()

    train_global_step = 0
    best_epoch_train_loss = 1e9
    best_epoch_offline_reward = -1e9
    
    for epoch in range(args.num_epochs):

        # TRAINING
        model.test = False
        model.train()
        epoch_train_loss = 0.0
        for batch in tqdm(train_loader, desc=f'training on Epoch {epoch}', total=len(train_loader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            batch_idx = torch.arange(batch['context_states'].shape[0])
            batch['query_states'] = batch['context_states'][batch_idx, batch['query_idx']] # 128,100,2 -> 128,2            
            pred_actions = model(batch) # (batch_size, horizon, args.action_dim)
            true_actions = batch['context_actions'][batch_idx, batch['query_idx']]
            true_actions = true_actions[..., None, :].repeat(1, args.horizon, 1)
            optimizer.zero_grad()
            loss = loss_fn(pred_actions, true_actions)
            loss.backward()
            optimizer.step()
            writer.add_scalar('Loss/train', loss.item(), train_global_step)
            epoch_train_loss += loss.item()
            train_global_step += 1         
        # save model
        if epoch_train_loss / len(train_loader) < best_epoch_train_loss:
            torch.save(model.state_dict(), f'logs_action/{log_name}/train_sota.pt')
            best_epoch_train_loss = epoch_train_loss / len(train_loader)

        # EVALUATION
        if (epoch+1) % 10 == 0:
            model.eval()
            model.test = True
            online_rewards, offline_rewards_random, offline_rewards_expert, optimal_rewards = evaluation(eval_trajs, model, device, 1, args, writer=writer)
            print(f'online_rewards: {online_rewards}')
            print(f'offline_rewards_random: {offline_rewards_random}')
            print(f'offline_rewards_expert: {offline_rewards_expert}')
            print(f'optimal_rewards: {optimal_rewards}')
            offline_rewards_mean = (offline_rewards_expert.mean() + offline_rewards_random.mean()) / 2
            if offline_rewards_mean > best_epoch_offline_reward:
                torch.save(model.state_dict(), f'logs_action/{log_name}/offline_sota.pt')
                best_epoch_offline_reward = offline_rewards_mean
                # save rewards in a file 
                rewards = {
                    'online_rewards': online_rewards,
                    'offline_rewards_random': offline_rewards_random,
                    'offline_rewards_expert': offline_rewards_expert,
                    'optimal_rewards': optimal_rewards,
                }
                with open(f'logs_action/{log_name}/rewards_sota.pkl', 'wb') as f:
                    pickle.dump(rewards, f)

    writer.close()
    print("Done.")
    