from __future__ import annotations
import os
import copy
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Dict, List, Optional
from src.memorykit import MemorySystem
from src.agents import CodeAgent, MemoryCurator, MemoryRetriever
from src.providers import OpenAILLM, OpenAIEmbedder, AzureLLM, AzureEmbedder
from src.configs.agent import AgentConfig
from src.utils import save_json, save_file, normalize_op_name
from src.verification import Verifier
from src.log import logger_setup
from examples.dataset import dataset, ablation_sub_dataset

logger = logging.getLogger(__name__)

class AscendCAgentFramework:
    def __init__(self, config: AgentConfig):
        """Initialize the framework.
        
        Args:
            config: Configuration object
        """
        self.config: AgentConfig = config

        if hasattr(config.llm_config, 'api_version'):
            self.llm_provider = AzureLLM(config.llm_config)
        else:
            self.llm_provider = OpenAILLM(config.llm_config)

        if hasattr(config.embedder_config, 'api_version'):
            self.embedding_provider = AzureEmbedder(config.embedder_config)
        else:
            self.embedding_provider = OpenAIEmbedder(config.embedder_config)
        
        self.memory_system = MemorySystem(
            embedder=self.embedding_provider, 
            config=self.config.memory_config
        )
        self.memory_retriever = MemoryRetriever(
            memory_system=self.memory_system,
            config=self.config.memory_config,
        )
        self.memory_curator = MemoryCurator(
            llm=self.llm_provider,
            memory_system=self.memory_system,
            config=self.config.memory_config,
        )
        
        self.memory_curator.memory_init()

        self.code_agent = CodeAgent(
            retriever=self.memory_retriever,
            llm=self.llm_provider,
            config=self.config.code_agent_config
        )
        self.code_verifier = Verifier(self.config.verify_config)
        
    def process_query(self, op: str, query: str):
        op_norm = normalize_op_name(op)
        history_infos = {"op": op, "op_norm": op_norm, "iterations": []}
        for i in range(self.config.max_iters):
            try:
                history_info = {}
                response: Dict = self.code_agent.generate(query, op_norm)
                logger.info(f"Iteration {i+1}/{self.config.max_iters}, Action: {response.get('action', '')}\nPlan: \n{response.get('plan', '')}\nCode: \n{response.get('code', '')}")
                code_result = response.get("code", "")
                action = response.get("action", "")

                if code_result:
                    history_info.update(response)
                    history_info['iter'] = i+1
                    save_file(code_result, os.path.join(self.save_dir, op, f"{action}_{i+1}.txt"))
                    verified_result = self.code_verifier.run([(code_result, query, op)])[0]
                    history_info.update(verified_result)
                    
                    history_info.update({"q_update_infos": self.memory_curator.summarize_and_update(op_norm, query, history_info)})

                    history_infos['iterations'].append(copy.deepcopy(history_info))
                    save_json(history_infos, os.path.join(self.save_dir, op, "op_result.json"))
                    logger.info(history_info)
            except Exception as e:
                logger.error(
                    "Error in process_query: op=%s, query=%s, iter=%s, error=%s\n%s",
                    op,
                    query,
                    i + 1,
                    repr(e),
                    traceback.format_exc(),
                )
                continue
        return history_infos
    
    def process_queries(self, iter_num, exp_name, ops: List[str], queries: List[str], max_workers: int | None = None) -> List[Dict| None]:
        try:
            if len(ops) != len(queries):
                raise ValueError("ops and queries must have the same length")

            if not ops:
                return []
            self.memory_retriever.candidate_pool_multiplier = iter_num
            results = [None] * len(ops)

            if '/' in self.config.llm_config.model_name_or_path:
                # processing openrouter model
                model_name = self.config.llm_config.model_name_or_path.split('/')[1]
            else:
                model_name = self.config.llm_config.model_name_or_path

            if iter_num == self.config.outer_start and iter_num != 1:
                last_exp_name = "_".join(exp_name.split("_")[:-1]) + '_' + str(iter_num-1)
                last_save_dir = os.path.join(self.config.project_root_path, 'outputs', model_name, last_exp_name, self.config.timestamp)
                self.memory_curator.load_running_info(last_save_dir)
            
            self.save_dir = os.path.join(self.config.project_root_path, 'outputs', model_name, exp_name, self.config.timestamp)
            
            with ThreadPoolExecutor(max_workers=max_workers) as ex:
                future_to_idx = {
                    ex.submit(self.process_query, op, query): idx
                    for idx, (op, query) in enumerate(zip(ops, queries))
                }
                for fut in as_completed(future_to_idx):
                    idx = future_to_idx[fut]
                    try:
                        results[idx] = fut.result()
                    except Exception as e:
                        results[idx] = {"error": traceback.format_exc()}
            self.memory_system.dump(dir=self.save_dir)
            self.memory_curator.save_running_info(self.save_dir)
        except Exception as e:
            logger.error(f"Error in process_queries: {traceback.format_exc()}")
            raise e
        return results
    
def get_ops(levels: List[str] = ['level1']) -> List[Dict]:
    """Get operators for the specified levels.
    
    Args:
        levels: List of levels to include (e.g., ['level1'], ['level1', 'level2'])
        
    Returns:
        List of operator dicts with name, category, level, and ref_src
    """
    ops = []
    example_dir = "examples/KernelBench"

    for op_name, op_info in dataset.items():
        if op_info['level'] in levels:
            try:
                op = {"name": op_name, "category": op_info['category'], "level": op_info['level']}
                ref_src_path = os.path.join(example_dir, op_info['level'], f"{op_name}.py")
                with open(ref_src_path, 'r') as f:
                    ref_src = f.read()
                op["ref_src"] = ref_src
                ops.append(op)
            except Exception as e:
                traceback.print_exc()
                continue
    return ops


def get_sub_ops(allowed_ops: List[str] = None):
    example_dir = "examples"
    ops = []
    if allowed_ops is None:
        allowed_ops = [
            "leaky_relu", "subtract_with_bias_broadcast", "tanh", "swish", "sigmoid", "relu", "power_broadcast", "hardtanh", 
            "hardsigmoid", "batched_matrix_multiplication", "hinge_loss", "four_dim_tensor_matrix_multiplication", "tall_skinny_matrix_multiplication",
            "three_dim_tensor_matrix_multiplication", "softsign", "matmul_with_small_k_dimension", 
            "conv3d_group_norm_min_clamp_dropout", "conv3d_scaling_tanh_multiply_sigmoid", "conv_transpose2d_mish_add_hardtanh_scaling",
            "sparse_attention", "scaled_dot_product_attention_inference", "multi_query_attention"
        ]

    # Prepare the operators pending testing
    for op_name, op_info in dataset.items():
        if op_name in allowed_ops:
            try:
                op = {"name": op_name, "category": op_info['category'], "level": op_info['level']}
                ref_src_path = os.path.join(example_dir, op_info['category'], f"{op_name}.py")
                with open(ref_src_path, 'r') as f:
                    ref_src = f.read()
                op["ref_src"] = ref_src
                ops.append(op)
            except Exception as e:
                traceback.print_exc()
                continue

    return ops

def get_ablation_ops() -> List[Dict]:   
    """Get operators for ablation study.
    
    Returns:
        List of operator dicts with name, category, level, and ref_src
    """
    ops = []
    example_dir = "examples/KernelBench"

    for op_name in ablation_sub_dataset.keys():
        op_info = dataset[op_name]
        try:
            op = {"name": op_name, "category": op_info['category'], "level": op_info['level']}
            ref_src_path = os.path.join(example_dir, op_info['level'], f"{op_name}.py")
            with open(ref_src_path, 'r') as f:
                ref_src = f.read()
            op["ref_src"] = ref_src
            ops.append(op)
        except Exception as e:
            traceback.print_exc()
            continue
    return ops

def main(ops: List[Dict], config_file: str = 'agent.test_module.yaml'):
    """Run the agent on a list of operators.
    
    Args:
        ops: List of operator dicts with name, category, level, and ref_src
        config_file: Path to the config file (relative to src/configs/presets/)
    """
    cfg: AgentConfig = AgentConfig.from_yaml_file(config_file)
    logger_setup(os.path.join(cfg.logs_path,f'agent.{cfg.timestamp}.log'))
    framework = AscendCAgentFramework(cfg)

    op_names = [op['name'] for op in ops]
    op_ref_srcs = [op['ref_src'] for op in ops]
    exp_name = cfg.exp_name
    
    for i in range(cfg.outer_start, cfg.outer_iters + 1):
        logger.info(f"Processing {len(op_names)} ops on iteration {i}...")
        results = framework.process_queries(i, f"{exp_name}_{i}", op_names, op_ref_srcs, max_workers=cfg.max_workers)
        for idx, op in enumerate(ops):
            op.update(results[idx])
            save_json(op, filepath=os.path.join(framework.save_dir, op['name'], f"op_result.json"))
            
        if cfg.memory_config.enable_utility_pruning:
            framework.memory_curator.prune_low_utility_memories()

    framework.code_verifier.close()
    logger.info("All iterations finished!")

if __name__ == "__main__":

    # ops = get_ops(levels=['level1', 'level2'])
    # main(ops, config_file="1_18/agent.gpt52_l1-2_mixed_scratch.yaml")

    # ops = get_ops(levels=['level1'])
    # main(ops, config_file="1_18/agent.gpt52_l1.yaml")

    # ops = get_ops(levels=['level2'])
    # main(ops, config_file="1_18/agent.gpt52_l2_with_L1_Mem.yaml")
    # main(ops, config_file="1_18/agent.gpt52_l2_scratch.yaml")

    # ops = get_ops(levels=['level1', 'level2'])
    # main(ops, config_file="1_18/agent.deepseek_l1-2_mixed_scratch.yaml")
    

    # ops = get_ops(levels=['level1'])
    # main(ops, config_file="1_18/agent.qwen_l1.yaml")

    # ops = get_ablation_ops()
    
    # main(ops, config_file="1_25/deepseek_mem_ablation.yaml")

    # ops = get_ops(levels=['level2'])
    # main(ops, config_file="1_25/gpt5.2_l2_withl1_mem_q_ablation.yaml")

    ops = get_ops(levels=['level1', 'level2'])
    main(ops, config_file="1_28/agent.qwen_l1-2_mixed_scratch.yaml")