import os
import yaml
import pickle
import argparse
from copy import deepcopy

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

parser = argparse.ArgumentParser()
parser.add_argument('--env-name', default='hopper-medium-v2', help='task environment name')
parser.add_argument('--force-generate', action='store_true', help='force regenerate dataset')
parser.add_argument('--gpu', default=0, type=int, help='gpu number')
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
print('Training on:', args.env_name)
print('Using GPU:', args.gpu)
print('\n')

import gym
import d4rl
import torch
import numpy as np
from tqdm import tqdm

from dataset import load_dataset
from utils import load_config, RewardModel, RL, DPO
from models import load_policy, MLP, FlowPolicy
from preflow import FlowMatching


def train_rlhf(config, device='cpu'):
    print('Training RLHF...')
    dataset = load_dataset(
        config,
        seg_len=config.data.seg_len,
        device=device,
        force_generate=args.force_generate,
        return_state=True,
    )
    policy = load_policy(
        config, 
        path=config.policy.path,
        device=device
    )
    reward_model = RewardModel(
        config.env.state_dim,
        config.env.action_dim,
        hidden_dim=config.reward_model.hidden_dim,
        num_layers=config.reward_model.num_layers,
        device=device
    )
    _, reward_losses = reward_model.fit(
        dataset, 
        num_epochs=config.reward_model.num_epochs
    )
    
    rlhf = RL(policy, reward_model, device=device)
    finetuned_policy, losses = rlhf.train(
        dataset, 
        num_epochs=config.rlhf.num_epochs,
        batch_size=config.rlhf.batch_size,
        learning_rate=config.rlhf.learning_rate,
        beta=config.rlhf.beta
    )

    print('Training complete!')
    np.save(f'./results/losses/npy/{args.env_name}_rlhf_losses.npy', losses)
    np.save(f'./results/losses/npy/{args.env_name}_reward_losses.npy', reward_losses)
    torch.save(finetuned_policy.state_dict(), f'./weights/rlhf/{args.env_name}.pth')


def train_dpo(config, device='cpu'):
    print('Training DPO...')
    dataset = load_dataset(
        config,
        seg_len=config.data.seg_len,
        device=device,
        force_generate=args.force_generate,
        return_state=True,
    )
    policy = load_policy(
        config,
        path=config.policy.path,
        device=device
    )
    dpo = DPO(policy, policy, device=device)
    finetuned_policy, losses = dpo.fit(
        dataset, 
        num_epochs=config.dpo.num_epochs,
        batch_size=config.dpo.batch_size,
        learning_rate=config.dpo.learning_rate,
        beta=config.dpo.beta
    )

    print('Training complete!')
    np.save(f'./results/losses/npy/{args.env_name}_dpo_losses.npy', losses)
    torch.save(finetuned_policy.state_dict(), f'./weights/dpo/{args.env_name}.pth')


def train_flow_matching(config, device='cpu'):
    print('Training Flow Matching...')
    dataset = load_dataset(
        config, 
        seg_len=config.data.seg_len,
        device=device, 
        force_generate=args.force_generate,
        return_state=False,
    )
    model = MLP(
        config.model.input_dim,
        hidden_dim=config.model.hidden_dim,
        context_dim=config.model.context_dim,
        time_varying=True
        
    ).to(device)
    flow = FlowMatching(model, device=device)
    
    loss_scale = torch.ones(config.data.seg_len, config.env.action_dim)
    loss_scale = (loss_scale * torch.tensor([
        config.loss_scale ** i
        for i in range(config.data.seg_len)
    ]).unsqueeze(-1)).flatten().to(device)
    
    _, losses = flow.fit(
        dataset,
        num_epochs=config.num_epochs,
        batch_size=config.batch_size,
        learning_rate=config.learning_rate,
        loss_scale=loss_scale,
        conditional=True,
        save_path=os.path.join(config.model.save_dir, f'{args.env_name}.pth'),
        save_losses_path=f'./results/losses/npy/{args.env_name}_flow_losses.npy',
    )
    
    print('Training complete!')
    np.save(f'./results/losses/npy/{args.env_name}_flow_losses.npy', losses)


def main(env_name=None, beta=None):
    
    if env_name is not None:
        args.env_name = env_name
    
    config = load_config(args)
        
    if beta is not None:
        config.rlhf.beta = beta
        config.dpo.beta = beta
    
    device = f'cuda' if torch.cuda.is_available() else 'cpu'
    print('\n')
    print('Using device:', device)

    # train_rlhf(config, device=device)
    # train_dpo(config, device=device)
    train_flow_matching(config, device=device)


if __name__ == "__main__":
    main()