import random
import time
from typing import List

import nltk
import openai
import requests
import torch
from accelerate.utils import find_executable_batch_size
from nltk.corpus import stopwords, wordnet
from vllm import SamplingParams

from ..model_utils import get_template, load_model_and_tokenizer, load_vllm_model


class MutateModel():
    def __init__(self, model_name):
        self.model_name = model_name
    
    def batched_generate(self, sentences: List, max_n_tokens: int, temperature: float):
        """
        Generates responses for a batch of prompts using a language model.
        """
        raise NotImplementedError
    
    def is_initialized(self):
        print(f"==> Initialized {self.model_name}")
    
    def process_outputs(self, sentences):
        outputs = []
        for revised_sentence in sentences:
            revised_sentence = revised_sentence.replace('\n', '').strip()
            if revised_sentence.startswith("'") or revised_sentence.startswith('"'):
                revised_sentence = revised_sentence[1:]
            if revised_sentence.endswith("'") or revised_sentence.endswith('"'):
                revised_sentence = revised_sentence[:-1]
            if revised_sentence.endswith("'.") or revised_sentence.endswith('".'):
                revised_sentence = revised_sentence[:-2]
            outputs.append(revised_sentence)
        return outputs

    def get_attribute(self, attr_name):
        return getattr(self, attr_name, None)

class GPT(MutateModel):
    API_RETRY_SLEEP = 10
    API_ERROR_OUTPUT = "$ERROR$"
    API_MAX_RETRY = 5
    API_TIMEOUT = 20

    def __init__(self, model_name, token, **model_kwargs):
        self.model_name = model_name
        openai.api_key = token
        self.system_msg = 'You are a helpful and creative assistant who writes well.'
    
    def generate(self, 
                sentence, 
                max_n_tokens: int=700, 
                temperature: float=1.0,
                top_p: float=0.9):
        '''
        Args:
            conv: List of dictionaries, OpenAI API format
            max_n_tokens: int, max number of tokens to generate
            temperature: float, temperature for sampling
            top_p: float, top p for sampling
        Returns:
            str: generated response
        '''
        output = self.API_ERROR_OUTPUT
        user_message = f"Please revise the following sentences with no changes to its length and only output the revised version, the sentences are: \n '{sentence}'."
        revised_sentence = sentence
        for _ in range(self.API_MAX_RETRY):
            try:
                response = openai.ChatCompletion.create(model=self.model_name,
                                        messages=[{"role": "system", "content": self.system_msg},
                                                    {"role": "user", "content": user_message}],
                                        temperature=temperature, top_p=top_p)
                revised_sentence = response["choices"][0]["message"]["content"].replace('\n', '')
                break
            except openai.error.OpenAIError as e:
                print(type(e), e)
                time.sleep(self.API_RETRY_SLEEP)
        return revised_sentence 
    
    def batched_generate(self, 
                        sentences,
                        max_n_tokens: int=1024, 
                        temperature: float=1.0,
                        **kwargs):
        revised_sentences = [self.generate(s, max_n_tokens, temperature) for s in sentences]
        outputs = self.process_outputs(revised_sentences)
        return outputs

class Synonyms(MutateModel):
    def replace_with_synonyms(self, sentence, num=10, no_mutate_words=[]):
        T = {"llama2", "meta", "vicuna", "lmsys", "guanaco", "theblokeai", "wizardlm", "mpt-chat",
            "mosaicml", "mpt-instruct", "falcon", "tii", "chatgpt", "modelkeeper", "prompt"}
        T.update(no_mutate_words)
        
        stop_words = set(stopwords.words('english'))
        words = nltk.word_tokenize(sentence)
        uncommon_words = [word for word in words if word.lower() not in stop_words and word.lower() not in T]
        selected_words = random.sample(uncommon_words, min(num, len(uncommon_words)))
        for word in selected_words:
            synonyms = wordnet.synsets(word)
            if synonyms and synonyms[0].lemmas():
                synonym = synonyms[0].lemmas()[0].name()
                sentence = sentence.replace(word, synonym, 1)
        return sentence

    def batched_generate(self, sentences, no_mutate_words, **kwargs):
        sentences = [self.replace_with_synonyms(s, no_mutate_words=no_mutate_words) for s in sentences]
        return sentences

class VLLM(MutateModel):
    def __init__(self, model_name_or_path, **model_kwargs):
        self.model_name = model_name_or_path
        self.model = load_vllm_model(model_name_or_path, **model_kwargs)
        self.template = get_template(model_name_or_path, **model_kwargs)['prompt']
        self.init_msg = "Sure, here's the revised version:\n"

    def batched_generate(self, 
                    sentences,
                    max_n_tokens: int=1024, 
                    temperature: float=1.0,
                    stop_tokens: List[str]=[],
                    **kwargs):
        
        sampling_params = SamplingParams(temperature=temperature, 
                                    max_tokens=max_n_tokens, 
                                    stop=stop_tokens)

        inputs = [f"You are a helpful and creative assistant who writes well.\nPlease revise the following sentences without changing the length and only output the revised version, the sentences are:\n'{sentence}'." 
                for sentence in sentences]
        
        # === we init with some starting tokens to make sure that the generated sequences are in-distribution=====
        init_msgs = [" ".join(sentence.split()[:2]) for sentence in sentences]
        inputs = [self.template.format(instruction=s) + " Sure, here's the revision of the sentences without changing the length:\n" + m  for s, m in zip(inputs,init_msgs)]

        outputs = self.model.generate(inputs, sampling_params, use_tqdm=False)
        outputs = [m + " " + o.outputs[0].text.strip() for o,m in zip(outputs, init_msgs)]
        outputs = self.process_outputs(outputs)

        return outputs

class HuggingFace(MutateModel):
    def __init__(self, model_name_or_path, **model_kwargs):        
        model, tokenizer = load_model_and_tokenizer(model_name_or_path, **model_kwargs)
        self.model_name = model_name_or_path
        self.model = model 
        self.tokenizer = tokenizer
        self.template = get_template(model_name_or_path, **model_kwargs)['prompt']
        self.eos_token_ids = [self.tokenizer.eos_token_id]
        self.generation_batch_size = 64
        self.n_input_tokens = 0
        self.n_output_tokens = 0
        self.n_fwd_passes = 0
        self.embed_layer = self.model.get_input_embeddings()

    def batch_generate_bs(self, batch_size, inputs, **generation_kwargs):
        if batch_size != self.generation_batch_size:
            self.generation_batch_size = batch_size
            print(f"WARNING: Found OOM, reducing generation batch_size to: {self.generation_batch_size}")

        gen_outputs = []
        for i in range(0, len(inputs), batch_size):
            inputs_b = inputs[i:i+batch_size]
            encoded_b = self.tokenizer(inputs_b, return_tensors='pt', padding='longest')

            with torch.no_grad():
                print('Mutating HuggingFace')
                embeddings_b = self.embed_layer(
                    encoded_b["input_ids"].to(self.model.device))  # Get embeddings for input ids
                self.n_input_tokens += embeddings_b.shape[1] * embeddings_b.shape[0]
                self.n_fwd_passes += embeddings_b.shape[0]
                output_ids = self.model.generate(
                        **encoded_b.to(self.model.device),
                        **generation_kwargs
                    ).cpu()


            if not self.model.config.is_encoder_decoder:
                output_ids = output_ids[:, encoded_b["input_ids"].shape[1]:]
            self.n_output_tokens += output_ids.shape[1] * output_ids.shape[0]

            decoded_outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            gen_outputs.extend(decoded_outputs)
        return gen_outputs
    
    def batched_generate(self, 
                        sentences,
                        max_n_tokens: int=1024, 
                        temperature: float=1.0,
                        stop_tokens: List[str]=[],
                        **kwargs,
                        ):
        eos_token_id = [self.tokenizer.eos_token_id] + [self.tokenizer.encode(t, add_special_tokens=False) for t in stop_tokens]
        if temperature > 0:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=True,
                    temperature=temperature,
                    eos_token_id=self.eos_token_ids)
        else:
            generation_kwargs = dict(max_new_tokens=max_n_tokens, 
                    do_sample=False,
                    eos_token_id=self.eos_token_ids,
            )

        inputs = [f"You are a helpful and creative assistant who writes well.\nPlease revise the following sentences with no changes to its length and only output the revised version, the sentences are: \n '{sentence}'." 
                for sentence in sentences]
        # === we init with some starting tokens to make sure that the generated sequences are in-distribution=====
        init_msgs = [" ".join(sentence.split()[:2]) for sentence in sentences]
        inputs = [self.template.format(instruction=s).strip() + " " + m  for s, m in zip(inputs,init_msgs)]

        generation_function = find_executable_batch_size(self.batch_generate_bs, self.generation_batch_size)
        outputs = generation_function(inputs, **generation_kwargs)
        outputs = [m + " " + o.strip() for o,m in zip(outputs, init_msgs)]
        outputs = self.process_outputs(outputs)

        return outputs

class TogetherAPI(MutateModel):
    def __init__(self, model_name_or_path, together_api_key, **model_kwargs):
        self.together_api_key = together_api_key
        self.endpoint = 'https://api.together.xyz/inference'

        self.model_name = model_name_or_path
        self.template = get_template(model_name_or_path, **model_kwargs)['prompt']
        self.init_msg = "Sure, here's the revised version:\n"

    def batched_generate(self, 
                    sentences,
                    max_n_tokens: int=1024, 
                    temperature: float=1.0,
                    stop_tokens: List[str]=[],
                    **kwargs):
        
        inputs = [f"You are a helpful and creative assistant who writes well.\nPlease revise the following sentences with no changes to its length and only output the revised version, the sentences are: \n '{sentence}'." 
                for sentence in sentences]
        
        # === we init with some starting tokens to make sure that the generated sequences are in-distribution=====
        init_msgs = [" ".join(sentence.split()[:2]) for sentence in sentences]
        inputs = [self.template.format(instruction=s).strip() + " " + m  for s, m in zip(inputs,init_msgs)]

        outputs = []
        for input in inputs:
            output = requests.post(self.endpoint, json={
                "model": self.model_name,
                "max_tokens": max_n_tokens,
                "prompt": input,
                "request_type": "language-model-inference",
                "temperature": temperature,
                "stop": [
                    "</s>",
                    "[INST]"
                ] + stop_tokens,
                "negative_prompt": "",
            }, headers={
                "Authorization": f"Bearer {self.together_api_key}",
            }).json()

            output = output['output']['choices'][0]['text']
            outputs.append(output)

        outputs = [m + " " + o.strip() for o,m in zip(outputs, init_msgs)]
        outputs = self.process_outputs(outputs)
        
        return outputs
