from typing import Optional, Iterator, Dict
from .agents import WorkerAgent, ManagerAgent
from .coa_utils import split_into_chunks, get_data_specific_prompt, batch_split_into_chunks
import logging
import json
import numpy as np
from multiprocessing import Pool
from .goa_utils import construct_subgraphs, get_embedding_model
import gc
import torch
from sentence_transformers import SentenceTransformer
import os
import json 

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def memory_stats():
    print(torch.cuda.memory_allocated()/1024**2)
    print(torch.cuda.memory_cached()/1024**2)


class GraphOfAgnets:
    """Main class for the Chain of Agents implementation."""
    
    def __init__(
        self,
            debug: bool = False,
        worker_model: str = "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",  # Together AI model
        manager_model: str = "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",  # Together AI model
        chunk_size: int = 500,
            max_gen: int=128,
        goa_cluster_size: int=4,
        goa_no_context: bool=False,
        worker_prompt: Optional[str] = None,
        manager_prompt: Optional[str] = None,
        device: Optional[str] = None,
        dataset: Optional[str] = None,
        prompt_mode: Optional[str] = None,
        goa_mode: Optional[str] = None,
        model_kwargs: Optional[dict] = None,
        tokenizer_kwargs: Optional[dict] = None,
        pipeline_args: Optional[dict] = None,
        ablation_type: Optional[str] = 'None',
        summary_tag: Optional[bool]=True,
        use_coa: Optional[bool]=False,
        log_summary_dir: Optional[str] = None
    ):
        """
        Initialize the Chain of Agents.
        
        Args:
            worker_model: Model to use for worker agents
            manager_model: Model to use for manager agent
            chunk_size: Maximum tokens per chunk
            worker_prompt: Custom system prompt for workers
            manager_prompt: Custom system prompt for manager
        """
        default_worker_prompt, default_manager_prompt = get_data_specific_prompt(dataset, prompt_mode, ablation_type, summary_tag)
        print(f"WORKER PROMPT for DATASET {dataset}: {default_worker_prompt}")
        print(f"MANAGER PROMPT for DATASET {dataset}: {default_manager_prompt}")

        if isinstance(default_manager_prompt, tuple):
            default_manager_prompt = default_manager_prompt[0]
            print("WARNING: Manager prompt is a tuple, using the first element only.")

        # Initialization
        self.embedding_model = None
        self.worker_prompt = worker_prompt or default_worker_prompt
        self.manager_prompt = manager_prompt or default_manager_prompt
        self.chunk_size = chunk_size
        self.worker_model = worker_model
        self.manager_model = manager_model
        self.device = device
        if self.device is None:
            self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        self.tokenizer_kwargs = tokenizer_kwargs
        self.pipeline_args = pipeline_args
        self.manager_max_gen = max_gen

        if self.chunk_size <= 2000:
            self.worker_max_gen = 256
        elif self.chunk_size <= 4000:
            self.worker_max_gen = 512
        else:
            self.worker_max_gen = 1024 

        self.summary_tag = summary_tag
        
        ## goa 
        self.use_coa = use_coa
        if self.use_coa:
            self.embedding_model = None
        else:
            self.embedding_model = SentenceTransformer('BAAI/bge-m3', device=self.device)
        self.goa_cluster_size = goa_cluster_size
        self.goa_no_context = goa_no_context

        # TODO: remove below
        if goa_mode and goa_mode != 'None':
            raise ValueError("DO NOT PROVIDE GOA_MODE")
        self.goa_mode = goa_mode
        # if self.goa_mode == 'None': self.goa_mode=None


        if ablation_type != 'None':
            self.ablation_type = ablation_type
        else:
            self.ablation_type = None
        # Adjust the chunk size for the worker model (considering the instruction and response)
        # Here, we assume that the instruction takes about 1000 tokens, so we reduce the chunk size accordingly
        self.chunk_size -= 1000 
        self.chunk_size -= self.worker_max_gen
        self.chunk_size = max(self.chunk_size, 2000)  # Ensure chunk size is at least 2000
        
        self.debug = debug
        self.log_summary_dir = log_summary_dir
        if self.debug:
            if self.log_summary_dir is None:
                raise ValueError("Pass the directory for logging communication chains")
            else:
                os.makedirs(log_summary_dir, exist_ok=True)

        logger.info(f"Initialized Chain of Agents with {worker_model} workers and {manager_model} manager")
        
    @torch.no_grad()
    def process(self, input_text, query, return_summary=False, return_pmi_score=False, get_embedding=False):
        """
        Process a single input text and query.
        
        Args:
            input_text: The input text to process
            query: The query to use for processing
            return_summary: Whether to return the summary or not
        
        Returns:
            str: The final output from the manager agent
        """
        if self.debug:
            example_number = np.random.randint(0, 1e8)
            fname = f'{example_number}_workers.jsonl'
            manager_fname = f'{example_number}_manager.jsonl'
            worker_log_file_name = os.path.join(self.log_summary_dir, fname)
            manager_log_file_name = os.path.join(self.log_summary_dir, manager_fname)
        # if self.goa_mode is not None and "contextual" in self.goa_mode:
        if self.use_coa or self.goa_no_context:
            pass 
        else:
            return self.contextual_process(input_text, query, return_summary, return_pmi_score)
        
        chunks = split_into_chunks(input_text, self.chunk_size, self.worker_model)
        
        if not self.use_coa:
            chunks, num_cluster = construct_subgraphs(
                chunks, query, self.embedding_model, self.goa_cluster_size, is_batch=False)

        # Remvoe this
        if get_embedding and self.embedding_model is None:
            self.embedding_model = SentenceTransformer('BAAI/bge-m3', device=self.device)
            num_cluster = 1
        else:
            num_cluster = 1

        if self.debug:
            worker_fout = open(worker_log_file_name, 'a', encoding='utf-8')

        if num_cluster == 1:
            previous_cu = "Not available"
            worker = WorkerAgent(self.worker_model, self.worker_prompt, self.worker_max_gen, self.tokenizer_kwargs, self.pipeline_args, self.chunk_size)

            ## Iterate over chunks
            for worker_num in range(len(chunks)):
                # get worker_num-th chunk for each batch
                current_xs = chunks[worker_num] 

                worker_outputs = worker.process_chunk(
                    current_xs, query, previous_cu, worker_num == 0)

                if self.summary_tag:
                    # extract summary tag
                    if '<summary>' in worker_outputs and '</summary>' in worker_outputs:
                        start_idx = worker_outputs.index('<summary>') + len('<summary>')
                        end_idx = worker_outputs.index('</summary>')
                        previous_cu = worker_outputs[start_idx:end_idx].strip()
                    else:
                        previous_cu = worker_outputs
                else:
                    previous_cu = worker_outputs

                print('-'*100)
                print(f'\t\tWorker {worker_num}')
                print(f"WORKER INPUT MESSAGE (cropped): {current_xs[:300]}")
                print(f"OUTPUT (cropped): {previous_cu[:500]} \n\n {'-' * 30}")

                if self.debug:
                    item = {'worker_num': worker_num, 'output': previous_cu}
                    worker_fout.write(json.dumps(item, ensure_ascii=False) + '\n')

            worker_final_outputs = previous_cu

        else:
            adjusted_max_gen = self.worker_max_gen // num_cluster
            worker = WorkerAgent(self.worker_model, self.worker_prompt, adjusted_max_gen, self.tokenizer_kwargs, self.pipeline_args, self.chunk_size)
            
            ## TODO: run parallel inferences
            total_worker_outputs = []
            for cluster_idx, cluster_chunks in enumerate(chunks):
                previous_cu = "Not available"
                print(f"Processing cluster {cluster_idx + 1}/{num_cluster} with {len(cluster_chunks)} chunks")
                for worker_num, current_xs in enumerate(cluster_chunks):
                    worker_outputs = worker.process_chunk(
                        current_xs, query, previous_cu, worker_num == 0)
                    
                    if self.summary_tag:
                        # extract summary tag
                        if '<summary>' in worker_outputs and '</summary>' in worker_outputs:
                            start_idx = worker_outputs.index('<summary>') + len('<summary>')
                            end_idx = worker_outputs.index('</summary>')
                            previous_cu = worker_outputs[start_idx:end_idx].strip()
                        else:
                            previous_cu = worker_outputs
                    else:
                        previous_cu = worker_outputs

                    if self.debug:
                        item = {'worker_num': worker_num, 'cluster_idx': cluster_idx, 'output': previous_cu}
                        worker_fout.write(json.dumps(item, ensure_ascii=False) + '\n')


                total_worker_outputs.append(previous_cu)            


            worker_final_outputs = "\n".join(total_worker_outputs)

        print('^'*100)
        print('\tBEFORE DELETING WORKER')
        memory_stats()
        del worker
        gc.collect()
        torch.cuda.empty_cache()
        print('^' * 100)
        print('\tAFTER DELETING WORKER')
        memory_stats()

        if self.debug:
            worker_fout.close()


        if return_pmi_score:
            emb1 = self.embedding_model.encode([worker_final_outputs])[0]
            emb2 = self.embedding_model.encode([query])[0]
            pmi_score = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
            return None, total_worker_outputs, pmi_score

        manager = ManagerAgent(self.manager_model, self.manager_prompt, self.tokenizer_kwargs, self.pipeline_args, self.chunk_size)

        final_outputs = manager.synthesize(worker_final_outputs, query)


        
        if self.debug:
            with open(manager_log_file_name, "a", encoding="utf-8") as f:
                record = {
                    "input": worker_final_outputs,
                    "query": query,
                    "output": final_outputs
                }
                json.dump(record, f, ensure_ascii=False)
                f.write('\n')

        print('^' * 100)
        print('\tBEFORE DELETING MANGER')
        memory_stats()
        del manager
        gc.collect()
        torch.cuda.empty_cache()
        print('^' * 100)
        print('\tAFTER DELETING MANGER')
        memory_stats()

        
        if return_summary:
            return final_outputs, worker_final_outputs
        else:
            return final_outputs
        
    @torch.no_grad()
    def contextual_process(self, input_text, query, return_summary=False, return_pmi_score=False):
        if self.debug:
            example_number = np.random.randint(0, 1e8)
            fname = f'{example_number}_workers.jsonl'
            manager_fname = f'{example_number}_manager.jsonl'
            worker_log_file_name = os.path.join(self.log_summary_dir, fname)
            manager_log_file_name = os.path.join(self.log_summary_dir, manager_fname)

        chunks = split_into_chunks(input_text, self.chunk_size, self.worker_model)

        print(f"Number of chunks: {len(chunks)}")

        ## Cluster 
        if self.ablation_type is not None and 'rand_cluster' in self.ablation_type:
            print('Finding random clusters for ablation study')
            cluster_chunks = [] 
            # _, num_cluster, _ = self.goa_mode.split('-')
            num_cluster = self.goa_cluster_size
            cluster_idx = np.random.randint(0, int(num_cluster), size=len(chunks))
            
            for i in range(int(num_cluster)):
                allocated = [chunks[j] for j in range(len(chunks)) if cluster_idx[j] == i]
                if len(allocated) == 0:
                    continue
                cluster_chunks.append(allocated)
            chunks = cluster_chunks
            num_cluster = len(set(cluster_idx))
            print(f"Cluster indices: {cluster_idx}")
            print(f"Cluster chunks: {num_cluster}")
            print(f"length of cluster chunks: {len(chunks)}")
            if num_cluster == 1:
                chunks = chunks[0]
        else:
            chunks, num_cluster = construct_subgraphs(
                        chunks, query, self.embedding_model, self.goa_cluster_size, is_batch=False, search_method='cluster_only', 
                        ablation_type=self.ablation_type)

        def get_closest_chunk(chunks, query, embedding_model, visited_nodes=None, previous_cu=None):
            """
            Get the chunk closest to the query using cosine similarity.
            
            Args:
                chunks: List of text chunks
                query: Query string
                embedding_model: Pre-trained embedding model
            
            Returns:
                str: The chunk closest to the query
            """
            if visited_nodes is None or len(visited_nodes) == 0:
                valid_idx = list(range(len(chunks)))
            else:
                valid_idx = [i for i in range(len(chunks)) if i not in visited_nodes]
            
            if previous_cu != "Not available":
                # If previous_cu is available, we can use it to refine the search
                contextual_chunks = [f"{previous_cu}\n\n{chunks[i]}" for i in valid_idx]
            else:
                contextual_chunks = [chunks[i] for i in valid_idx]

            if self.ablation_type is not None and 'lexical' in self.ablation_type:
                query_embedding = embedding_model.encode([query], return_sparse=True)['lexical_weights']
                chunk_embeddings = embedding_model.encode(contextual_chunks, return_sparse=True)['lexical_weights']
                similarities = embedding_model.compute_lexical_matching_score(query_embedding, chunk_embeddings)[0]
            elif self.ablation_type is not None and 'multivec' in self.ablation_type:
                query_embedding = embedding_model.encode([query], return_colbert_vecs=True)['colbert_vecs'][0]
                chunk_embeddings = embedding_model.encode(contextual_chunks, return_colbert_vecs=True)['colbert_vecs']
                similarities = np.array([embedding_model.colbert_score(query_embedding, chunk_embedding) for chunk_embedding in chunk_embeddings])
            elif self.ablation_type is not None and 'hybrid_search' in self.ablation_type:
                query_embedding = embedding_model.encode([query], return_dense=True, return_sparse=True, return_colbert_vecs=True)
                chunk_embeddings = embedding_model.encode(contextual_chunks, return_dense=True, return_sparse=True, return_colbert_vecs=True)

                dense_qe = query_embedding['dense_vecs']
                dense_ce = chunk_embeddings['dense_vecs']
                dense_sim = dense_qe @ dense_ce.T
                dense_sim = dense_sim[0]

                sparse_qe = query_embedding['lexical_weights']
                sparse_ce = chunk_embeddings['lexical_weights']
                sparse_sim = embedding_model.compute_lexical_matching_score(sparse_qe, sparse_ce)
                sparse_sim = sparse_sim[0]

                mult_qe = query_embedding['colbert_vecs'][0]
                mult_ce = chunk_embeddings['colbert_vecs']
                mult_sim = np.array([embedding_model.colbert_score(mult_qe, mult_c) for mult_c in mult_ce])

                similarities = dense_sim + sparse_sim + mult_sim

            else:
                query_embedding = embedding_model.encode([query])[0]
                chunk_embeddings = embedding_model.encode(contextual_chunks)
                similarities = np.dot(chunk_embeddings, query_embedding) / (
                    np.linalg.norm(chunk_embeddings, axis=1) * np.linalg.norm(query_embedding))
            
                print(f"query_embedding shape : {query_embedding.shape}")
                print(f"chunk_embeddings shape : {chunk_embeddings.shape}")
                print(f"similarities shape : {similarities.shape}")

            print("DEBUG: Similarities:", similarities)
            closest_idx = np.argmax(similarities)
            return valid_idx[closest_idx]

        if self.debug:
            worker_fout = open(worker_log_file_name, 'a', encoding='utf-8')

        if num_cluster == 1:
            previous_cu = "Not available"
            worker = WorkerAgent(self.worker_model, self.worker_prompt, self.worker_max_gen, self.tokenizer_kwargs, self.pipeline_args, self.chunk_size)
        
            visited_nodes = []
            ## Iterate over chunks
            for worker_num in range(len(chunks)):
                if worker_num == 0 and self.ablation_type is not None and 'from_medoid' in self.ablation_type:
                    closest_chunk_idx = 0
                else:
                    # Get the closest chunk to the query
                    if self.ablation_type is not None and 'rand_perm' in self.ablation_type:
                        closest_chunk_idx = np.random.choice([i for i in range(len(chunks)) if i not in visited_nodes])
                    else:
                        closest_chunk_idx = get_closest_chunk(chunks, query, self.embedding_model, visited_nodes, previous_cu) 
                visited_nodes.append(closest_chunk_idx)

                print(f"Closest chunk index: {closest_chunk_idx}")
                print(f"Visited nodes: {visited_nodes}")
                print(f"num chunks: {len(chunks)}")

                current_xs = chunks[closest_chunk_idx] 

                print(f"Current chunk len: {len(current_xs)}")
                
                worker_outputs = worker.process_chunk(
                    current_xs, query, previous_cu, worker_num == 0)
                
                if self.summary_tag:
                    # extract summary tag
                    if '<summary>' in worker_outputs and '</summary>' in worker_outputs:
                        start_idx = worker_outputs.index('<summary>') + len('<summary>')
                        end_idx = worker_outputs.index('</summary>')
                        previous_cu = worker_outputs[start_idx:end_idx].strip()
                    else:
                        previous_cu = worker_outputs
                else:
                    previous_cu = worker_outputs

                if self.debug:
                    item = {'worker_num': worker_num, 'output': previous_cu}
                    worker_fout.write(json.dumps(item, ensure_ascii=False) + '\n')
                print('-'*100)
                print(f'\t\tWorker {worker_num}')
                print(f"WORKER INPUT MESSAGE (cropped): {current_xs[:300]}")
                print(f"OUTPUT (cropped): {previous_cu[:500]} \n\n {'-' * 30}")

            worker_final_outputs = previous_cu

        else:
            adjusted_max_gen = self.worker_max_gen // num_cluster
            worker = WorkerAgent(self.worker_model, self.worker_prompt, adjusted_max_gen, 
                                 self.tokenizer_kwargs, self.pipeline_args, self.chunk_size)
            
            ## TODO: run parallel inferences
            total_worker_outputs = []
            for cluster_idx, cluster_chunks in enumerate(chunks):
                previous_cu = "Not available"
                visited_nodes = []
                print(f"Processing cluster {cluster_idx + 1}/{num_cluster} with {len(cluster_chunks)} chunks")
                for worker_num, current_xs in enumerate(cluster_chunks):
                    if worker_num == 0 and self.ablation_type is not None and 'from_medoid' in self.ablation_type:
                        closest_chunk_idx = 0
                    else:
                        closest_chunk_idx = get_closest_chunk(cluster_chunks, query, self.embedding_model, 
                                                            visited_nodes, previous_cu)
                    visited_nodes.append(closest_chunk_idx)

                    current_xs = cluster_chunks[closest_chunk_idx]
                    worker_outputs = worker.process_chunk(
                        current_xs, query, previous_cu, worker_num == 0)

                    if self.summary_tag:
                        # extract summary tag
                        if '<summary>' in worker_outputs and '</summary>' in worker_outputs:
                            start_idx = worker_outputs.index('<summary>') + len('<summary>')
                            end_idx = worker_outputs.index('</summary>')
                            previous_cu = worker_outputs[start_idx:end_idx].strip()
                        else:
                            previous_cu = worker_outputs
                    else:
                        previous_cu = worker_outputs

                    if self.debug:
                        item = {'worker_num': worker_num, 'cluster_idx': cluster_idx, 'output': previous_cu}
                        worker_fout.write(json.dumps(item, ensure_ascii=False) + '\n')


                total_worker_outputs.append(previous_cu)            

            worker_final_outputs = "" 
            for i_cls in range(num_cluster):
                # worker_final_outputs += f"[Summary of Worker {i_cls + 1} out of {num_cluster}]: {worker_outputs[i_cls]}\n"
                worker_final_outputs += f"Summary ({i_cls + 1}/{num_cluster}): {total_worker_outputs[i_cls]}\n"

        print('^'*100)
        print('\tBEFORE DELETING WORKER')
        memory_stats()
        del worker
        gc.collect()
        torch.cuda.empty_cache()
        print('^' * 100)
        print('\tAFTER DELETING WORKER')
        memory_stats()

        if self.debug:
            worker_fout.close()


        if return_pmi_score:
            emb1 = self.embedding_model.encode([worker_final_outputs])[0]
            emb2 = self.embedding_model.encode([query])[0]
            pmi_score = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2))
            return None, total_worker_outputs, pmi_score

        manager = ManagerAgent(self.manager_model, self.manager_prompt, self.tokenizer_kwargs, self.pipeline_args, self.chunk_size)

        final_outputs = manager.synthesize(worker_final_outputs, query)

        if self.debug:
            with open(manager_log_file_name, "a", encoding="utf-8") as f:
                record = {
                    "input": worker_final_outputs,
                    "query": query,
                    "output": final_outputs
                }
                json.dump(record, f, ensure_ascii=False)
                f.write('\n')

        print('^' * 100)
        print('\tBEFORE DELETING MANGER')
        memory_stats()
        del manager
        gc.collect()
        torch.cuda.empty_cache()
        print('^' * 100)
        print('\tAFTER DELETING MANGER')
        memory_stats()

        if return_summary:
            return final_outputs, worker_final_outputs
        else:
            return final_outputs
        