import os
import yaml
import json
import time
import pandas as pd

from LLMClient import LLMClient
from API_Manager import API_Pool
from prompts_EN import (
    EXTRACTOR_SYSTEM_PROMPT, make_extractor_user_prompt, 
    VERIFIER_SYSTEM_PROMPT, make_verifier_user_prompt,
    Maintainer_Analyze_and_Merge_Prompt
)

"""
Data Provider:
    load and process data
"""
class Data_Provider:
    def __init__(self, file_path:str):
        self.file_path = file_path
        self.data = self.load_data(file_path)
    
    def load_data(self, file_path):
        if file_path.endswith('.csv'):
            data = pd.read_csv(file_path)
        elif file_path.endswith('.jsonl'):
            data = pd.read_json(file_path, lines=True)
        else:
            raise ValueError('Unsupported file format')
        return data
    def __len__(self):
        return len(self.data)
    def get_data_from_idx(self, i):
        return self.data['msg_list'][i]
"""
Agent Base
"""
class AgentBase:
    def __init__(self, client:LLMClient, system_prompt:str):
        self.name = "Agent-Base"
        self.client = client
        self.system_prompt = system_prompt
        self.token_count = {
            "completion_tokens": 0,
            "prompt_tokens": 0,
            "total_tokens": 0
        }

    def update_token_count(self, response):
        if response['usage']['completion_tokens'] is not None:
            self.token_count['completion_tokens'] += response['usage']['completion_tokens']
        if response['usage']['prompt_tokens'] is not None:
            self.token_count['prompt_tokens'] += response['usage']['prompt_tokens']
        if response['usage']['total_tokens'] is not None:
            self.token_count['total_tokens'] += response['usage']['total_tokens']

    def make_message(self, data):
        return data
    def format_output(self, response):
        return response
    def get_token_count(self):
        return (self.token_count)
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output
    
"""
API Extractor
"""
class API_Extractor(AgentBase):
    def __init__(self, client:LLMClient, system_prompt:str):
        super().__init__(client, system_prompt)
        self.name = "API_Extractor"
    
    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message

    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        json_content = json.loads(content.strip("```json"))
        return json_content
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output

"""
API Verifier
"""
class API_Verifier(AgentBase):
    def __init__(self, client:LLMClient, system_prompt:str):
        super().__init__(client, system_prompt)
        self.name = "API_Verifier"

    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message
    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        json_content = json.loads(content.strip("```json"))
        return json_content
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output

"""
API Maintainer
"""
class API_Maintainer(AgentBase):
    def __init__(self, client:LLMClient, system_prompt:str):
        super().__init__(client, system_prompt)
        self.name = "API_Maintainer"

    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message
    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        json_content = json.loads(content.strip("```json"))
        return json_content
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output


"""
Problem Weaver: scripts to tasks
"""
class ProblemWeaver(AgentBase):
    def __init__(self, client:LLMClient, system_prompt):
        super().__init__(client, system_prompt)
        self.name = "Problem_Weaver"
    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message
    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        content = json.loads(content.strip("```json"))
        return content
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output

"""
Problem Weaver: API to sentence
"""
class SentenceAgent(AgentBase):
    def __init__(self, client:LLMClient, system_prompt):
        super().__init__(client, system_prompt)
        self.name = "Problem_Weaver: API to sentence"
    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message
    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        content = json.loads(content.strip("```json"))
        return content
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output

"""
Problem Weaver: Merge sentences to single-sentence
"""
class MergeSentenceAgent(AgentBase):
    def __init__(self, client:LLMClient, system_prompt):
        super().__init__(client, system_prompt)
        self.name = "Problem_Weaver: Merge sentences"
    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message
    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        content = json.loads(content.strip("```json"))
        return content
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output

"""
API Maintainer
"""
class API_Maintainer(AgentBase):
    def __init__(self, client:LLMClient, system_prompt:str):
        super().__init__(client, system_prompt)
        self.name = "API_Maintainer"

    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message
    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        json_content = json.loads(content.strip("```json"))
        return json_content
    def run(self, data, retry=3):
        message = self.make_message(data)
        response = self.client.chat(message)
        self.update_token_count(response)
        formatted_output = self.format_output(response)
        return formatted_output


"""
Evaluation Agent
"""
class EvaluationAgent(AgentBase):
    def __init__(self, client:LLMClient, system_prompt):
        super().__init__(client, system_prompt)
        self.name = "Evaluation Agent"
    def make_message(self, data):
        user_prompt = data
        message = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": f"{user_prompt}"},
        ]
        return message
    def format_output(self, response):
        content = response['choices'][0]['message']['content']
        content = json.loads(content.strip("```json"))
        return content
    def run(self, data, retry=3):
        cnt = 0
        while cnt < retry:
            try:
                message = self.make_message(data)
                response = self.client.chat(message)
                self.update_token_count(response)
                formatted_output = self.format_output(response)
                break
            except Exception as e:
                cnt += 1
                print(f"[Agent] Retry {cnt} / {retry} times")
                time.sleep(1)
        return formatted_output