import logging
import time
from datetime import datetime
import google.generativeai as google_ai
from concurrent.futures import ThreadPoolExecutor, as_completed
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTNeoXForCausalLM
import torch
import json
import gc
import os
from tqdm.auto import tqdm
from openai import OpenAI
import anthropic
from together import Together

from . import keys

OAI_client = OpenAI(api_key=keys.OPENAI_API_KEY)
ANT_client = anthropic.Anthropic(api_key=keys.API_KEY_ANTHROPIC)
TOGETHER_client = Together(api_key=keys.API_KEY_TOGETHER)

MODEL_TO_COST = {"gpt-4o-mini": {"input": 0.00000015, "output": 0.0000006},
                    "gpt-4o": {"input": 0.0000025, "output": 0.00001},
                    "gpt-4": {"input": 0.000030, "output": 0.000060},
                    "o1-mini": {"input": 0.000003, "output": 0.000012},
                    "o1": {"input": 0.000015, "output": 0.00006},
                    "o3-mini": {"input": 0.0000011, "output": 0.0000044},
                    "claude-3-7-sonnet-20250219": {"input": 0.000003, "output": 0.000015},
                    "claude-3-5-sonnet-20241022": {"input": 0.000003, "output": 0.000015},
                    "claude-3-5-haiku-20241022": {"input": 0.0000008, "output": 0.000004},
                    }


google_ai.configure(api_key=keys.API_KEY_GOOGLE)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

LOCAL_MODEL_PATH_MAP = {
    "meta-llama/Llama-2-70b-chat-hf": "/vast/work/public/ml-datasets/llama-2/Llama-2-70b-chat-hf",
    "meta-llama/Meta-Llama-3.1-8B-Instruct": "/vast/work/public/ml-datasets/llama-3.1/Meta-Llama-3.1-8B-Instruct",
    "meta-llama/Meta-Llama-3.1-70B-Instruct": "/vast/work/public/ml-datasets/llama-3.1/Meta-Llama-3.1-70B-Instruct",
    "meta-llama/Meta-Llama-3.1-405B-Instruct": "/vast/work/public/ml-datasets/llama-3.1/Meta-Llama-3.1-405B-Instruct"
}

def GET_API_COST(model, response):
    if 'gpt' in model or 'o1' in model or 'o3' in model:
        return (
            response.usage.prompt_tokens * MODEL_TO_COST[model]["input"]
            + response.usage.completion_tokens * MODEL_TO_COST[model]["output"]
        )
    elif 'claude' in model:
        return (
            response.usage.input_tokens * MODEL_TO_COST[model]["input"]
            + response.usage.output_tokens * MODEL_TO_COST[model]["output"]
        )

def generate_OAI_response(
    prompt, system_prompt="", model_name="gpt-4o-mini", cost=0, max_tokens=500, temperature=0
):
    """Make one OAI inference with error handling, retry logic, 
    and graceful handling of invalid prompts by checking the exception message."""
    while True:
        try:
            if model_name in ['o1', 'o1-mini', "o3-mini"]:
                response = OAI_client.chat.completions.create(
                    model=model_name,
                    messages=[
                        {"role": "user", "content": prompt},
                    ],
                )
            else:
                response = OAI_client.chat.completions.create(
                    model=model_name,
                    messages=[
                        {"role": "system", "content": "You are a helpful assistant." + " " + system_prompt},
                        {"role": "user", "content": prompt},
                    ],
                    max_tokens=max_tokens,
                    temperature=temperature,
                )

            # Calculate and update the cost of the API call
            openai_cost = GET_API_COST(model_name, response)
            cost += openai_cost

            # Extract and return the response content
            response_content = response.choices[0].message.content
            return response_content, cost

        except Exception as e:
            # Convert the exception to a string so we can inspect its message
            error_str = str(e)
            print(error_str)
            # Check if the error message contains "invalid_prompt"
            if "violating" in error_str:
                return "Sorry, I can't help with this.", cost

            # Otherwise, print the error and retry after 10 seconds
            print(f"API error occurred: {e}. Retrying in 10 seconds...")
            time.sleep(10)


def document_api_cost(total_cost):

    curr_loc = os.getcwd()

    # Find the most recent folder containing 'SafeSysGen'
    while 'SafeSysGen' not in os.path.basename(curr_loc) and os.path.dirname(curr_loc) != curr_loc:
        curr_loc = os.path.dirname(curr_loc)

    if 'SafeSysGen' not in os.path.basename(curr_loc):
        raise FileNotFoundError("SafeSysGen folder not found in the current directory structure.")

    file_path = os.path.join(curr_loc, "oai_costs.json")
    print(file_path)
    # Ensure the directory exists
    directory = os.path.dirname(file_path)
    if not os.path.exists(directory):
        os.makedirs(directory)
        print(f"Created directory: {directory}")

    # Check if the file exists, if not create an empty file with an empty dictionary
    if not os.path.exists(file_path):
        with open(file_path, "w") as file:
            json.dump({}, file)
        print(f"Created {file_path} to document API cost.")
    
    # Read and handle the file content safely
    try:
        with open(file_path, "r") as file:
            content = file.read().strip()  # Read and strip whitespace
            data = json.loads(content) if content else {}  # Load or initialize
    except (json.JSONDecodeError, FileNotFoundError):
        data = {}  # Reset to an empty dictionary if the file is invalid
    
    # Add the new entry
    data[datetime.now().strftime("%Y-%m-%d %H:%M:%S")] = total_cost

    # Write back the updated data
    with open(file_path, "w") as file:
        json.dump(data, file, indent=4)

def generate_OAI_responses_threaded(
    prompts, system_prompt="", model_name="gpt-4o-mini", max_tokens=300, temperature=0, max_workers=20
):
    """Generate responses using multithreading with a specified number of workers and progress bar."""
    results = {}
    total_prompts = len(prompts)
    total_cost = 0

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_prompt = {
            executor.submit(
                generate_OAI_response, prompt, system_prompt, model_name, 0, max_tokens, temperature
            ): prompt
            for prompt in prompts
        }

        with tqdm(total=total_prompts, desc="Generating responses") as pbar:
            for future in as_completed(future_to_prompt):
                prompt = future_to_prompt[future]
                try:
                    result = future.result()
                    results[prompt], cost = result
                    total_cost += cost
                except Exception as e:
                    print(f"Error processing prompt: {prompt}")
                    print(f"Error: {e}")
                finally:
                    pbar.update(1)

    document_api_cost(total_cost)

    return results

def generate_TOGETHER_response(
    prompt, system_prompt="", model_name="deepseek-ai/DeepSeek-V3", cost=0, max_tokens=500, temperature=0
):
    """make one ANT inference with error handling and retry logic."""
    while True:
        try:
            response = TOGETHER_client.chat.completions.create(
                model=model_name,
                messages=[{"role": "user", 
                        "content": prompt}],
                max_tokens=max_tokens,
                temperature=temperature)
            # calculate the cost of the api call
            api_cost = 0
            cost+=api_cost
            
            # extract and return the response content
            response_content = response.choices[0].message.content
            return response_content, cost

        except Exception as e:
            # print the error and retry after 10 seconds
            print(f"API error occurred: {e}. Retrying in 10 seconds...")
            time.sleep(10)
            
def generate_TOGETHER_responses_threaded(
    prompts, system_prompt = "" , model_name="deepseek-ai/DeepSeek-V3", max_tokens=300, temperature=0, max_workers=20
):
    """Generate responses using multithreading with a specified number of workers and progress bar."""
    results = {}
    total_prompts = len(prompts)
    total_cost = 0

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_prompt = {
            executor.submit(
                generate_TOGETHER_response, prompt, system_prompt, model_name, 0, max_tokens, temperature
            ): prompt
            for prompt in prompts
        }

        with tqdm(total=total_prompts, desc="Generating responses") as pbar:
            for future in as_completed(future_to_prompt):
                prompt = future_to_prompt[future]
                try:
                    result = future.result()
                    results[prompt], cost = result
                    total_cost += cost
                except Exception as e:
                    print(f"Error processing prompt: {prompt}")
                    print(f"Error: {e}")
                finally:
                    pbar.update(1)

    document_api_cost(total_cost)

    return results

def generate_ANT_response(
    prompt, system_prompt="", model_name="claude-3-5-sonnet-20241022", cost=0, max_tokens=500, temperature=0
):
    """make one ANT inference with error handling and retry logic."""
    while True:
        try:
            response = ANT_client.messages.create(
                model=model_name,
                max_tokens=max_tokens,
                temperature=temperature,
                system = system_prompt,
                messages=[
                    {
                        "role": "user",
                        "content": [
                            {
                                "type": "text",
                                "text": prompt
                            }
                        ]
                    }
                ]
            )
            # calculate the cost of the api call
            api_cost = GET_API_COST(model_name, response)
            cost+=api_cost
            
            # extract and return the response content
            response_content = response.content[0].text
            return response_content, cost

        except Exception as e:
            # print the error and retry after 10 seconds
            print(f"API error occurred: {e}. Retrying in 10 seconds...")
            time.sleep(10)


def get_response_with_retry(api_call, prompt, wait_time):
    """
    Make an API call and retry on failure after a specified wait time.
    """
    while True:
        try:
            return api_call()
        except Exception as e:
            logger.info(f"Prompt: {prompt}")
            logger.info(f"Error message: {e}")
            logger.info(f"Waiting for {wait_time} seconds before retrying...")

            time.sleep(wait_time)
            
            
def generate_ANT_responses_threaded(
    prompts, system_prompt = "" , model_name="claude-3-5-sonnet-20241022", max_tokens=300, temperature=0, max_workers=20
):
    """Generate responses using multithreading with a specified number of workers and progress bar."""
    results = {}
    total_prompts = len(prompts)
    total_cost = 0

    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_prompt = {
            executor.submit(
                generate_ANT_response, prompt, system_prompt, model_name, 0, max_tokens, temperature
            ): prompt
            for prompt in prompts
        }

        with tqdm(total=total_prompts, desc="Generating responses") as pbar:
            for future in as_completed(future_to_prompt):
                prompt = future_to_prompt[future]
                try:
                    result = future.result()
                    results[prompt], cost = result
                    total_cost += cost
                except Exception as e:
                    print(f"Error processing prompt: {prompt}")
                    print(f"Error: {e}")
                finally:
                    pbar.update(1)

    document_api_cost(total_cost)

    return results


def get_response_from_google_model(model_name, prompt, max_tokens=2000, temperature=0, wait_time=10):
    model = google_ai.GenerativeModel(model_name)

    def api_call():
        response = model.generate_content(
            prompt,
            generation_config=google_ai.types.GenerationConfig(
                candidate_count=1,
                max_output_tokens=max_tokens,
                temperature=temperature,
            ),
        )
        
        try:
            final_output = response.candidates[0].content.parts[0].text
            return final_output
        except:
            return "No response generated due to safety training."
        
    return get_response_with_retry(api_call, prompt, wait_time)
    

def generate_gemini_responses_threaded(prompts, model_name, max_tokens=2000, temperature=0, max_workers=30, wait_time=10):
    """Get responses for multiple prompts using multi-threading with a progress bar."""
    
    def get_response(prompt):
        return prompt, get_response_from_google_model(model_name, prompt, max_tokens, temperature, wait_time)
    
    responses = {}
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(get_response, prompt): prompt for prompt in prompts}
        for future in tqdm(as_completed(futures), total=len(prompts), desc="Processing prompts"):
            prompt, response = future.result()
            responses[prompt] = response
    
    return responses

def generate_gemini_responses_threaded_modified(prompts, model_name, max_tokens=2000, temperature=0, max_workers=30, wait_time=10):
    """Get responses for multiple prompts using multi-threading with a progress bar."""
    
    def get_response(prompt):
        return prompt, get_response_from_google_model(model_name, prompt, max_tokens, temperature, wait_time)
    
    responses = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        futures = {executor.submit(get_response, prompt): prompt for prompt in prompts}
        for future in tqdm(as_completed(futures), total=len(prompts), desc="Processing prompts"):
            prompt, response = future.result()
            responses.append(response)
    
    return responses


def free_memory():
    """Free up memory and collect gabage."""
    if 'model' in globals():
        del model
    if 'tokenizer' in globals():
        del tokenizer
    torch.cuda.empty_cache()
    gc.collect()


def load_hf_model_and_tokenizer(model_name):
    """Load a model and a tokenizer from Huggingface using cuda if avaiable."""
    if "EleutherAI" in model_name:
        model = GPTNeoXForCausalLM.from_pretrained(model_name, 
                                          token=keys.HF_AUTH_TOKEN)
        tokenizer = AutoTokenizer.from_pretrained(model_name, 
                                          token=keys.HF_AUTH_TOKEN)
    elif model_name in LOCAL_MODEL_PATH_MAP:
        # local models (in /vast/work/public/ml-datasets)
        model = AutoModelForCausalLM.from_pretrained(LOCAL_MODEL_PATH_MAP[model_name], 
                                     token=keys.HF_AUTH_TOKEN, 
                                     torch_dtype=torch.float16)
        tokenizer = AutoTokenizer.from_pretrained(LOCAL_MODEL_PATH_MAP[model_name], 
                                            token=keys.HF_AUTH_TOKEN)
    else:
        model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             token=keys.HF_AUTH_TOKEN, 
                                             torch_dtype=torch.float16)
        tokenizer = AutoTokenizer.from_pretrained(model_name, 
                                            token=keys.HF_AUTH_TOKEN)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print(f"Loaded {model_name} in {device}")
    return model, tokenizer, device


def get_response_from_hf_model(prompts, model, tokenizer, device, max_length=200, max_new_tokens=200, batch_size=10):
    """Get responses from the given hf model, running in batches for efficiency."""
    response_dict = {}
    for i in tqdm(range(0, len(prompts), batch_size), desc="Processing Batches"):
        batch_prompts = prompts[i:i + batch_size]
        
        inputs = tokenizer(batch_prompts, 
                           return_tensors="pt", 
                           padding=True, 
                           truncation=True, 
                           max_length=max_length).to(device)
        
        outputs = model.generate(**inputs, max_length=max_length + max_new_tokens, do_sample=False)
        generated_texts = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
        
        for j, prompt in enumerate(batch_prompts):
            response_dict[prompt] = generated_texts[j][len(prompt):]
            print(f"Prompt: {prompt}\nResponse: {generated_texts[j][len(prompt):]}")
            print("-"*40)
    
    return response_dict

def make_mutithreaded_inference_by_model(prompts, model_name, system_prompt="You are a helpful assistant.", max_tokens=2000, temperature=0, max_workers=30):
    if 'gpt' in model_name or 'o1' in model_name or 'o3' in model_name:
        responses = generate_OAI_responses_threaded(
            prompts, system_prompt=system_prompt, model_name=model_name, max_tokens=max_tokens, temperature=0, max_workers=30
        )
    elif 'claude' in model_name:
        responses = generate_ANT_responses_threaded(
            prompts, system_prompt=system_prompt, model_name=model_name, max_tokens=max_tokens, temperature=0, max_workers=30
        )
    elif 'gemini' in model_name:    
        responses = generate_gemini_responses_threaded(prompts, model_name=model_name, max_tokens=max_tokens, temperature=0, max_workers=30)
    else:
        responses = generate_TOGETHER_responses_threaded(
            prompts, system_prompt=system_prompt, model_name=model_name, max_tokens=max_tokens, temperature=0, max_workers=30
        )
    return responses



def load_model_and_tokenizer(model_name, cache_dir=None):
    """Load model and tokenizer from local path if available, otherwise download."""
    if cache_dir is None:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            token=keys.HF_TOKEN_JOHN,
            # force_download=True 
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            device_map="auto",
            torch_dtype=torch.float16,
            token=keys.HF_TOKEN_JOHN,
            cache_dir=cache_dir
        )

    if cache_dir is None:
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=keys.HF_TOKEN_JOHN)
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name, token=keys.HF_TOKEN_JOHN, cache_dir=cache_dir)

    device = next(model.parameters()).device
    if device.type == "cuda":
        print(f"Model is using GPU: {torch.cuda.get_device_name(device)}")
    elif device.type == "mps":
        print("Model is using Apple Silicon GPU")
    else:
        print("Model is using CPU")

    return model, tokenizer


def batch_HF_inference(
    model, tokenizer, prompts, max_new_tokens=1000, batch_size=10, temperature=1.0,
):
    """Generate inferences for a list of prompts using batched processing on GPU."""
    results = {}
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # model.to(device)
    num_batches = (len(prompts) + batch_size - 1) // batch_size

    for i in tqdm(
        range(0, len(prompts), batch_size), total=num_batches, desc="Processing batches"
    ):
        batch_prompts = prompts[i : i + batch_size]

        # Tokenize the batch
        encoded_inputs = tokenizer(batch_prompts, padding=True, return_tensors="pt").to(
            device
        )

        with torch.no_grad():
            outputs = model.generate(
                **encoded_inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True if temperature > 0 else False,
            )


        # Decode the outputs
        decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)

        # Process each output in the batch, maintaining order
        for prompt, full_response in zip(batch_prompts, decoded_outputs):
            # Find where the actual response starts
            response = full_response[len(prompt) :].strip()
            results[prompt] = response
            print(f"Response: {response}")
            print("---")

    return results