

import contextlib
import copy
from dataclasses import dataclass

import pytest
import torch
import torch.distributed
from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input
from torch.distributed import init_device_mesh
from transformers import AutoModelForCausalLM, LlamaConfig, PretrainedConfig, Qwen2Config

from verl.models.transformers.monkey_patch import apply_monkey_patch
from verl.protocol import DataProto
from verl.utils.distributed import initialize_global_process_group
from verl.utils.model import compute_position_id_with_mask, create_random_mask
from verl.utils.ulysses import (
    gather_outputs_and_unpad,
    get_ulysses_sequence_parallel_world_size,
    set_ulysses_sequence_parallel_group,
    ulysses_pad_and_slice_inputs,
)
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager

@dataclass
class SequenceParallelConfig:
    config: PretrainedConfig
    sp_size: int
    is_valid: bool

def test_configs():
    return [
        SequenceParallelConfig(
            LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32), sp_size=8, is_valid=True
        ),
        SequenceParallelConfig(
            Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),
            sp_size=4,
            is_valid=True,
        ),
        SequenceParallelConfig(
            Qwen2Config(num_hidden_layers=2, num_attention_heads=28, num_key_value_heads=4, hidden_size=3584),
            sp_size=8,
            is_valid=False,
        ),
        SequenceParallelConfig(
            Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=4, is_valid=True
        ),
        SequenceParallelConfig(
            Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4), sp_size=8, is_valid=True
        ),
    ]

def sync_model_parameters_global(layer):

    for p in layer.parameters():
        torch.distributed.broadcast(tensor=p.data, src=0)

@pytest.mark.parametrize("test_config", test_configs())
def test_hf_casual_fwd_bwd(test_config):
    if not torch.distributed.is_initialized():
        initialize_global_process_group()

    context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError)
    with context:
        world_size = torch.distributed.get_world_size()
        _hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size)

def _hf_casual_fwd(config, sp_size, dp_size):
    assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"

    ulysses_device_mesh = init_device_mesh(
        device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")
    )
    sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)

    batch_size = 1
    seqlen = 128

    with torch.device("cuda"):
        model = AutoModelForCausalLM.from_config(
            config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
        )
        apply_monkey_patch(model, sp_size)
        model = model.to(device="cuda")
        sync_model_parameters_global(model)

    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
    attention_mask = create_random_mask(
        input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8
    )
    position_ids = compute_position_id_with_mask(
        attention_mask
    )

    model_inputs = {
        "input_ids": input_ids.cuda(),
        "attention_mask": attention_mask.cuda(),
        "position_ids": position_ids.int().cuda(),
    }

    model_inputs = DataProto.from_dict(model_inputs)

    with sharding_manager:
        model_inputs = sharding_manager.preprocess_data(model_inputs)
        input_ids = model_inputs.batch["input_ids"]
        attention_mask = model_inputs.batch["attention_mask"]
        position_ids = model_inputs.batch["position_ids"]
        input_ids_rmpad, indices, *_ = unpad_input(
            input_ids.unsqueeze(-1), attention_mask
        )
        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)

        position_ids_rmpad = index_first_axis(
            rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
        ).transpose(0, 1)

        input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
            input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()
        )

        logits_split_in_seq = model(
            input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False
        ).logits

        logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)

    set_ulysses_sequence_parallel_group(None)
    logits_rmpad_local = model(
        input_ids_rmpad, position_ids=position_ids_rmpad, use_cache=False
    ).logits

    mean_local = logits_rmpad_local.mean()
    mean_full = logits_full.mean()
    torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)

def _hf_casual_fwd_bwd(config, sp_size, dp_size):
    assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"

    ulysses_device_mesh = init_device_mesh(
        device_type="cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp")
    )
    sharding_manager = FSDPUlyssesShardingManager(ulysses_device_mesh)

    batch_size = 1
    seqlen = 128

    with torch.device("cuda"):
        model = AutoModelForCausalLM.from_config(
            config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
        )
        apply_monkey_patch(model, sp_size)
        model = model.to(device="cuda")
        sync_model_parameters_global(model)

    input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device="cuda")
    attention_mask = create_random_mask(
        input_ids=input_ids, max_ratio_of_left_padding=0, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.8
    )
    position_ids = compute_position_id_with_mask(
        attention_mask
    )

    model_inputs = {
        "input_ids": input_ids.cuda(),
        "attention_mask": attention_mask.cuda(),
        "position_ids": position_ids.int().cuda(),
    }

    model_inputs = DataProto.from_dict(model_inputs)

    with sharding_manager:
        model_inputs = sharding_manager.preprocess_data(model_inputs)
        input_ids = model_inputs.batch["input_ids"]
        attention_mask = model_inputs.batch["attention_mask"]
        position_ids = model_inputs.batch["position_ids"]
        input_ids_rmpad, indices, *_ = unpad_input(
            input_ids.unsqueeze(-1), attention_mask
        )
        input_ids_rmpad = input_ids_rmpad.transpose(0, 1)

        position_ids_rmpad = index_first_axis(
            rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
        ).transpose(0, 1)

        input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
            input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()
        )

        logits_split_in_seq = model(
            input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded, use_cache=False
        ).logits

        logits_full = gather_outputs_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)

    set_ulysses_sequence_parallel_group(None)
    input_ids_full = copy.deepcopy(input_ids_rmpad)
    position_ids_full = copy.deepcopy(position_ids_rmpad)
    model_no_sp = copy.deepcopy(model)
    logits_rmpad_local = model_no_sp(
        input_ids_full, position_ids=position_ids_full, use_cache=False
    ).logits

    mean_local = logits_rmpad_local.mean()
    mean_full = logits_full.mean()

    mean_full.backward()
    mean_local.backward()

    grad = model.model.layers[0].self_attn.q_proj.weight.grad
    grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad
    torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
    torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)

if __name__ == "__main__":
    pytest.main([__file__, "-svv"])
