from typing import List, Optional, Iterator, Dict
import os
from transformers import pipeline
import torch
from utils import chat, predict, truncate


class WorkerAgent:
    """Worker agent that processes individual chunks of text."""
    def __init__(self, model: str, system_prompt: str, max_new_tokens: int, tokenizer_kwargs: dict, pipeline_args: dict, chunk_size: int):
        """
        Initialize a worker agent.
        
        Args:
            model: The LLM model to use (e.g., "gpt-3.5-turbo")
            system_prompt: The system prompt that defines the worker's role
        """
        self.model = model
        self.first_system_prompt, self.system_prompt = system_prompt
        self.tokenizer_kwargs = tokenizer_kwargs 

        # This is for using different max new tokens for different agents
        self.tokenizer_kwargs["max_new_tokens"] = max_new_tokens
        self.client = pipeline(**pipeline_args)
        self.chunk_size = chunk_size 

    def batch_process_chunk(self, chunks, queries, previous_cus, is_first_agent):
        if is_first_agent:
            pt = self.first_system_prompt
        else:
            pt = self.system_prompt

        batch_messages = [
            pt.format(input_chunk=chunk, prev_cu=previous_cu, query=query) for (chunk, query, previous_cu)
            in zip(chunks, queries, previous_cus)
        ]
        batch_messages = [truncate(batch_message, self.client.tokenizer, self.chunk_size) for batch_message in batch_messages]

        batch_responses = predict(self.client, batch_messages, self.tokenizer_kwargs, self.model, is_batch=True)

        return batch_responses
    
    def process_chunk(self, chunk: str, query: str, previous_cu: Optional[str] = None, is_first_agent: bool = False) -> str:
        if is_first_agent:
            pt = self.first_system_prompt
        else:
            pt = self.system_prompt

        messages = pt.format(input_chunk=chunk, prev_cu=previous_cu, query=query)
        messages = truncate(messages, self.client.tokenizer, self.chunk_size)
        response = predict(self.client, messages, self.tokenizer_kwargs, self.model, is_batch=False)

        return response


class ManagerAgent:
    """Manager agent that synthesizes outputs from worker agents."""

    def __init__(self, model: str, system_prompt: str, tokenizer_kwargs: dict, pipeline_args: dict, chunk_size: int):
        """
        Initialize a manager agent.

        Args:
            model: The LLM model to use (e.g., "gpt-4")
            system_prompt: The system prompt that defines the manager's role
        """
        self.model = model
        self.system_prompt = system_prompt

        self.tokenizer_kwargs = tokenizer_kwargs 
        self.client = pipeline(**pipeline_args)

        self.chunk_size = chunk_size 


    def batch_synthesize(self, batch_worker_outputs, queries):
        batch_messages = [self.system_prompt.format(
            summary=worker_outputs, query=query)
            for (worker_outputs, query) in zip(batch_worker_outputs, queries)]
        batch_messages = [truncate(batch_message, self.client.tokenizer, self.chunk_size) for batch_message in batch_messages]
        batch_responses = predict(self.client, batch_messages, self.tokenizer_kwargs, self.model, is_batch=True)
        
        return batch_responses 
    
    def synthesize(self, previous_cu: str, query: str) -> str:
        messages = self.system_prompt.format(summary=previous_cu, query=query)
        messages = truncate(messages, self.client.tokenizer, self.chunk_size)
        response = predict(self.client, messages, self.tokenizer_kwargs, self.model, is_batch=False)

        return response 
