import torch.multiprocessing as mp
if mp.get_start_method(allow_none=True) is None:
    mp.set_start_method('spawn', force=True)  # or 'forkserver'

import os
import torch
import argparse
import numpy as np
import random
import pickle
from tqdm import tqdm
import torch.nn as nn
import transformers
transformers.set_seed(0)
from transformers import GPT2Config, GPT2Model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Transformer(nn.Module):
    """Transformer class."""
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.config = config
        self.test = config['test']
        self.horizon = self.config['horizon']
        self.n_embd = self.config['n_embd']
        self.n_layer = self.config['n_layer']
        self.n_head = self.config['n_head']
        self.state_dim = self.config['state_dim']
        self.action_dim = self.config['action_dim']
        self.dropout = self.config['dropout']
        self.model_type = self.config['model_type']
        config = GPT2Config(
            n_positions=4 * (1 + self.horizon),
            n_embd=self.n_embd,
            n_layer=self.n_layer,
            n_head=1,
            resid_pdrop=self.dropout,
            embd_pdrop=self.dropout,
            attn_pdrop=self.dropout,
            use_cache=False,
        )
        self.transformer = GPT2Model(config)
        if self.model_type == 'DPTPR':
            self.embed_transition = nn.Linear(
                2 * self.state_dim + 2 * self.action_dim, self.n_embd)
            self.forward = self.forward_DPTPR
        else:
            self.embed_transition = nn.Linear(
                2 * self.state_dim + self.action_dim + 1, self.n_embd)
            self.forward = self.forward_DPT
        self.pred_actions = nn.Linear(self.n_embd, self.action_dim)
        nn.init.normal_(self.pred_actions.weight, mean=0.0, std=1e-3)
        nn.init.zeros_(self.pred_actions.bias)    

    def forward_DPTPR(self, x):
        query_states = x['query_states'][:, None, :]
        zeros = x['zeros'][:, None, :]
        state_seq = torch.cat([query_states, x['context_states']], dim=1)
        pref_actions_seq = torch.cat([zeros[:, :, :self.action_dim], x['preferred_actions']], dim=1)
        npref_actions_seq = torch.cat([zeros[:, :, :self.action_dim], x['non_preferred_actions']], dim=1)
        next_state_seq = torch.cat([zeros[:, :, :self.state_dim], x['context_next_states']], dim=1)
        seq = torch.cat(
            [state_seq, pref_actions_seq, npref_actions_seq, next_state_seq], dim=2)
        stacked_inputs = self.embed_transition(seq)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_actions(transformer_outputs['last_hidden_state'])
        if self.test:
            return preds[:, -1, :]
        return preds[:, 1:, :]
    
    def forward_DPT(self, x):
        query_states = x['query_states'][:, None, :]
        zeros = x['zeros'][:, None, :]
        state_seq = torch.cat([query_states, x['context_states']], dim=1)
        actions = torch.cat([zeros[:, :, :self.action_dim], x['context_actions']], dim=1)
        next_state_seq = torch.cat([zeros[:, :, :self.state_dim], x['context_next_states']], dim=1)
        rewards = torch.cat([zeros[:, :, :1], x['context_rewards']], dim=1)
        seq = torch.cat(
            [state_seq, actions, next_state_seq, rewards], dim=2)
        stacked_inputs = self.embed_transition(seq)
        transformer_outputs = self.transformer(inputs_embeds=stacked_inputs)
        preds = self.pred_actions(transformer_outputs['last_hidden_state'])
        if self.test:
            return preds[:, -1, :]
        return preds[:, 1:, :]    



class MetaWorldDataset(torch.utils.data.Dataset):
    """Dataset for MetaWorld trajectories with advantage labels."""
    
    def __init__(self, data_path, horizon, model_type):
        # Load data
        with open(data_path, 'rb') as f:
            self.data = pickle.load(f)
        # Calculate dataset size
        self.length = len(self.data)
        # Extract useful dimensions from the first trajectory
        sample = self.data[0]
        self.state_dim = 39
        self.action_dim = 4
        self.horizon = horizon
        self.model_type = model_type

    def __len__(self):
        return self.length
    def __getitem__(self, idx):
        if self.model_type == 'DPTPR':
            return self.getitem_DPTPR(idx)
        else:
            return self.getitem_DPT(idx)
        
    def getitem_DPTPR(self, idx):
        trajectory = self.data[idx]
        # # Convert to tensors
        context_states = torch.FloatTensor(np.array(trajectory['context_states']))[:self.horizon, :]
        context_next_states = torch.FloatTensor(np.array(trajectory['context_next_states']))[:self.horizon, :]
        preferred_actions = torch.FloatTensor(np.array(trajectory['preferred_actions']))[:self.horizon, :]
        non_preferred_actions = torch.FloatTensor(np.array(trajectory['non_preferred_actions']))[:self.horizon, :]
        optimal_actions = torch.FloatTensor(np.array(trajectory['optimal_actions']))[:self.horizon, :]
        # context_rewards = torch.FloatTensor(np.array(trajectory['context_rewards']))
        zeros = torch.zeros(context_states.shape[-1] ** 2 + preferred_actions.shape[-1] + 1) # This will give you zeros more than required 
        # sample a query state from the context states
        choice = torch.randint(0, context_states.shape[0], (1,))
        query_state = context_states[choice]
        optimal_action = optimal_actions[choice]

        return {
            'context_states': context_states,
            'context_next_states': context_next_states,
            'preferred_actions': preferred_actions,
            'non_preferred_actions': non_preferred_actions,
            'zeros': zeros,
            'query_states': query_state.squeeze(),
            'optimal_action': optimal_action,
        }

    def getitem_DPT(self, idx):
        trajectory = self.data[idx]
        # # Convert to tensors
        context_states = torch.FloatTensor(np.array(trajectory['context_states']))[:self.horizon, :]
        context_next_states = torch.FloatTensor(np.array(trajectory['context_next_states']))[:self.horizon, :]
        context_actions = torch.FloatTensor(np.array(trajectory['context_actions']))[:self.horizon, :]
        context_rewards = torch.FloatTensor(np.array(trajectory['context_rewards']))[:self.horizon][:, None]
        optimal_actions = torch.FloatTensor(np.array(trajectory['optimal_actions']))[:self.horizon, :]
        # context_rewards = torch.FloatTensor(np.array(trajectory['context_rewards']))
        zeros = torch.zeros(context_states.shape[-1] ** 2 + context_actions.shape[-1] + 1) # This will give you zeros more than required 
        # sample a query state from the context states
        choice = torch.randint(0, context_states.shape[0], (1,))
        query_state = context_states[choice]
        optimal_action = optimal_actions[choice]

        return {
            'context_states': context_states,
            'context_next_states': context_next_states,
            'context_actions': context_actions,
            'context_rewards': context_rewards,
            'zeros': zeros,
            'query_states': query_state.squeeze(),
            'optimal_action': optimal_action,
        }


def build_metaworld_data_filename(n_tasks, n_trajs, p_good, p_bad, model_type, mode=0):
    mode_str = 'train' if mode == 0 else 'test' if mode == 1 else 'eval'
    if model_type == 'DPTPR':
        filename = f'datasets_DPT/metaworld_tasks{n_tasks}_trajs{n_trajs}_pg{p_good}_pb{p_bad}_{mode_str}.pkl'
    else:
        filename = f'datasets_DPT/metaworld_DIT_tasks{n_tasks}_trajs{n_trajs}_p{p_good}_{mode_str}.pkl'
    return filename



def parse_args():
    """Parse command line arguments."""
    parser = argparse.ArgumentParser(description="Train a Decision Preference Transformer on MetaWorld data")
    # Dataset parameters
    parser.add_argument('--train_tasks', type=list, default=[1,2,3,4,5,6,7,8,9,
                                                            11,12,13,14,15,16,17,18,19,
                                                            21,22,23,24,25,26,27,28,29,
                                                            31,32,33,34,35,36,37,38,39,
                                                            41,42,43,44,45,46,47,48,49], help="Tasks to train on")
    parser.add_argument('--eval_tasks', type=list, default=[0,10,20,30,40], help="Tasks to eval on")
    parser.add_argument('--max-episode-steps', type=int, default=200, help="Max episode steps")
    parser.add_argument('--n-trajs', type=int, default=1000, help="Number of trajectories to collect for each task")
    parser.add_argument('--p-good', type=int, default=80, help="Good policy probability")
    parser.add_argument('--p-bad', type=int, default=20, help="Bad policy probability")
    # Model parameters
    parser.add_argument('--model-type', type=str, default='DPT', choices=['DPT','DPTPR'], help="Model to use")
    parser.add_argument('--embedding-dim', type=int, default=256, help="Embedding dimension")
    parser.add_argument('--num-layers', type=int, default=6, help="Number of transformer layers")
    parser.add_argument('--num-heads', type=int, default=8, help="Number of attention heads")
    # Training parameters
    parser.add_argument('--learning-rate', type=float, default=1e-4, help="Learning rate")
    parser.add_argument('--batch-size', type=int, default=64, help="Batch size")
    parser.add_argument('--num-epochs', type=int, default=200, help="Number of training epochs")
    parser.add_argument('--seed', type=int, default=42, help="Random seed")
    return parser.parse_args()



def construct_batch(batch, model):
    action_dim = model.action_dim
    label = batch['optimal_action'].repeat(1, model.horizon, 1)
    preds = model(batch)
    mse = torch.nn.functional.mse_loss(preds, label, reduction='mean')
    return mse


def eval_epoch(model, test_loader, args):
    test_loss = 0.0
    num_batches = 0
    with torch.no_grad():
        for batch_id, batch in enumerate(test_loader):
            # Move batch to device
            batch = {k: v.to('cuda') for k, v in batch.items()}
            loss = construct_batch(batch, model)
            test_loss += loss.item()
            num_batches += 1
            # Optional: Add synchronization
            torch.cuda.synchronize()
    # Avoid division by zero
    return test_loss / max(1, num_batches)



def train_epoch(model, train_loader, optimizer, args):
    train_loss = 0.0
    progress_bar = tqdm(train_loader)
    num_batches = 0
    for batch_id, batch in enumerate(progress_bar):
        # Move batch to device
        batch = {k: v.to('cuda') for k, v in batch.items()}
        loss = construct_batch(batch, model)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        num_batches += 1
        # Print progress
        if batch_id % 200 == 0:
            progress_bar.set_description(f'{loss.item():.2f}')
        # Clear cache periodically to prevent memory issues
        if batch_id % 10 == 0:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
    # Avoid division by zero
    return train_loss / max(1, num_batches)



if __name__ == '__main__':
    # Parse arguments
    args = parse_args()
    print("Args: ", args)
    
    # Create necessary directories
    if not os.path.exists('models_new'):
        os.makedirs('models_new', exist_ok=True)
    
    # Set random seeds
    seed = args.seed
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(seed)
    random.seed(seed)
    
    # Build data paths
    train_task_ids = args.train_tasks
    test_task_ids = args.train_tasks # NOTE:for testing dataset used for training, we use the same tasks for training and testing
    eval_task_ids = args.eval_tasks    
    train_path = build_metaworld_data_filename(
        len(train_task_ids), args.n_trajs, args.p_good, args.p_bad, args.model_type, mode=0)
    test_path = build_metaworld_data_filename(
        len(test_task_ids), 10, args.p_good, args.p_bad, args.model_type, mode=1)
    
    # Load datasets - keep data on CPU
    print("Loading datasets...")
    train_dataset = MetaWorldDataset(train_path, args.max_episode_steps, args.model_type)
    test_dataset = MetaWorldDataset(test_path, args.max_episode_steps, args.model_type)
    print("finished loading datasets")
    
    # Define transformer configuration
    config = {
        'horizon': args.max_episode_steps,  # Using n_trajs as an approximation for horizon
        'state_dim': train_dataset.state_dim, 
        'action_dim': train_dataset.action_dim,
        'n_layer': args.num_layers,
        'n_embd': args.embedding_dim,
        'n_head': args.num_heads,
        'dropout': False,  # Keeping your original dropout setting
        'test': False,
        'model_type': args.model_type,
    }
    
    # Import the Transformer class here
    model = Transformer(config).to('cuda')
    
    # Create data loaders with optimized settings
    train_loader_params = {
        'batch_size': args.batch_size,
        'num_workers': 0,  # Change to 0 to avoid multiprocessing issues
        'shuffle': True,
        'pin_memory': True,  # This helps with CPU to GPU transfers
    }
    test_loader_params = {
        'batch_size': args.batch_size,
        'num_workers': 0,  # Change to 0 to avoid multiprocessing issues
        'shuffle': False,
        'pin_memory': True,
    }
    
    train_loader = torch.utils.data.DataLoader(train_dataset, **train_loader_params)
    test_loader = torch.utils.data.DataLoader(test_dataset, **test_loader_params)
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=1e-4)
    
    # Training and evaluation
    test_losses, train_losses = [], []
    
    for epoch in tqdm(range(args.num_epochs)):
        print(f'============ Epoch:{epoch} ============')
        
        # Test model first (keeping your original order)
        model.test = False
        test_loss = eval_epoch(model, test_loader, args)
        print(f'test loss:{test_loss}')
        test_losses.append(test_loss)
        
        # Train model
        train_loss = train_epoch(model, train_loader, optimizer, args)
        print(f'train loss:{train_loss}')
        train_losses.append(train_loss)
        
        # Save if test loss is the smallest
        if test_loss == min(test_losses):
            # Make sure the directory exists
            print('best model found in epoch: ', epoch, 'with test loss: ', test_loss)
            torch.save(model, f'models/h{args.max_episode_steps}_{args.model_type}.pt')