import os
import tempfile
import shutil
import torch
import dill
import pytest
from torch.utils.data import DataLoader, RandomSampler
from verl.utils.dataset.rl_dataset import RLHFDataset, collate_fn
from verl.utils import hf_tokenizer
import numpy as np

@pytest.fixture
def test_data():
    return [
        {
            "prompt": [{"role": "user", "content": "Test prompt 1"}],
            "completion": "Test completion 1"
        },
        {
            "prompt": [{"role": "user", "content": "Test prompt 2"}],
            "completion": "Test completion 2"
        },
        {
            "prompt": [{"role": "user", "content": "Test prompt 3"}],
            "completion": "Test completion 3"
        },
        {
            "prompt": [{"role": "user", "content": "Test prompt 4"}],
            "completion": "Test completion 4"
        },
        {
            "prompt": [{"role": "user", "content": "Test prompt 5"}],
            "completion": "Test completion 5"
        },
        {
            "prompt": [{"role": "user", "content": "Test prompt 6"}],
            "completion": "Test completion 6"
        },
        {
            "prompt": [{"role": "user", "content": "Test prompt 7"}],
            "completion": "Test completion 7"
        },
        {
            "prompt": [{"role": "user", "content": "Test prompt 8"}],
            "completion": "Test completion 8"
        },
    ]

@pytest.fixture
def temp_dir():
    temp_dir = tempfile.mkdtemp()
    yield temp_dir
    shutil.rmtree(temp_dir)

@pytest.fixture
def dataloader(temp_dir, test_data):
    # Write test data to parquet directly
    import pandas as pd
    df = pd.DataFrame(test_data)
    parquet_file = os.path.join(temp_dir, "test_data.parquet")
    df.to_parquet(parquet_file)

    # Initialize tokenizer and dataset
    tokenizer = hf_tokenizer("Qwen/Qwen2.5-0.5B-Instruct")
    dataset = RLHFDataset(
        parquet_files=[parquet_file],
        tokenizer=tokenizer,
        prompt_key="prompt",
        max_prompt_length=128,
        filter_prompts=True,
        return_raw_chat=False,
        truncation='error'
    )

    # Create dataloader with fixed seed
    generator = torch.Generator()
    generator.manual_seed(42)
    sampler = RandomSampler(dataset, generator=generator)

    return DataLoader(
        dataset=dataset,
        batch_size=2,
        drop_last=True,
        collate_fn=collate_fn,
        sampler=sampler
    )

def test_basic_save_load(temp_dir, dataloader):
    save_path = os.path.join(temp_dir, "dataloader.pt")
    torch.save(dataloader, save_path, pickle_module=dill)

    loaded_dataloader = torch.load(save_path, pickle_module=dill)
    loaded_dataloader.dataset.resume_dataset_state()

    assert len(loaded_dataloader) == len(dataloader)
    assert loaded_dataloader.batch_size == dataloader.batch_size
    assert loaded_dataloader.sampler.generator.initial_seed() == dataloader.sampler.generator.initial_seed()

    dataloader_iter = iter(dataloader)
    loaded_dataloader_iter = iter(loaded_dataloader)
    
    for _ in range(3):
        orig_batch = next(dataloader_iter)
        loaded_batch = next(loaded_dataloader_iter)
        assert orig_batch.keys() == loaded_batch.keys()
        for key in orig_batch.keys():
            if isinstance(orig_batch[key], torch.Tensor):
                assert torch.equal(orig_batch[key], loaded_batch[key])
            elif isinstance(orig_batch[key], np.ndarray):
                assert np.array_equal(orig_batch[key], loaded_batch[key])
            else:
                assert orig_batch[key] == loaded_batch[key]

def test_save_load_after_some_epochs(temp_dir, dataloader):
    epochs = 3
    for _ in range(epochs):
        for batch in dataloader:
            print(batch['completion'])

    save_path = os.path.join(temp_dir, "dataloader.pt")
    torch.save(dataloader, save_path, pickle_module=dill)

    loaded_dataloader = torch.load(save_path, pickle_module=dill)
    loaded_dataloader.dataset.resume_dataset_state()

    assert len(loaded_dataloader) == len(dataloader)
    assert loaded_dataloader.batch_size == dataloader.batch_size
    assert loaded_dataloader.sampler.generator.initial_seed() == dataloader.sampler.generator.initial_seed()

    dataloader_iter = iter(dataloader)
    loaded_dataloader_iter = iter(loaded_dataloader)
    
    for _ in range(3):
        orig_batch = next(dataloader_iter)
        loaded_batch = next(loaded_dataloader_iter)
        assert orig_batch.keys() == loaded_batch.keys()
        for key in orig_batch.keys():
            if isinstance(orig_batch[key], torch.Tensor):
                assert torch.equal(orig_batch[key], loaded_batch[key])
            elif isinstance(orig_batch[key], np.ndarray):
                assert np.array_equal(orig_batch[key], loaded_batch[key])
            else:
                assert orig_batch[key] == loaded_batch[key]


def test_save_load_for_mid_epoch_resume(temp_dir, dataloader):
    # iterate the dataloader for 3 epochs
    epochs = 3
    for _ in range(epochs):
        for batch in dataloader:
            print(batch['completion'])
        
    dataloader_iter = iter(dataloader)
    # consume 2 batches
    consumed_batches = 2
    dataloader.sampler.generator.manual_seed(epochs)
    for _ in range(consumed_batches):
        next(dataloader_iter)
    
    save_path = os.path.join(temp_dir, "dataloader.pt")
    torch.save(dataloader, save_path, pickle_module=dill)

    loaded_dataloader = torch.load(save_path, pickle_module=dill)
    loaded_dataloader.sampler.generator.manual_seed(epochs)
    loaded_dataloader.dataset.resume_dataset_state()

    assert len(loaded_dataloader) == len(dataloader)
    assert loaded_dataloader.batch_size == dataloader.batch_size
    assert loaded_dataloader.sampler.generator.initial_seed() == dataloader.sampler.generator.initial_seed()

    loaded_dataloader_iter = iter(loaded_dataloader)

    # skip the consumed batches for loaded dataloader
    for _ in range(consumed_batches):
        next(loaded_dataloader_iter)
    
    # iterate the remaining batches for both dataloaders, and assert they are the same
    for _ in range(2):
        orig_batch = next(dataloader_iter)
        loaded_batch = next(loaded_dataloader_iter)
        assert orig_batch.keys() == loaded_batch.keys()
        for key in orig_batch.keys():
            if isinstance(orig_batch[key], torch.Tensor):
                assert torch.equal(orig_batch[key], loaded_batch[key])
            elif isinstance(orig_batch[key], np.ndarray):
                assert np.array_equal(orig_batch[key], loaded_batch[key])
            else:
                assert orig_batch[key] == loaded_batch[key]

            
