#!/usr/bin/env python3
"""
BigBird convergence-rate experiment with SEQUENCE PARALLELISM.

Distributes tokens across multiple GPUs:
- Each GPU holds a portion of the sequence
- Layer 0: Embeddings computed locally on each GPU (word + position + type)
- Layers 1-N: All-gather hidden states, run BigBird's native block sparse
  attention layer, each GPU keeps only its output portion

Uses BigBird's native block_sparse attention which is O(n) memory, not O(n²).

Usage:
    python run_convergence_experiment_parallel.py --test         # Basic test ~40k tokens
    python run_convergence_experiment_parallel.py --test_large   # N=122 (~500k tokens)
    python run_convergence_experiment_parallel.py --N 122 --layer_id 1
"""

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from transformers import AutoTokenizer, BigBirdModel, AutoConfig
import math
import numpy as np
from tqdm import tqdm
from typing import List, Tuple, Optional, Dict, Union, Any
import argparse
import gc

# ============== MODEL CONFIGURATIONS ==============
MODEL_CONFIGS = {
    "base": {
        "model_id": "google/bigbird-roberta-base",
        "hidden_size": 768,
        "num_attention_heads": 12,
        "attention_head_size": 64,
        "num_layers": 12,
    },
    "large": {
        "model_id": "google/bigbird-roberta-large",
        "hidden_size": 1024,
        "num_attention_heads": 16,
        "attention_head_size": 64,
        "num_layers": 24,
    },
}

torch.manual_seed(0)


def get_model_config(model_size: str = "base") -> dict:
    if model_size not in MODEL_CONFIGS:
        raise ValueError(f"Unknown model size: {model_size}")
    return MODEL_CONFIGS[model_size]


# ============== POSITION EMBEDDING INTERPOLATION ==============

def interpolate_pos_embeddings(model, new_max_length):
    """Extends position embeddings via linear interpolation."""
    old_embeds = model.embeddings.position_embeddings.weight.data
    old_length = old_embeds.shape[0]

    if new_max_length <= old_length:
        return model

    new_embeds = F.interpolate(
        old_embeds.T.unsqueeze(0),
        size=new_max_length,
        mode='linear',
        align_corners=True,
    ).squeeze(0).T

    model.embeddings.position_embeddings = nn.Embedding(
        new_max_length, old_embeds.shape[1]
    )
    model.embeddings.position_embeddings.weight.data = new_embeds
    model.embeddings.position_embeddings.weight.requires_grad = False
    model.config.max_position_embeddings = new_max_length
    print(f"Extended positions: {old_length} -> {new_max_length}")
    return model


# ============== SEQUENCE PARALLEL UTILITIES ==============

class SequenceParallelContext:
    """Context manager for sequence parallelism across GPUs."""

    def __init__(self, world_size: int = None):
        self.world_size = world_size or torch.cuda.device_count()
        self.devices = [torch.device(f"cuda:{i}") for i in range(self.world_size)]
        print(f"SequenceParallelContext: {self.world_size} GPUs")
        for i, dev in enumerate(self.devices):
            print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")

    def split_sequence(self, tensor: torch.Tensor, dim: int = 1) -> List[torch.Tensor]:
        """Split tensor along sequence dimension across GPUs."""
        seq_len = tensor.shape[dim]
        chunk_size = math.ceil(seq_len / self.world_size)

        chunks = []
        for i in range(self.world_size):
            start = i * chunk_size
            end = min((i + 1) * chunk_size, seq_len)
            if dim == 1:
                chunk = tensor[:, start:end].to(self.devices[i])
            else:
                chunk = tensor[start:end].to(self.devices[i])
            chunks.append(chunk)
        return chunks

    def gather_sequence(self, chunks: List[torch.Tensor], dim: int = 1, target_device: int = 0) -> torch.Tensor:
        """Gather chunks back to single tensor on target device."""
        # Move all to target device
        moved = [c.to(self.devices[target_device]) for c in chunks]
        return torch.cat(moved, dim=dim)

    def all_gather_chunks(self, chunks: List[torch.Tensor], dim: int = 1) -> List[torch.Tensor]:
        """All-gather: each GPU gets the full sequence."""
        # Simple implementation: gather on each device
        full_tensors = []
        for i in range(self.world_size):
            moved = [c.to(self.devices[i]) for c in chunks]
            full = torch.cat(moved, dim=dim)
            full_tensors.append(full)
        return full_tensors


class SequenceParallelBigBird:
    """
    BigBird model with sequence parallelism for very long sequences.

    Distributes sequence across GPUs:
    - Each GPU holds portion of sequence
    - Embeddings computed locally
    - For layers > 0: manual attention with gathered K,V
    """

    def __init__(
        self,
        model_size: str = "base",
        max_length: int = None,
        attention_type: str = "block_sparse",
        torch_dtype: torch.dtype = torch.float16,
    ):
        self.config_dict = get_model_config(model_size)
        self.model_id = self.config_dict["model_id"]
        self.hidden_size = self.config_dict["hidden_size"]
        self.num_heads = self.config_dict["num_attention_heads"]
        self.head_size = self.config_dict["attention_head_size"]
        self.num_layers = self.config_dict["num_layers"]
        self.max_length = max_length
        self.attention_type = attention_type
        self.torch_dtype = torch_dtype

        self.ctx = SequenceParallelContext()
        self.n_gpus = self.ctx.world_size

        # Load model on each GPU
        self.models = self._load_models()

    def _load_models(self) -> List[BigBirdModel]:
        """Load a copy of the model on each GPU."""
        print(f"Loading {self.model_id} on {self.n_gpus} GPUs...")

        offline_mode = os.environ.get("HF_HUB_OFFLINE", "0") == "1"
        models = []

        for i in range(self.n_gpus):
            print(f"  Loading on GPU {i}...")
            try:
                model = BigBirdModel.from_pretrained(
                    self.model_id,
                    torch_dtype=self.torch_dtype
                )
            except Exception as e:
                print(f"  First load failed: {e}")
                fallback_kwargs = {"revision": "refs/pr/2", "use_safetensors": True}
                if not offline_mode:
                    fallback_kwargs["force_download"] = True
                model = BigBirdModel.from_pretrained(
                    self.model_id,
                    torch_dtype=self.torch_dtype,
                    **fallback_kwargs
                )

            model.eval()
            model.set_attention_type(self.attention_type)

            if self.max_length and self.max_length > 4096:
                if i == 0:  # Only print once
                    model = interpolate_pos_embeddings(model, self.max_length)
                else:
                    # Copy interpolated embeddings from first model
                    model = interpolate_pos_embeddings(model, self.max_length)

            model.to(self.ctx.devices[i])
            models.append(model)

        print(f"Models loaded on all {self.n_gpus} GPUs")
        return models

    def get_layer0_embeddings_parallel(self, input_ids: torch.Tensor) -> torch.Tensor:
        """
        Get layer 0 embeddings with sequence parallelism.

        Each GPU computes embeddings for its portion of the sequence.
        Results are gathered back.
        """
        batch_size, seq_len = input_ids.shape

        # Split sequence across GPUs
        id_chunks = self.ctx.split_sequence(input_ids, dim=1)

        # Compute embeddings on each GPU
        embed_chunks = []
        for i, (ids, model) in enumerate(zip(id_chunks, self.models)):
            chunk_len = ids.shape[1]
            start_pos = i * (seq_len // self.n_gpus)

            with torch.no_grad():
                # Word embeddings
                token_embeds = model.embeddings.word_embeddings(ids)

                # Position embeddings (with correct positions)
                position_ids = torch.arange(
                    start_pos, start_pos + chunk_len,
                    device=ids.device
                ).unsqueeze(0)
                position_embeds = model.embeddings.position_embeddings(position_ids)

                # Token type embeddings
                token_type_ids = torch.zeros_like(ids)
                token_type_embeds = model.embeddings.token_type_embeddings(token_type_ids)

                # Combine and normalize
                embeddings = token_embeds + position_embeds + token_type_embeds
                embeddings = model.embeddings.LayerNorm(embeddings)

            embed_chunks.append(embeddings)

        # Gather back to single tensor
        return self.ctx.gather_sequence(embed_chunks, dim=1, target_device=0)

    def get_layerN_embeddings_parallel(
        self,
        input_ids: torch.Tensor,
        layer_id: int,
    ) -> torch.Tensor:
        """
        Get embeddings at layer N with sequence parallelism.

        Strategy:
        1. Compute layer 0 embeddings in parallel (local)
        2. For each transformer layer up to layer_id:
           a. All-gather hidden states so each GPU has full sequence
           b. Run BigBird's native encoder layer (with block sparse attention)
           c. Each GPU keeps only its portion of the output

        This uses BigBird's native block_sparse attention which is O(n) not O(n²).
        """
        batch_size, seq_len = input_ids.shape
        chunk_size = math.ceil(seq_len / self.n_gpus)

        print(f"  Computing layer {layer_id} embeddings with sequence parallelism")
        print(f"  Sequence length: {seq_len}, split across {self.n_gpus} GPUs")
        print(f"  Chunk size per GPU: ~{chunk_size}")

        # Get layer 0 embeddings (parallel)
        print("  Step 1: Computing layer 0 embeddings (parallel)...")
        hidden_states_full = self.get_layer0_embeddings_parallel(input_ids)

        # Split hidden states across GPUs
        hidden_chunks = self.ctx.split_sequence(hidden_states_full, dim=1)
        del hidden_states_full
        torch.cuda.empty_cache()

        # Process through transformer layers using native BigBird block sparse attention
        for layer_idx in range(layer_id):
            print(f"  Step 2.{layer_idx}: Processing transformer layer {layer_idx} (block sparse attention)...")

            # All-gather hidden states so each GPU has full sequence
            hidden_full_per_gpu = self.ctx.all_gather_chunks(hidden_chunks, dim=1)

            new_hidden_chunks = []

            for gpu_idx in range(self.n_gpus):
                model = self.models[gpu_idx]
                device = self.ctx.devices[gpu_idx]
                layer = model.encoder.layer[layer_idx]

                # Get full hidden states on this GPU
                full_hidden = hidden_full_per_gpu[gpu_idx]  # [batch, full_seq, hidden]

                # Compute start/end for this GPU's portion
                start_pos = gpu_idx * chunk_size
                end_pos = min((gpu_idx + 1) * chunk_size, seq_len)

                with torch.no_grad():
                    # Run BigBird's native encoder layer (uses block sparse attention)
                    # This handles local, global, and random attention patterns
                    # Block sparse attention is O(n) memory, not O(n²)
                    layer_output = layer(
                        full_hidden,
                        attention_mask=None,
                        head_mask=None,
                        output_attentions=False,
                    )

                    # layer_output is a tuple, first element is hidden states
                    full_output = layer_output[0]  # [batch, full_seq, hidden]

                    # Keep only this GPU's portion
                    local_output = full_output[:, start_pos:end_pos, :].clone()

                    new_hidden_chunks.append(local_output)

                    # Clear memory aggressively
                    del full_hidden, layer_output, full_output
                    torch.cuda.empty_cache()

            hidden_chunks = new_hidden_chunks
            del hidden_full_per_gpu
            gc.collect()
            torch.cuda.empty_cache()

        # Gather final hidden states
        return self.ctx.gather_sequence(hidden_chunks, dim=1, target_device=0)

    def _pad_to_block_size(self, input_ids: torch.Tensor, block_size: int = 64) -> Tuple[torch.Tensor, int]:
        """Pad sequence length to be divisible by block_size for block sparse attention."""
        seq_len = input_ids.shape[1]
        if seq_len % block_size == 0:
            return input_ids, 0

        pad_len = block_size - (seq_len % block_size)
        # Pad with pad token (usually 0)
        padding = torch.zeros(
            input_ids.shape[0], pad_len,
            dtype=input_ids.dtype, device=input_ids.device
        )
        padded = torch.cat([input_ids, padding], dim=1)
        return padded, pad_len

    def get_embeddings(self, input_ids: torch.Tensor, layer_id: int) -> torch.Tensor:
        """Get embeddings at specified layer.

        Handles padding for block sparse attention compatibility.
        """
        original_len = input_ids.shape[1]

        # Pad to block size for block sparse attention
        padded_ids, pad_len = self._pad_to_block_size(input_ids, block_size=64)

        if pad_len > 0:
            print(f"  Padded sequence: {original_len} -> {padded_ids.shape[1]} (+{pad_len})")

        if layer_id == 0:
            embeddings = self.get_layer0_embeddings_parallel(padded_ids)
        else:
            embeddings = self.get_layerN_embeddings_parallel(padded_ids, layer_id)

        # Remove padding from output
        if pad_len > 0:
            embeddings = embeddings[:, :original_len, :]

        return embeddings

    def extract_qkv(self, layer_id: int = 0):
        """Get Q, K, V weight matrices (from first GPU's model)."""
        attn_self = self.models[0].encoder.layer[layer_id].attention.self

        Wq = attn_self.query.weight.detach()
        Wk = attn_self.key.weight.detach()
        Wv = attn_self.value.weight.detach()

        bq = attn_self.query.bias.detach() if attn_self.query.bias is not None else None
        bk = attn_self.key.bias.detach() if attn_self.key.bias is not None else None
        bv = attn_self.value.bias.detach() if attn_self.value.bias is not None else None

        return Wq, Wk, Wv, bq, bk, bv


# ============== TESTING ==============

def test_sequence_parallel():
    """Test sequence parallel BigBird."""
    print("\n" + "=" * 70)
    print("TESTING SEQUENCE PARALLEL BIGBIRD")
    print("=" * 70)

    n_gpus = torch.cuda.device_count()
    print(f"Available GPUs: {n_gpus}")

    if n_gpus < 2:
        print("WARNING: Only 1 GPU available. Sequence parallelism works but no speedup.")

    # Initialize model
    max_length = 4096 * 10  # 40k tokens for test
    print(f"\nInitializing with max_length={max_length}")

    sp_model = SequenceParallelBigBird(
        model_size="base",
        max_length=max_length,
        torch_dtype=torch.float16,
    )

    # Create test tokens
    print("\nCreating test tokens...")
    tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base")
    text = "Hello world. This is a test of sequence parallelism. " * 5000
    tokens = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")
    input_ids = tokens['input_ids'].to("cuda:0")
    print(f"Input shape: {input_ids.shape}")

    # Test layer 0
    print("\n--- Testing Layer 0 Embeddings ---")
    X0 = sp_model.get_embeddings(input_ids, layer_id=0)
    print(f"Layer 0 output shape: {X0.shape}")

    # Test layer 1
    print("\n--- Testing Layer 1 Embeddings ---")
    X1 = sp_model.get_embeddings(input_ids, layer_id=1)
    print(f"Layer 1 output shape: {X1.shape}")

    # Test layer 2
    print("\n--- Testing Layer 2 Embeddings ---")
    X2 = sp_model.get_embeddings(input_ids, layer_id=2)
    print(f"Layer 2 output shape: {X2.shape}")

    print("\n" + "=" * 70)
    print("TEST PASSED!")
    print("=" * 70)

    return sp_model


def test_large_sequence():
    """Test with very large sequence (N=122)."""
    print("\n" + "=" * 70)
    print("TESTING LARGE SEQUENCE (N=122)")
    print("=" * 70)

    max_length = 4096 * 122  # ~500k tokens
    print(f"Target max_length: {max_length}")

    sp_model = SequenceParallelBigBird(
        model_size="base",
        max_length=max_length,
        torch_dtype=torch.float16,
    )

    # Create tokens
    tokenizer = AutoTokenizer.from_pretrained("google/bigbird-roberta-base")
    text = "Test. " * (max_length // 2)
    tokens = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt")
    input_ids = tokens['input_ids'].to("cuda:0")
    print(f"Input shape: {input_ids.shape}")

    # Test layer 1
    print("\n--- Testing Layer 1 with 500k tokens ---")
    X1 = sp_model.get_embeddings(input_ids, layer_id=1)
    print(f"Layer 1 output shape: {X1.shape}")

    print("\nSUCCESS!")
    return sp_model


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--test", action="store_true", help="Run basic test")
    parser.add_argument("--test_large", action="store_true", help="Test with N=122")
    parser.add_argument("--N", type=int, default=10, help="Sequence multiplier")
    parser.add_argument("--layer_id", type=int, default=1, help="Target layer")

    args = parser.parse_args()

    if args.test:
        test_sequence_parallel()
    elif args.test_large:
        test_large_sequence()
    else:
        print("Use --test for basic test or --test_large for N=122 test")
