import os
import json
from typing import List, Dict, Callable
from time import sleep
from functools import partial

import anthropic
from openai import OpenAI
import google.generativeai as genai


CLAUDE_CKPT = "claude-3-opus-20240229"
GEMINI_CKPT = "models/gemini-1.5-pro-latest"

SUMMARY_PROMPT = (
    "You are a helpful agricultural expert studying a report published by the USDA"
)
ANALYST_PROMPT = "You are a helpful agricultural expert helping farmers decide what produce to plant next year."


def format_query(
    query: str,
    format_instruction: str = "You should format your response as a JSON object.",
):
    return f"{query}\n{format_instruction}"

def inference(
    query: str,
    system_content: str = ANALYST_PROMPT,
    temperature: float = 0.0,
    api: str = 'gpt-4',
):
    success = False
    if api in ['gpt-4', 'gpt-4o']:
        if api == 'gpt-4':
            OPENAI_CKPT = "gpt-4-1106-preview"
        else:
            OPENAI_CKPT = "gpt-4o"
        client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
        _inference_fct = partial(
            client.chat.completions.create, 
            model=OPENAI_CKPT,
            response_format={"type": "json_object"},
            temperature=temperature,
        )
        messages= [
            {"role": "system", "content": system_content},
            {"role": "user", "content": query + "<json>"},
        ]
    elif api == 'claude-3':
        client = anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY"))
        _inference_fct = partial(
            client.messages.create,
            model=CLAUDE_CKPT,
            temperature=temperature,
            system=system_content,
            max_tokens=4096,
        )
        messages= [
            {
                "role": "user",
                "content": [
                    {
                        "type": "text",
                        "text": query
                    }
                ]
            }
        ]
    elif api == 'gemini':
        cfg = genai.GenerationConfig(
            temperature=temperature,
            max_output_tokens=4096,
            response_mime_type="application/json",
        )
        model = genai.GenerativeModel(
            GEMINI_CKPT, 
            generation_config=cfg,
            system_instruction=system_content,
        )
        messages = query
        def _inference_fct(messages):
            return model.generate_content(contents=messages)
    else:
        raise ValueError(f"Invalid API: {api}")
    while not success:
        try:
            import time
            start = time.time()
            response = _inference_fct(messages=messages)
            duration = time.time() - start
            if api == 'gemini': sleep(max(30 - duration, 0))
            success = True
        except Exception as e:
            print(e)
            sleep(10)

    try:
        if api in ['gpt-4', 'gpt-4o']:
            response = response.choices[0].message.content
        elif api == 'claude-3':
            response = response.content[0].text
        elif api == 'gemini':
            response = response.text
    except Exception as e:
        print(e)
        response = ""
    try:
        response = json.loads(response.lower())
    except:
        response = response.lower()
    return response


def majority_voting_inference(
    query: str | List[str | Callable],
    system_content: str = ANALYST_PROMPT,
    temperature: float = 0.7,
    num_samples: int = 5,
    use_chain_of_thought: bool = False,
    api: str = 'gpt-4',
):
    responses = []
    for _ in range(num_samples):
        if use_chain_of_thought:
            response = chain_of_thought_inference(
                chain=query, system_content=system_content, temperature=temperature, api=api
            )["response"]
        else:
            response = inference(query, system_content, temperature, api)
        responses.append(response)

    decisions = [r["decision"] for r in responses]
    majority_decision = max(set(decisions), key=decisions.count)
    response = {
        "decision": majority_decision,
        "explanation": responses,
    }
    return response


def chain_of_thought_inference(
    chain: List[str | Callable],
    system_content: str = ANALYST_PROMPT,
    temperature: float = 0.5,
    api: str = 'gpt-4',
):
    history = {}
    for query in chain:
        if isinstance(query, str):
            response = inference(query, system_content, temperature, api)
        else:
            previous_results = [history[k] for k in history.keys()]
            query = query(*previous_results)
            response = inference(query, system_content, temperature, api)
        history[query] = response

    return {
        "query": [{"prompt": q, "response": r} for q, r in history.items()],
        "response": response,
    }


def summarize(
    fname: str, products: List[str], temperature: float = 0.0, api: str = 'gpt-4',
) -> Dict[str, str]:
    products = sorted(p.lower() for p in products)
    summary_fname = fname.split(".")[0] + "-" + "-".join(products) + ".json"
    if os.path.exists(summary_fname):
        # print(f"Summary file {summary_fname} already exists.")
        return json.load(open(summary_fname))

    report = open(fname).read()
    query = f"Below is an agriculture report published by the USDA:\n\n{report}\n\n"

    format_instruction = f"""Please write a detailed summary of the report.

You should format your response as a JSON object. The JSON object should contain the following keys:
    overview: a string that describes, in detail, the overview of the report. Your summary should focus on factors that affect the overall furuit and nut market.
    """
    for p in products:
        format_instruction += f"""
    {p}: a string that describes, in detail, information pertaining to {p} in the report. You should include information on {p} prices and production, as well as factors that affect them. 
        """
    query = format_query(query, format_instruction)
    response = inference(query, SUMMARY_PROMPT, temperature, api)
    try:
        response = json.loads(response.lower())
    except:
        response = response
    with open(summary_fname, "w") as f:
        json.dump(response, f, indent=4)
    return response
