import torch
from typing import List, Tuple
from transformers import DynamicCache

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PAD_TOKEN_ID = 0  # set after tokenizer load


# -------------------------
# Per-path state container
# -------------------------
class PathState:
    def __init__(self, ids, past, length):
        self.ids = ids  # tensor (1, seq_len)
        self.past = past  # tuple of (k,v) per layer
        self.length = length
        self.finished = False


# -------------------------
# Utilities for past KV handling
# -------------------------
def split_batched_past(batched_past):
    """Split a batched past into list of per-path pasts."""
    num_layers = len(batched_past)
    batch = batched_past[0][0].shape[0]
    per_path = []
    for b in range(batch):
        layers = []
        for l in range(num_layers):
            k = batched_past[l][0][b : b + 1].contiguous()
            v = batched_past[l][1][b : b + 1].contiguous()
            layers.append((k, v))
        per_path.append(tuple(layers))
    return per_path


def build_global_memory(prefix_past, path_states: List[PathState], prefix_len: int):
    """
    Build a single memory for all paths:
    prefix || suffix(path0) || suffix(path1) || ...
    Returns (batched_mem, offsets, mem_len).
    batched_mem is already expanded to batch size P.
    """
    suffix_lens = [p.length for p in path_states]
    num_layers = len(prefix_past)

    # compute offsets
    offsets = []
    cur = prefix_len
    for s in suffix_lens:
        offsets.append((cur, cur + s))
        cur += s

    mem_layers = []
    for l in range(num_layers):
        prefix_k, prefix_v = prefix_past[l]
        suffix_keys, suffix_vals = [], []
        for i, p in enumerate(path_states):
            s_len = suffix_lens[i]
            if s_len > 0:
                # Extract the suffix part from each path's past
                sk = p.past[l][0][:, :, prefix_len : prefix_len + s_len, :].contiguous()
                sv = p.past[l][1][:, :, prefix_len : prefix_len + s_len, :].contiguous()
            else:
                n_heads, head_dim = p.past[l][0].shape[1], p.past[l][0].shape[3]
                sk = p.past[l][0].new_zeros((1, n_heads, 0, head_dim))
                sv = p.past[l][1].new_zeros((1, n_heads, 0, head_dim))
            suffix_keys.append(sk)
            suffix_vals.append(sv)
        
        # Concatenate prefix with all suffix parts
        # Make sure all tensors have the same batch size (1) before concatenating
        all_keys = [prefix_k] + suffix_keys
        all_vals = [prefix_v] + suffix_vals
        
        # Ensure all tensors have the same batch size
        for i in range(len(all_keys)):
            if all_keys[i].shape[0] != 1:
                all_keys[i] = all_keys[i].unsqueeze(0)
            if all_vals[i].shape[0] != 1:
                all_vals[i] = all_vals[i].unsqueeze(0)
        
        mem_k = torch.cat(all_keys, dim=2)
        mem_v = torch.cat(all_vals, dim=2)

        # expand to batch size P here
        P = len(path_states)
        mem_layers.append(
            (
                mem_k.expand(P, -1, -1, -1).contiguous(),
                mem_v.expand(P, -1, -1, -1).contiguous(),
            )
        )

    mem_len = mem_layers[0][0].shape[2]
    assert mem_len == prefix_len + sum(suffix_lens), "Memory length mismatch!"

    return tuple(mem_layers), offsets, mem_len


def split_global_memory(
    returned_batched_past, prefix_len, offsets: List[Tuple[int, int]], mem_len: int
):
    """
    Reconstruct per-path pasts after model call.
    Each path's past = prefix || suffix(path_i) || new_token(path_i)
    """
    num_layers = len(returned_batched_past)
    batch = returned_batched_past[0][0].shape[0]
    assert len(offsets) == batch, "Offsets length must equal batch size"

    per_path_pasts = []
    for b in range(batch):
        layers = []
        for l in range(num_layers):
            ret_k = returned_batched_past[l][0][b : b + 1]
            ret_v = returned_batched_past[l][1][b : b + 1]
            pref_k = ret_k[:, :, :prefix_len, :].contiguous()
            pref_v = ret_v[:, :, :prefix_len, :].contiguous()
            start, end = offsets[b]
            if end > start:
                path_k = ret_k[:, :, start:end, :].contiguous()
                path_v = ret_v[:, :, start:end, :].contiguous()
            else:
                n_heads, head_dim = ret_k.shape[1], ret_k.shape[3]
                path_k = ret_k.new_zeros((1, n_heads, 0, head_dim))
                path_v = ret_v.new_zeros((1, n_heads, 0, head_dim))
            new_k = ret_k[:, :, mem_len : mem_len + 1, :].contiguous()
            new_v = ret_v[:, :, mem_len : mem_len + 1, :].contiguous()
            k_concat = torch.cat([pref_k, path_k, new_k], dim=2)
            v_concat = torch.cat([pref_v, path_v, new_v], dim=2)
            layers.append((k_concat, v_concat))
        per_path_pasts.append(tuple(layers))
    return per_path_pasts


# -------------------------
# Core generator
# -------------------------
@torch.no_grad()
def group_think_cross_attend_generate(
    model,
    tokenizer,
    prompt: str,
    system_prompt: str = None,
    num_paths: int = 4,
    shift: int = 3000,
    max_path_tokens: int = 256,
    verbose: bool = True,
    step_callback: callable = None,
):
    global PAD_TOKEN_ID
    PAD_TOKEN_ID = (
        tokenizer.pad_token_id
        if tokenizer.pad_token_id is not None
        else tokenizer.eos_token_id
    )

    # # === Prefill prefix
    # if not prompt.endswith("<Parallel>"):
    #     raise ValueError("Prompt must end with `<Parallel>`.")

    # Apply chat template to the question
    if system_prompt is None:
        system_prompt = "You are participating in a group think session where multiple thinkers are answering the question in parallel resulting in concurrent thinking paths."
    
    formatted_text = f"{system_prompt}\n\n{prompt}\n\n"
    formatted_text = tokenizer.apply_chat_template(
        [{"role": "system", "content": system_prompt},
        {"role": "user", "content": prompt}],
        tokenize=False,
        add_generation_prompt=True,
        return_tensors=None,
    )   
    # Add parallel structure
    formatted_text += "<Parallel>"
    
    prefix_text = formatted_text
    prefix_ids = tokenizer(prefix_text, return_tensors="pt").input_ids.to(device)
    prefix_out = model(input_ids=prefix_ids, use_cache=True)
    prefix_past = prefix_out.past_key_values
    prefix_len = prefix_ids.shape[1]

    # === Initialize path states with seed tokens
    seed_texts = [f"<Path>\nThinker {i + 1}:\n" for i in range(num_paths)]
    
    # Initialize path states by processing each seed individually
    path_states = []
    for i, seed_text in enumerate(seed_texts):
        # Tokenize the seed text
        seed_ids = tokenizer(seed_text, return_tensors="pt").input_ids.to(device)
        seed_len = seed_ids.shape[1]
        
        # Create position_ids for this path
        start_pos = prefix_len + (i * shift)
        position_ids = torch.arange(start_pos, start_pos + seed_len, device=device).unsqueeze(0)
        
        # Generate with the prefix past
        seed_out = model(
            input_ids=seed_ids,
            past_key_values=prefix_past,
            position_ids=position_ids,
            use_cache=True,
        )
        
        # Create PathState for this path
        path_state = PathState(
            seed_ids.clone(), seed_out.past_key_values, seed_len
        )
        path_states.append(path_state)

    # === Generation loop
    step = 0
    while not all(p.finished for p in path_states):
        # Build global memory (same for each path)
        global_mem, offsets, mem_len = build_global_memory(
            prefix_past, path_states, prefix_len
        )

        # prepare batch inputs
        batch_input_ids, batch_position_ids = [], []
        for i, p in enumerate(path_states):
            if p.finished:
                batch_input_ids.append(torch.tensor([[PAD_TOKEN_ID]], device=device))
            else:
                batch_input_ids.append(p.ids[:, -1:].to(device))
            batch_position_ids.append(
                torch.tensor([[prefix_len + (i * shift) + p.length]], device=device)
            )

        batch_input_ids = torch.cat(batch_input_ids, dim=0)
        batch_position_ids = torch.cat(batch_position_ids, dim=0)
        P = len(path_states)
        attn_mask = torch.ones((P, 1), dtype=torch.long, device=device)

        # Convert global_mem to DynamicCache format
        cache = DynamicCache()
        for layer_idx, (k, v) in enumerate(global_mem):
            cache.update(k, v, layer_idx)

        out = model(
            input_ids=batch_input_ids,
            past_key_values=cache,
            attention_mask=attn_mask,
            position_ids=batch_position_ids,
            use_cache=True,
        )
        next_tokens = torch.argmax(out.logits[:, -1, :], dim=-1)
        returned_batched_past = out.past_key_values

        # Convert DynamicCache back to tuple format for split_global_memory
        returned_batched_past_tuple = []
        for layer_idx in range(len(returned_batched_past)):
            k = returned_batched_past[layer_idx][0]  # keys
            v = returned_batched_past[layer_idx][1]  # values
            returned_batched_past_tuple.append((k, v))
        returned_batched_past_tuple = tuple(returned_batched_past_tuple)
        
        new_pasts = split_global_memory(
            returned_batched_past_tuple, prefix_len, offsets, mem_len
        )

        # update states
        for i, p in enumerate(path_states):
            if p.finished:
                continue
            p.past = new_pasts[i]
            tok = next_tokens[i].view(1, 1)
            p.ids = torch.cat([p.ids, tok], dim=1)
            p.length += 1
            if (
                "</Path>" in tokenizer.decode(p.ids[0], skip_special_tokens=False)
                or p.length >= max_path_tokens
            ):
                p.finished = True
                if verbose:
                    print(f"[engine] Path {i+1} finished.")
        
        # Call step callback if provided (for interactive updates)
        if step_callback:
            step_callback(path_states, tokenizer, step)
        
        step += 1

    return path_states
