import os
from typing import List
from itertools import accumulate

import torch
import torch.distributed as dist
from torch import nn

from ..model import Attention, FeedForward, Transformer


def _get_global_rank() -> int:
    return int(os.environ.get("LOCAL_RANK", "0"))


def is_local():
    return _get_global_rank() == 0


def local_break():
    if is_local():
        breakpoint()
    dist.barrier()


def _get_world_size() -> int:
    return int(os.environ.get("LOCAL_WORLD_SIZE", "1"))


def _select_kv_heads(num_kv_heads, rank_group:list):
    global_rank = _get_global_rank()
    rank = rank_group.index(global_rank)
    world_size = len(rank_group)
    base_heads = num_kv_heads // world_size
    remainder = num_kv_heads % world_size
    distribution = [base_heads] * world_size
    for i in range(remainder):
        distribution[i] += 1
    cumulative_distribution = list(accumulate(distribution))
    if rank == 0:
        start = 0
        end = cumulative_distribution[0]
    else:
        start = cumulative_distribution[rank-1]
        end = cumulative_distribution[rank]
    return start, end


def init_dist():
    global_rank = _get_global_rank()
    world_size = _get_world_size()

    torch.cuda.set_device(global_rank)
    dist.init_process_group(backend="nccl", rank=global_rank, world_size=world_size, device_id=torch.device(f'cuda:{global_rank}'))
    global_group = dist.group.WORLD
    
    return global_rank, global_group


def _apply_tp_linear(linear: nn.Linear, style: str, weight_splits: List[int] = [], rank_group=None, num_kv_heads = None, num_heads = None, head_dim = None) -> None:
    num_group = num_heads//num_kv_heads
    kv_start, kv_end = _select_kv_heads(num_kv_heads, rank_group)
    q_start = kv_start*num_group*head_dim
    q_end = kv_end*num_group*head_dim
    kv_start = kv_start*head_dim
    kv_end = kv_end*head_dim

    dim_lookup = {
        "colwise": (0, "out_features"),
        "rowwise": (1, "in_features")
    }
    assert style in dim_lookup
    shard_dim, size_attr = dim_lookup[style]

    def shard(x, dim, start, end):
        if dim==0:
            return x[start:end]
        elif dim==1:
            return x[:,start:end]

    def shard_qkv(qkv, dim, weight_splits):
        q, k, v = qkv.split(weight_splits, dim=dim)
        q = shard(q, dim, q_start, q_end)
        k = shard(k, dim, kv_start, kv_end)
        v = shard(v, dim, kv_start, kv_end)
        return torch.cat((q,k,v), dim=dim)

    if weight_splits:
        assert len(weight_splits) == 3
        sharded_weight = shard_qkv(linear.weight, shard_dim, weight_splits)
        if hasattr(linear, "scales") and style == "colwise":
            linear.scales = shard_qkv(linear.scales, 0, weight_splits)
    else:
        sharded_weight = shard(linear.weight, shard_dim, q_start, q_end)
        if hasattr(linear, "scales") and style == "colwise":
            linear.scales = shard(linear.scales, 0, q_start, q_end)

    linear.weight = nn.Parameter(sharded_weight, requires_grad=False)

    if linear.bias is not None and style == "colwise":
        if weight_splits:
            sharded_bias = shard_qkv(linear.bias, 0, weight_splits)
        else:
            sharded_bias = shard(linear.bias, 0, q_start, q_end)
        
        linear.bias = nn.Parameter(sharded_bias, requires_grad=False)
    setattr(linear, size_attr, linear.weight.shape[shard_dim])



def _apply_tp_linear_gate_up(linear: nn.Linear, style: str, rank_group=None) -> None:
    global_rank = _get_global_rank()
    rank = rank_group.index(global_rank)
    world_size = len(rank_group)

    dim_lookup = {
        "colwise": (0, "out_features"),
        "rowwise": (1, "in_features")
    }
    assert style in dim_lookup
    shard_dim, size_attr = dim_lookup[style]

    def shard(x, dim):
        return torch.chunk(x, world_size, dim=dim)[rank]
    
    def shard_w13(w13, dim):
        w1, w3 = w13.split([w13.shape[dim]//2, w13.shape[dim]//2], dim=dim)
        w1 = shard(w1, dim)
        w3 = shard(w3, dim)
        return torch.cat((w1, w3), dim=dim)

    sharded_weight = shard_w13(linear.weight, shard_dim)
    if hasattr(linear, "scales") and style == "colwise":
        linear.scales = shard_w13(linear.scales, 0)

    linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
    setattr(linear, size_attr, linear.weight.shape[shard_dim])



def _apply_tp_linear_mlp(linear: nn.Linear, style: str, rank_group=None) -> None:
    global_rank = _get_global_rank()
    rank = rank_group.index(global_rank)
    world_size = len(rank_group)

    dim_lookup = {
        "colwise": (0, "out_features"),
        "rowwise": (1, "in_features")
    }
    assert style in dim_lookup
    shard_dim, size_attr = dim_lookup[style]

    def shard(x, dim):
        return torch.chunk(x, world_size, dim=dim)[rank]

    sharded_weight = shard(linear.weight, shard_dim)
    if hasattr(linear, "scales") and style == "colwise":
        linear.scales = shard(linear.scales, 0)

    linear.weight = nn.Parameter(sharded_weight, requires_grad=False)
    setattr(linear, size_attr, linear.weight.shape[shard_dim])



def _apply_tp_ffn(mlp: FeedForward, rank_group, group) -> None:
    assert hasattr(mlp, "w13")
    assert hasattr(mlp, "w2")

    _apply_tp_linear_gate_up(mlp.w13, "colwise", rank_group=rank_group)
    _apply_tp_linear_mlp(mlp.w2, "rowwise", rank_group=rank_group)
    mlp.process_group = group



def _apply_tp_attn(attn: Attention, rank_group, config, group) -> None:
    assert hasattr(attn, "wqkv")
    assert hasattr(attn, "wo")

    kv_size = attn.n_local_heads * attn.head_dim
    _apply_tp_linear(attn.wqkv, "colwise", [attn.dim, kv_size, kv_size], rank_group=rank_group, num_kv_heads = attn.n_local_heads, num_heads = attn.n_head, head_dim=attn.head_dim)
    _apply_tp_linear(attn.wo, "rowwise", rank_group=rank_group, num_kv_heads = attn.n_local_heads, num_heads = attn.n_head, head_dim=attn.head_dim)

    attn.n_head = config.n_head
    attn.dim = config.dim
    attn.head_dim = attn.dim // attn.n_head
    attn.n_local_heads = config.n_local_heads
    attn.process_group = group


def _apply_tp_Transformer(Transformer: Transformer, rank_group, process_group) -> None:
    num_heads = Transformer.config.n_head
    num_kv_heads = Transformer.config.n_local_heads
    num_group = num_heads // num_kv_heads
    start, end= _select_kv_heads(num_kv_heads, rank_group)
    local_num_kv_heads = end-start
    local_num_heads= local_num_kv_heads*num_group
    local_dim = Transformer.config.head_dim * local_num_heads
    Transformer.config.n_head = local_num_heads
    Transformer.config.dim = local_dim
    Transformer.config.n_local_heads = local_num_kv_heads
    _apply_tp_linear_mlp(Transformer.output, "colwise", rank_group=rank_group)
    Transformer.process_group = process_group
    Transformer.world_size = dist.get_world_size(process_group)
    Transformer.rank = dist.get_rank(process_group)


def apply_tp(model: Transformer, rank_group, group) -> None:
    _apply_tp_Transformer(model, rank_group, group)
    for block in model.layers:
        _apply_tp_ffn(block.feed_forward, rank_group, group)
        _apply_tp_attn(block.attention, rank_group, model.config, group)