import torch.multiprocessing as mp

from ICLR_darkroom.utils import build_miniworld_data_filename
if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn', force=True)
import argparse
import os
import torch
import numpy as np
from common_args import *
import random
from tqdm import tqdm
import transformers
from easydict import EasyDict
from dataset import Dataset_DIT as Dataset
transformers.set_seed(42)
from torch.utils.tensorboard import SummaryWriter
from net import reward_learner

if __name__ == '__main__':
    '''args'''
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)
    parser = argparse.ArgumentParser()
    add_dataset_args(parser)
    add_model_args(parser)
    add_train_args(parser)
    parser.add_argument('--seed', type=int, default=42)
    args = EasyDict(vars(parser.parse_args()))
    print("Args: ", args)
    device = torch.device(f'{args.device}' if torch.cuda.is_available() else 'cpu')
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        torch.cuda.manual_seed_all(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(args.seed)
    random.seed(args.seed)


    '''dataset'''
    dataset_config = {
        'device': device,
        'horizon': args.horizon,
        'gamma': args.gamma,
    }
    if args.env == 'ml1_pick_place':
        args.n_trails = 1000
    elif args.env == 'cheetah_vel':
        args.n_trails = 2000
    elif args.env == 'darkroom':
        args.n_trails = 100
    path_train = build_data_filename('train', args)
    args.n_trails = 4
    path_eval = build_data_filename('eval', args)
    train_dataset = Dataset(path_train, dataset_config)
    eval_dataset = Dataset(path_eval, dataset_config)
    params = {'batch_size': 64, 'shuffle': True}
    train_loader = torch.utils.data.DataLoader(train_dataset, **params)
    params.update({'shuffle': False})
    eval_loader = torch.utils.data.DataLoader(eval_dataset, **params)


    '''model'''
    model_config = {
        'dropout': args.dropout,
        'horizon': args.horizon,
        'n_embd': args.embd,
        'n_layer': args.layer,
        'n_head': args.head,
        'state_dim': args.state_dim,
        'action_dim': args.action_dim,
        'device': device,
        'test': False,
        'type': args.type,
    }    
    model = reward_learner(model_config).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)


    '''loss'''    
    loss_fn = torch.nn.MSELoss()
    test_loss = []
    bellman_loss = None
    log_name = f'H{args.horizon}_g{args.gamma}_q{args.policy_quality}_{args.type}'
    os.makedirs(f'logs_reward/{log_name}', exist_ok=True)
    writer = SummaryWriter(log_dir=f'logs_reward/{log_name}')
    train_global_step = 0
    test_global_step = 0
    
    for epoch in range(args.num_epochs):

        # TRAINING
        model.test = False
        model.train()
        epoch_train_loss = 0.0

        if epoch > 50 and epoch % 10 == 0:
            bellman_model = reward_learner(model_config).to(device)
            bellman_model.load_state_dict(model.state_dict())
            bellman_model.test = False
            bellman_model.eval()
            bellman_loss = 0.0

        for batch in tqdm(train_loader, desc=f'training on Epoch {epoch}', total=len(train_loader)):
            batch = {k: v.to(device) for k, v in batch.items()}
            # predicted loss
            pred_rewards = model(batch) # (batch_size, horizon, 1)
            true_rewards = batch['cumulative_rewards'][..., None]
            optimizer.zero_grad()
            loss = loss_fn(pred_rewards, true_rewards)
            writer.add_scalar('Loss/loss_pred', loss.item(), train_global_step)

            # Bellman loss  #FIXME: write code to save model
            if bellman_loss is not None:
                with torch.no_grad():
                    bellman_rewards = bellman_model(batch)
                bellman_rewards = torch.cat((torch.zeros((batch['context_rewards'].shape[0],1, 1)).to(device), 
                                            bellman_rewards[:, 1:]), dim=1)
                bellman_target = batch['context_rewards'] + args.gamma * bellman_rewards
                bellman_loss = loss_fn(pred_rewards, bellman_target)
                writer.add_scalar('Loss/loss_bellman', bellman_loss.item(), train_global_step)
                loss += bellman_loss

            if train_global_step % 100 == 0:
                true_rewards, pred_rewards = true_rewards.detach().cpu().numpy(), pred_rewards.detach().cpu().numpy()
                fig = draw_pred(pred_rewards, true_rewards)
                writer.add_figure('train_prediction', fig, train_global_step)            

            loss.backward()
            optimizer.step()
            writer.add_scalar('Loss/train', loss.item(), train_global_step)
            train_global_step += 1
    

        # EVALUATION
        model.test = True
        model.eval()
        with torch.no_grad():
            for batch in tqdm(eval_loader, desc=f'evaluating on Epoch {epoch}', total=len(eval_loader)):
                batch = {k: v.to(device) for k, v in batch.items()}
                pred_rewards = model(batch)
                true_rewards = batch['cumulative_rewards'][...,None]
                true_rewards, pred_rewards = true_rewards.cpu().numpy(), pred_rewards.cpu().numpy()
                fig = draw_pred(pred_rewards, true_rewards)
                writer.add_figure(f'eval_prediction', fig, test_global_step)
                test_global_step += 1


    writer.close()
    print("Done.")
    