from __future__ import annotations
import os
import warnings
from typing import Optional, Tuple

import torch
import torch.distributed as dist
from torch import nn

def _dist_ready() -> bool:
    return dist.is_available() and dist.is_initialized()

def _tp_rank_world(tp_group: Optional[dist.ProcessGroup]) -> Tuple[int, int]:
    if not _dist_ready():
        return 0, 1
    if tp_group is None:
        return dist.get_rank(), dist.get_world_size()
    return dist.get_rank(tp_group), dist.get_world_size(tp_group)

def _local_cuda_device() -> torch.device:
    if torch.cuda.is_available() and "LOCAL_RANK" in os.environ:
        return torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
    return torch.device("cpu")

def _chunk_range(total: int, rank: int, world: int) -> Tuple[int, int]:
    assert total % world == 0, f"dim={total} must be divisible by tp_size={world}"
    per = total // world
    s = rank * per
    e = s + per
    return s, e

def _all_reduce_(t: torch.Tensor, tp_group: Optional[dist.ProcessGroup] = None):
    if _dist_ready() and _tp_rank_world(tp_group)[1] > 1:
        dist.all_reduce(t, op=dist.ReduceOp.SUM, group=tp_group)
    return t

def _all_gather_lastdim(x_local: torch.Tensor, tp_group: Optional[dist.ProcessGroup] = None) -> torch.Tensor:
    if not _dist_ready() or _tp_rank_world(tp_group)[1] == 1:
        return x_local
    world = _tp_rank_world(tp_group)[1]
    parts = [torch.empty_like(x_local) for _ in range(world)]
    dist.all_gather(parts, x_local.contiguous(), group=tp_group)
    return torch.cat(parts, dim=-1)

def _is_llama_attention(mod: nn.Module) -> bool:
    return all(hasattr(mod, n) for n in ("q_proj", "k_proj", "v_proj", "o_proj"))

def _is_llama_mlp(mod: nn.Module) -> bool:
    return all(hasattr(mod, n) for n in ("gate_proj", "up_proj", "down_proj"))

@torch.no_grad()
def _colwise_shard_linear_to_device(lin: nn.Linear, rank: int, world: int, device: torch.device):
    w_cpu = lin.weight.data.detach().to("cpu", non_blocking=False)
    s, e = _chunk_range(w_cpu.size(0), rank, world)
    w_shard = w_cpu[s:e].contiguous().to(device, non_blocking=False)
    lin.weight = nn.Parameter(w_shard, requires_grad=False)
    if lin.bias is not None:
        b_cpu = lin.bias.data.detach().to("cpu", non_blocking=False)
        lin.bias = nn.Parameter(b_cpu[s:e].contiguous().to(device, non_blocking=False), requires_grad=False)
    lin.out_features = e - s

@torch.no_grad()
def _rowwise_shard_linear_to_device(lin: nn.Linear, rank: int, world: int, device: torch.device):
    w_cpu = lin.weight.data.detach().to("cpu", non_blocking=False)
    s, e = _chunk_range(w_cpu.size(1), rank, world)
    w_shard = w_cpu[:, s:e].contiguous().to(device, non_blocking=False)
    lin.weight = nn.Parameter(w_shard, requires_grad=False)
    lin.in_features = e - s

@torch.no_grad()
def _replicate_remaining_to_device(model: nn.Module, device: torch.device):
    model.to(device)

def apply_tp(
    model: nn.Module,
    tp_group: Optional[dist.ProcessGroup] = None,
    *,
    shard_lm_head_colwise: bool = True,
    attach_hooks: bool = True,
    cpu_shard_then_move: bool = True,
    target_device: Optional[torch.device] = None,
    replicate_remaining: bool = True,
) -> nn.Module:
    rank, world = _tp_rank_world(tp_group)
    if world == 1:
        warnings.warn("tp_size=1: apply_tp() does nothing.")
        return model

    device = target_device or _local_cuda_device()

    if cpu_shard_then_move:
        if any(p.is_cuda for p in model.parameters()):
            warnings.warn("Model is on CUDA; moving to CPU first for CPU-side sharding.")
            model.to("cpu")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
    else:
        warnings.warn("cpu_shard_then_move=False: VRAM may not be reduced.")
    
    for layer in getattr(model, "model").layers:
        attn = layer.self_attn
        mlp  = layer.mlp
        if not _is_llama_attention(attn) or not _is_llama_mlp(mlp):
            raise RuntimeError("Expected Llama-like modules (q/k/v/o + gate/up/down).")

        _colwise_shard_linear_to_device(attn.q_proj, rank, world, device)
        _colwise_shard_linear_to_device(attn.k_proj, rank, world, device)
        _colwise_shard_linear_to_device(attn.v_proj, rank, world, device)
        _rowwise_shard_linear_to_device(attn.o_proj, rank, world, device)

        _colwise_shard_linear_to_device(mlp.gate_proj, rank, world, device)
        _colwise_shard_linear_to_device(mlp.up_proj,   rank, world, device)
        _rowwise_shard_linear_to_device(mlp.down_proj, rank, world, device)

        if hasattr(attn, "num_attention_heads"):
            attn.num_attention_heads //= world
        if hasattr(attn, "num_key_value_heads"):
            attn.num_key_value_heads //= world

    if shard_lm_head_colwise and hasattr(model, "lm_head") and isinstance(model.lm_head, nn.Linear):
        _colwise_shard_linear_to_device(model.lm_head, rank, world, device)
    else:
        if hasattr(model, "lm_head") and isinstance(model.lm_head, nn.Linear):
            model.lm_head.to(device)
        warnings.warn("lm_head not sharded colwise.")

    if hasattr(model, "get_input_embeddings"):
        model.get_input_embeddings().to(device)

    cfg = getattr(model, "config", None)
    if cfg is not None:
        if hasattr(cfg, "pretraining_tp"):
            cfg.pretraining_tp = 1
        if hasattr(cfg, "num_attention_heads"):
            cfg.num_attention_heads //= world
        if hasattr(cfg, "num_key_value_heads"):
            cfg.num_key_value_heads //= world

    if attach_hooks:
        for layer in getattr(model, "model").layers:
            attn = layer.self_attn
            def _attn_reduce_hook(mod, args, out, _tp_group=tp_group):
                if isinstance(out, tuple):
                    x, *rest = out
                    _all_reduce_(x, _tp_group)
                    return (x, *rest)
                _all_reduce_(out, _tp_group)
                return out
            attn.register_forward_hook(_attn_reduce_hook)

        for layer in getattr(model, "model").layers:
            mlp = layer.mlp
            def _mlp_reduce_hook(mod, args, out, _tp_group=tp_group):
                if isinstance(out, tuple):
                    x, *rest = out
                    _all_reduce_(x, _tp_group)
                    return (x, *rest)
                _all_reduce_(out, _tp_group)
                return out
            mlp.register_forward_hook(_mlp_reduce_hook)

        if shard_lm_head_colwise and hasattr(model, "lm_head") and isinstance(model.lm_head, nn.Linear):
            def _lm_head_gather_hook(mod, args, out, _tp_group=tp_group):
                return _all_gather_lastdim(out, _tp_group)
            model.lm_head.register_forward_hook(_lm_head_gather_hook)

    if replicate_remaining:
        _replicate_remaining_to_device(model, device)

    return model.eval()
