import torch.multiprocessing as mp
if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn', force=True)  # or 'forkserver'

import argparse
import os
import time
import torch
import numpy as np
import random
import pickle
from tqdm import tqdm


class MetaWorldDataset(torch.utils.data.Dataset):
    """Dataset for MetaWorld trajectories with advantage labels."""
    
    def __init__(self, data_path, horizon):
        # Load data
        with open(data_path, 'rb') as f:
            self.data = pickle.load(f)
        # Calculate dataset size
        self.length = len(self.data)
        # Extract useful dimensions from the first trajectory
        sample = self.data[0]
        self.state_dim = sample['context_states'][0].shape[-1]
        self.action_dim = sample['preferred_actions'][0].shape[-1]
        self.horizon = horizon

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        trajectory = self.data[idx]
        # # Convert to tensors
        context_states = torch.FloatTensor(np.array(trajectory['context_states']))[:self.horizon, :]
        context_next_states = torch.FloatTensor(np.array(trajectory['context_next_states']))[:self.horizon, :]
        preferred_actions = torch.FloatTensor(np.array(trajectory['preferred_actions']))[:self.horizon, :]
        non_preferred_actions = torch.FloatTensor(np.array(trajectory['non_preferred_actions']))[:self.horizon, :]
        # context_rewards = torch.FloatTensor(np.array(trajectory['context_rewards']))
        zeros = torch.zeros(context_states.shape[-1] ** 2 + preferred_actions.shape[-1] + 1) # This will give you zeros more than required 
        # sample a query state from the context states
        choice = torch.randint(0, context_states.shape[0], (1,))
        query_state = context_states[choice]
        preferred_action_label = preferred_actions[choice]
        non_preferred_action_label = non_preferred_actions[choice]

        return {
            'context_states': context_states,
            'context_next_states': context_next_states,
            'preferred_actions': preferred_actions,
            'non_preferred_actions': non_preferred_actions,
            'zeros': zeros,
            'query_states': query_state.squeeze(),
            'preferred_action_label': preferred_action_label,
            'non_preferred_action_label': non_preferred_action_label,
        }


def build_metaworld_data_filename(n_tasks, n_trajs, p_good, p_bad, mode=0):
    mode_str = 'train' if mode == 0 else 'test' if mode == 1 else 'eval'
    filename = f'datasets/metaworld_tasks{n_tasks}_trajs{n_trajs}_pg{p_good}_pb{p_bad}_{mode_str}.pkl'
    return filename


def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Train a Decision Preference Transformer on MetaWorld data")
    
    # Dataset parameters
    parser.add_argument('--train_tasks', type=list, default=[1,2,3,4,5,6,7,8,9,
                                                            11,12,13,14,15,16,17,18,19,
                                                            21,22,23,24,25,26,27,28,29,
                                                            31,32,33,34,35,36,37,38,39,
                                                            41,42,43,44,45,46,47,48,49], help="Tasks to train on")
    parser.add_argument('--eval_tasks', type=list, default=[0,10,20,30,40], help="Tasks to eval on")
    parser.add_argument('--max-episode-steps', type=int, default=200, help="Max episode steps")
    parser.add_argument('--n-trajs', type=int, default=1000, help="Number of trajectories to collect for each task")
    parser.add_argument('--p-good', type=int, default=100, help="Good policy probability")
    parser.add_argument('--p-bad', type=int, default=20, help="Bad policy probability")
    
    # Model parameters
    parser.add_argument('--embedding-dim', type=int, default=256, help="Embedding dimension")
    parser.add_argument('--num-layers', type=int, default=6, help="Number of transformer layers")
    parser.add_argument('--num-heads', type=int, default=8, help="Number of attention heads")
    
    # Training parameters
    parser.add_argument('--lambda-pr', type=float, default=10, help="Lambda for preference loss")
    parser.add_argument('--beta', type=float, default=0.01, help="Beta for preference loss")
    parser.add_argument('--learning-rate', type=float, default=1e-4, help="Learning rate")
    parser.add_argument('--batch-size', type=int, default=64, help="Batch size")
    parser.add_argument('--num-epochs', type=int, default=200, help="Number of training epochs")
    parser.add_argument('--seed', type=int, default=42, help="Random seed")

    # Evaluation parameters
    parser.add_argument('--eval-interval', type=int, default=10, help="Evaluation interval")
    parser.add_argument('--Hpes', type=int, default=50, help="Horizon")
    
    return parser.parse_args()


def construct_pref_batch(batch, model):
    pr_actions = batch['preferred_action_label'].repeat(1,batch['preferred_actions'].shape[1],1)
    npr_actions = batch['non_preferred_action_label'].repeat(1,batch['non_preferred_actions'].shape[1],1)
    preds = model(batch) # shape (batch_size, horizon, action_dim)
    mse_pr = torch.nn.functional.mse_loss(pr_actions, preds, reduction='none')
    mse_npr = torch.nn.functional.mse_loss(npr_actions, preds, reduction='none')
    return mse_pr, mse_npr


def pref_loss(mse_pr, mse_npr, beta=0.01, lambda_pr=10):
    diff = torch.sigmoid(beta*(mse_npr - lambda_pr * mse_pr)) # the bigger the better
    loss = -torch.mean(torch.log(diff).flatten()) # the smaller the better
    return loss


def eval_epoch(model, test_loader, args):
    test_loss = 0.0
    num_batches = 0
    with torch.no_grad():
        for batch_id, batch in enumerate(test_loader):
            # Move batch to device
            batch = {k: v.to('cuda') for k, v in batch.items()}
            mse_pr, mse_npr = construct_pref_batch(batch, model)
            loss = pref_loss(mse_pr, mse_npr, beta=args.beta, lambda_pr=args.lambda_pr)
            test_loss += loss.item()
            num_batches += 1
            # Optional: Add synchronization
            torch.cuda.synchronize()
    # Avoid division by zero
    return test_loss / max(1, num_batches)


def train_epoch(model, train_loader, optimizer, args):
    train_loss = 0.0
    progress_bar = tqdm(train_loader)
    num_batches = 0
    
    for batch_id, batch in enumerate(progress_bar):
        # Move batch to device
        batch = {k: v.to('cuda') for k, v in batch.items()}
        mse_pr, mse_npr = construct_pref_batch(batch, model)
        optimizer.zero_grad()
        loss = pref_loss(mse_pr, mse_npr, beta=args.beta, lambda_pr=args.lambda_pr)
        loss.backward()
        # Only uncomment if you want gradient clipping
        # torch.nn.utils.clip_grad_value_(model.parameters(), 1.0)
        optimizer.step()
        train_loss += loss.item()
        num_batches += 1
        
        # Print progress
        if batch_id % 200 == 0:
            progress_bar.set_description(f'{loss.item():.2f}')
            
        # Clear cache periodically to prevent memory issues
        if batch_id % 10 == 0:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
    
    # Avoid division by zero
    return train_loss / max(1, num_batches)


if __name__ == '__main__':
    # Parse arguments
    args = parse_args()
    print("Args: ", args)
    
    # Create necessary directories
    if not os.path.exists('models'):
        os.makedirs('models', exist_ok=True)
    
    # Set random seeds
    seed = args.seed
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    
    # Build data paths
    train_task_ids = args.train_tasks
    test_task_ids = args.train_tasks # NOTE:for testing dataset used for training, we use the same tasks for training and testing
    eval_task_ids = args.eval_tasks    
    train_path = build_metaworld_data_filename(
        len(train_task_ids), args.n_trajs, args.p_good, args.p_bad, mode=0)
    test_path = build_metaworld_data_filename(
        len(test_task_ids), 10, args.p_good, args.p_bad, mode=1)
    
    # Load datasets - keep data on CPU
    print("Loading datasets...")
    train_dataset = MetaWorldDataset(train_path, args.max_episode_steps)
    test_dataset = MetaWorldDataset(test_path, args.max_episode_steps)
    print("finished loading datasets")
    
    # Define transformer configuration
    config = {
        'horizon': args.max_episode_steps,  # Using n_trajs as an approximation for horizon
        'state_dim': train_dataset.state_dim, 
        'action_dim': train_dataset.action_dim,
        'n_layer': args.num_layers,
        'n_embd': args.embedding_dim,
        'n_head': args.num_heads,
        'dropout': False,  # Keeping your original dropout setting
        'test': False,
    }
    
    # Import the Transformer class here
    from net import Transformer
    model = Transformer(config).to('cuda')
    
    # Create data loaders with optimized settings
    train_loader_params = {
        'batch_size': args.batch_size,
        'num_workers': 0,  # Change to 0 to avoid multiprocessing issues
        'shuffle': True,
        'pin_memory': True,  # This helps with CPU to GPU transfers
    }
    test_loader_params = {
        'batch_size': args.batch_size,
        'num_workers': 0,  # Change to 0 to avoid multiprocessing issues
        'shuffle': False,
        'pin_memory': True,
    }
    
    train_loader = torch.utils.data.DataLoader(train_dataset, **train_loader_params)
    test_loader = torch.utils.data.DataLoader(test_dataset, **test_loader_params)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
    
    # Training and evaluation
    test_losses, train_losses = [], []
    
    for epoch in tqdm(range(args.num_epochs)):
        print(f'============ Epoch:{epoch} ============')
        
        # Test model first (keeping your original order)
        model.test = False
        test_loss = eval_epoch(model, test_loader, args)
        print(f'test loss:{test_loss}')
        test_losses.append(test_loss)
        
        # Train model
        train_loss = train_epoch(model, train_loader, optimizer, args)
        print(f'train loss:{train_loss}')
        train_losses.append(train_loss)
        
        # Save if test loss is the smallest
        if test_loss == min(test_losses):
            # Make sure the directory exists
            print('best model found in epoch: ', epoch, 'with test loss: ', test_loss)
            torch.save(model, f'models/h{args.max_episode_steps}_lpr{args.lambda_pr}_beta{args.beta}.pt')