import sys
import os
import torch
import torch.nn as nn          # noqa: F401  (kept for completeness)
import torch.nn.functional as F  # noqa: F401

from log_utils import rank_log, get_logger, verify_min_gpu_count
from llama2 import Transformer, ModelArgs

from torch.distributed.device_mesh import init_device_mesh
from torch.distributed._tensor import Shard, Replicate
from torch.distributed.tensor.parallel import (
    parallelize_module,
    ColwiseParallel,
    RowwiseParallel,
    PrepareModuleInput,
    SequenceParallel,
)

import torch.cuda.nvtx as nvtx

# ---- GPU check ----------------------------------------------------
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
    print(f"Unable to locate sufficient {_min_gpu_count} GPUs to run this example. Exiting.")
    sys.exit()
# -------------------------------------------------------------------

logger = get_logger()

# world‑size / rank come from torchrun
_rank        = int(os.environ["RANK"])
_world_size  = int(os.environ["WORLD_SIZE"])
tp_size = 2
# tp_size = 2
print(f"Starting PyTorch 2‑D TP example on rank {_rank}.")
assert _world_size % tp_size == 0, (
    f"World size {_world_size} must be divisible by TP size {tp_size}"
)

# -------------------------------------------------------------------
# Device‑mesh: keep a 2‑D mesh (dp × tp) but we'll only *use* the tp dim
dp_size     = _world_size // tp_size
device_mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
rank_log(_rank, logger, f"Device mesh created: {device_mesh=}")

tp_mesh = device_mesh["tp"]
dp_mesh = device_mesh["dp"]      # only used for a per‑rank PRNG seed
dp_rank = dp_mesh.get_local_rank()

# -------------------------------------------------------------------
# Build a tiny Llama‑2 model and parallelise it over the tp mesh
simple_llama2_config = ModelArgs(dim=3072, n_layers=24, n_heads=24, vocab_size=32000)
model = Transformer.from_model_args(simple_llama2_config).to("cuda")
model.init_weights()

def calculate_model_size(model):
    total_params = sum(p.numel() for p in model.parameters())
    return total_params * 4 / (1024 ** 2)  # Assuming 4 bytes per parameter (float32), convert to MB

model_size_mb = calculate_model_size(model)
print(f"Rank {_rank}: Model size before TP: {model_size_mb:.2f} MB")
print(f"Rank {_rank}: Model after TP parallelisation:\n{model}\n")

# First/last layers
# model = parallelize_module(
#     model,
#     tp_mesh,
#     {
#         "tok_embeddings": RowwiseParallel(input_layouts=Replicate(), output_layouts=Shard(1)),
#         "norm": SequenceParallel(),
#         "output": ColwiseParallel(input_layouts=Shard(1), output_layouts=Replicate()),
#     },
# )

model = parallelize_module(
    model,
    tp_mesh,
    {
        "tok_embeddings": RowwiseParallel(input_layouts=Replicate()),
        "output": ColwiseParallel(output_layouts=Replicate()),
    },
)



# Per‑transformer‑block plan
for layer_id, block in enumerate(model.layers):
    # layer_tp_plan = {
    #     "attention_norm": SequenceParallel(),
    #     "attention": PrepareModuleInput(
    #         input_layouts=(Shard(1), None),
    #         desired_input_layouts=(Replicate(), None),
    #     ),
    #     "attention.wq": ColwiseParallel(),
    #     "attention.wk": ColwiseParallel(),
    #     "attention.wv": ColwiseParallel(),
    #     "attention.wo": RowwiseParallel(output_layouts=Shard(1)),
    #     "ffn_norm": SequenceParallel(),
    #     "feed_forward": PrepareModuleInput(
    #         input_layouts=(Shard(1),),
    #         desired_input_layouts=(Replicate(),),
    #     ),
    #     "feed_forward.w1": ColwiseParallel(),
    #     "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)),
    #     "feed_forward.w3": ColwiseParallel(),
    # }

    layer_tp_plan = {
        "attention.wq": ColwiseParallel(),
        "attention.wk": ColwiseParallel(),
        "attention.wv": ColwiseParallel(),
        "attention.wo": RowwiseParallel(),
        "feed_forward.w1": ColwiseParallel(),
        "feed_forward.w2": RowwiseParallel(),
        "feed_forward.w3": ColwiseParallel(),
    }

    # adjust heads for local TP rank
    attn = block.attention
    attn.n_heads     //= tp_mesh.size()
    attn.n_kv_heads  //= tp_mesh.size()

    parallelize_module(block, tp_mesh, layer_tp_plan)

model_size_mb = calculate_model_size(model)
print(f"Rank {_rank}: Model size after TP: {model_size_mb:.2f} MB")
print(f"Rank {_rank}: Model after TP parallelisation:\n{model}\n")

# -------------------------------------------------------------------
# Optimiser & toy training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3, foreach=False)

print(f"Rank {_rank}: Starting training ...")
num_iterations = 5
batch_size     = 4

for i in range(num_iterations):
    nvtx.range_push(f"Itr {i}")
    iteration_start_time = torch.cuda.Event(enable_timing=True)
    iteration_end_time = torch.cuda.Event(enable_timing=True)
    forward_start_time = torch.cuda.Event(enable_timing=True)
    forward_end_time = torch.cuda.Event(enable_timing=True)
    backward_start_time = torch.cuda.Event(enable_timing=True)
    backward_end_time = torch.cuda.Event(enable_timing=True)

    torch.manual_seed(i + dp_rank)                  # identical inputs inside each TP group
    inp = torch.randint(32000, (batch_size, 3072), device="cuda")  # (8, 256)

    nvtx.range_push("Forward pass")
    iteration_start_time.record()
    forward_start_time.record()
    output = model(inp)
    forward_end_time.record()
    # torch.cuda.synchronize()
    nvtx.range_pop()
    print(f"Rank {_rank}: Iter {i} forward pass complete")

    nvtx.range_push("Backward pass")
    backward_start_time.record()
    output.sum().backward()
    
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)
    backward_end_time.record()
    iteration_end_time.record()
    torch.cuda.synchronize()
    nvtx.range_pop()
    forward_time = forward_start_time.elapsed_time(forward_end_time)
    backward_time = backward_start_time.elapsed_time(backward_end_time)
    iteration_time = iteration_start_time.elapsed_time(iteration_end_time)
    nvtx.range_pop()
    print(f"Rank {_rank}: Iter {i} complete | Forward: {forward_time:.2f} ms | Backward: {backward_time:.2f} ms | Iteration: {iteration_time:.2f} ms\n")

print(f"Rank {_rank}: Training successfully completed!")
