import torch
import numpy as np

import logging
from rich import print

logger = logging.getLogger(__file__)

import numpy as np

def add_agent_identifiers(input_ids, 
                          position_ids, 
                          num_agents, 
                          tokenizer,
                          shift = 1000):
    prompt_token_len = len(input_ids[0]) - num_agents + 1
    agent_identifier_prompts = [f"You are agent {i}:\n" for i in range(num_agents)]
    agent_identifier_input_ids = [tokenizer.encode(p) for p in agent_identifier_prompts]
    agent_identifier_pos_ids = [prompt_token_len + np.arange(len(ids)) + (i * shift) for i, ids in enumerate(agent_identifier_input_ids)]

    max_len = max([len(seq) for seq in agent_identifier_input_ids])
    flatten_ai_input_ids = []
    flatten_ai_pos_ids = []
    for i in range(max_len):
        for seq_input_ids, seq_pos_ids in zip(agent_identifier_input_ids, agent_identifier_pos_ids):
            if i < len(seq_input_ids):
                flatten_ai_input_ids.append(seq_input_ids[i])
                flatten_ai_pos_ids.append(seq_pos_ids[i])
                
    new_input_ids = torch.concat([input_ids, torch.tensor(flatten_ai_input_ids, 
                                                          dtype=input_ids.dtype, 
                                                          device=input_ids.device).unsqueeze(0)], dim=-1)

    new_pos_ids = torch.concat([position_ids, torch.tensor(flatten_ai_pos_ids, 
                                                          dtype=input_ids.dtype, 
                                                          device=input_ids.device).unsqueeze(0)], dim=-1)
    
    
    return dict(input_ids=new_input_ids, position_ids=new_pos_ids)
    

def get_model_inputs(prmopt, 
                     tokenizer, 
                     num_agents, 
                     system_prompt = None,
                     shift = 1000):
    
    messages = [
        {"role": "user", "content": prmopt},
    ]
    if system_prompt is not None:
        messages = [{"role": "system", "content": system_prompt}] + messages

    input_ids = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        enable_thinking=True,
        return_tensors="pt"
    )
    length = input_ids.size(1)
    position_ids = torch.arange(length).unsqueeze(0)
    input_ids = torch.concat([input_ids, (input_ids[0,-1:].repeat(1, num_agents-1))], dim=1)
    position_ids = torch.concat([
        position_ids, 
        position_ids[0,-1:].repeat(1, num_agents-1) + torch.arange(1, num_agents).unsqueeze(0) * shift
    ], dim=1)
    
    return dict(input_ids=input_ids, position_ids=position_ids)


class GroupThinkGenerator:
    def __init__(self, model, tokenizer, num_agents, shift=1000, 
                 interleave_prompt=False):
        self.model = model.eval()
        self.tokenizer = tokenizer
        self.num_agents = num_agents
        self.shift = shift
        self.interleave_prompt = interleave_prompt
        self.eos_token_id = tokenizer.eos_token_id
        self.pad_token_id = tokenizer.pad_token_id

    @torch.no_grad
    def generate(self, 
                 query, 
                 max_extra_tokens=10, 
                 system_prompt=None,
                 include_aggent_identifier=False):
        model_inputs = get_model_inputs(query, tokenizer=self.tokenizer, num_agents=self.num_agents, system_prompt=system_prompt, shift=self.shift)
        if include_aggent_identifier:
            model_inputs = add_agent_identifiers(**model_inputs, num_agents=self.num_agents, tokenizer=self.tokenizer, shift=self.shift)

        input_ids = model_inputs['input_ids'].to(self.model.device)
        position_ids = model_inputs['position_ids'].to(self.model.device)
        generated_ids_by_round = []
        active_agents = torch.ones(self.num_agents, dtype=torch.bool, device=self.model.device)
        logger.debug(f"input string: \n{self.tokenizer.decode(input_ids[0].cpu().detach())}\n{'='*80}")
        logger.debug(f"input ids: \n{input_ids}")
        generated_ids_by_round = torch.full((1, self.num_agents, max_extra_tokens), self.pad_token_id, dtype=torch.long, device=self.model.device)
        for k in range(max_extra_tokens):
            logger.debug(f"round {k}: position_ids = {position_ids}")
            if not active_agents.any():
                break
                
            if k == 0:
                outputs = self.model(
                    input_ids=input_ids,
                    position_ids=position_ids,
                    use_cache=True
                )
                # After first generation, we only need the last position IDs
                position_ids = position_ids[:, -self.num_agents:]
            else:
                # Only include ids for active agents
                active_position_ids = position_ids[:, active_agents]
                active_input_ids = output_ids[:, active_agents]
                
                logger.debug(f"\tactive pos_ids = {active_position_ids};\n\tactive_input_ids = {active_input_ids}")
                
                outputs = self.model(
                    input_ids=active_input_ids,
                    position_ids=active_position_ids,
                    past_key_values=past_key_values,
                    use_cache=True
                )
    
            logits = outputs.logits
            past_key_values = outputs.past_key_values

            # Only generate for active agents
            active_logits = logits[:, -active_agents.sum():, :]
            active_output_ids = active_logits.argmax(-1)
            
            # Create full output_ids tensor with EOS for inactive agents
            output_ids = torch.full((1, self.num_agents), self.pad_token_id, device=self.model.device)
            output_ids[:, active_agents] = active_output_ids
            logger.debug(f"\toutput_ids: {output_ids[:, active_agents]}")
            # Update active agents based on EOS
            active_agents = active_agents & (output_ids[0] != self.eos_token_id)
            
            # Update position_ids for next round
            position_ids = position_ids + 1
            position_ids[:, ~active_agents] = -100  # Set position to 0 for inactive agents
            
            generated_ids_by_round[0, :, k] = output_ids

            output_ids = output_ids.to(self.model.device)
            input_ids = None  # We only use output_ids from now on

        generated_ids = generated_ids_by_round.cpu().detach().numpy()[0]
        logger.debug(generated_ids)
        # Filter out PAD tokens from each agent's sequence
        generated_ids = [[token for token in seq if token != self.pad_token_id] for seq in generated_ids]
        return generated_ids

if __name__ == '__main__':
    #%%
    from transformers import AutoTokenizer, AutoModelForCausalLM
    import logging


    model_path = "/Users/fengtingliao/external/model_hf/Qwen3-0.6B"
    model_path = "/Users/fengtingliao/external/model_hf/Qwen2.5-0.5B-Instruct"

    model = AutoModelForCausalLM.from_pretrained(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    # logging.basicConfig(level=logging.DEBUG)
    logging.basicConfig(level=logging.INFO)

    num_agents = 3
    system_prompt = f"You are an multi-agent system. There are {num_agents} agents in the system. Each agent can see each other's prompt. You are encourage to get agents collaborate with each other."
    system_prompt="You are a helpful assistant."
    # system_prompt=None
    prompt = "What's the distance between LA and NY?"
    # prompt = "Tell me a joke about scientists."
    # prompt = "If x = 3, y = 2, then x^2 + y^2 equals?"

    generator = GroupThinkGenerator(model, tokenizer, num_agents=num_agents)
    generated_ids = generator.generate(prompt, 
                                    max_extra_tokens=128, 
                                    system_prompt=system_prompt,
                                    include_aggent_identifier=False)
    generated_strings = [tokenizer.decode(generated) for generated in generated_ids]
    #%%

    for i, gstr in enumerate(generated_strings):
        print(f"agent {i}:\n{gstr}")

    # %%
    print(generated_strings)

