
# cython: annotation_typing=False
import logging
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

import dotenv
from genai import Client, Credentials
from genai.schema import TextGenerationParameters

# cython: annotation_typing=False
import logging
import os

from genai.schema import TextGenerationParameters


from ibm_watsonx_ai.metanames import GenTextParamsMetaNames

from langchain_ibm import WatsonxLLM


dotenv.load_dotenv(".env")




def createWatsonxLLM(
    params: TextGenerationParameters,
    model_id="ibm-mistralai/mixtral-8x7b-instruct-v01-q",
):
    """
    Create a Watsonx LLM model
    Args:
        params: the parameters of the model to create
        model_id: the identifier of the model to create
    Returns:
        The newly created model
    """
    project_id = os.getenv("WATSONX_PROJECT_ID")
    api_key = os.getenv("WATSONX_APIKEY")
    api_url = os.getenv("WATSONX_URL")
    print(f"api url: {api_url}")
    # credentials = Credentials(api_key, api_endpoint=api_url)
    # print(f"credentials: {credentials}")
    # client = Client(credentials=credentials)

    params = {
        GenTextParamsMetaNames.DECODING_METHOD: params.decoding_method,
        GenTextParamsMetaNames.MAX_NEW_TOKENS: params.max_new_tokens,
        GenTextParamsMetaNames.MIN_NEW_TOKENS: params.min_new_tokens,
        GenTextParamsMetaNames.TEMPERATURE: params.temperature,
        GenTextParamsMetaNames.TOP_K: params.top_k,
        GenTextParamsMetaNames.TOP_P: params.top_p,
        GenTextParamsMetaNames.LENGTH_PENALTY: params.length_penalty,
        GenTextParamsMetaNames.RANDOM_SEED: params.random_seed,
        # GenTextParamsMetaNames.RETURN_OPTIONS: params.return_options,
        GenTextParamsMetaNames.REPETITION_PENALTY: params.repetition_penalty,
        GenTextParamsMetaNames.STOP_SEQUENCES: params.stop_sequences,
        GenTextParamsMetaNames.TIME_LIMIT: params.time_limit,

        GenTextParamsMetaNames.TRUNCATE_INPUT_TOKENS: params.truncate_input_tokens,
    }
    llm = WatsonxLLM(
        model_id=model_id,
        url=api_url,
        apikey=api_key,
        project_id=project_id,
        params=params,
    )
    return llm