# Copyright (c) Guangsheng Bao.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import time

import numpy as np
import datasets
import torch
import random
import argparse
import os
import json
import custom_datasets
from model import load_tokenizer, load_model
import anthropic
from groq import Groq
from openai import AzureOpenAI
import transformers
from transformers import AutoTokenizer


client_claude = anthropic.Anthropic(
    # defaults to 
    api_key=os.environ.get("ANTHROPIC_API_KEY",'xxxx'),
)
client_groq = Groq(
    api_key=os.environ.get("GROQ_API_KEY",'xxxx'),
)

import google.generativeai as genai
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY",'xxxx'))
client_gemini = genai.GenerativeModel('gemini-pro')

azure_key = 'b092bdedfd8942a3a38d05a5ada832a2'
azure_endpoint = 'https://hyunai.openai.azure.com/'
client_gpt4 = AzureOpenAI(
  azure_endpoint = azure_endpoint, 
  api_key=os.environ.get("AZURE_API_KEY",'xxxx'),  
  api_version="2024-02-01"
)
def generate_gemini(sys_prompt,cur_prompt,kwargs):
    generation_config=genai.types.GenerationConfig(
        # Only one candidate for now.
        candidate_count=1,
        stop_sequences=['x'],
        max_output_tokens=kwargs['max_tokens'],
        temperature=kwargs['temperature'])
    safety_settings=[
        {
            "category": "HARM_CATEGORY_DANGEROUS",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_HARASSMENT",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_HATE_SPEECH",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
            "threshold": "BLOCK_NONE",
        },
        {
            "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
            "threshold": "BLOCK_NONE",
        },
    ]
    response = client_gemini.generate_content(cur_prompt,generation_config=generation_config, safety_settings=safety_settings)
    cur_text = response.text
    return cur_text

def generate_claude(sys_prompt,cur_prompt,kwargs):
    message = client_claude.messages.create(
        model ='claude-3-opus-20240229',
        **kwargs,
        system=sys_prompt,
        messages=[
            {"role": "system","content": sys_prompt},
            {"role": "user", "content": cur_prompt}
        ]
    )
    cur_text = message.content[0].text
    return cur_text

def generate_gpt4turbo(sys_prompt,cur_prompt,kwargs):
    model_name = 'gpt-4-turbo-0409'
    response = client_gpt4.chat.completions.create(
            model=model_name,
            messages=[
                {"role": "system","content": sys_prompt},
                {"role": "user", "content": cur_prompt}],
            **kwargs,
            logit_bias=None,
        )
    cur_text = response.choices[0].message.content
    return cur_text

def generate_claude_sonnet(sys_prompt,cur_prompt,kwargs):
    message = client_claude.messages.create(
        model ='claude-3-sonnet-20240229',
        **kwargs,
        system=sys_prompt,
        messages=[
            {"role": "user", "content": cur_prompt}
        ]
    )
    cur_text = message.content[0].text
    return cur_text

def generate_claude_haiku(sys_prompt,cur_prompt,kwargs):
    message = client_claude.messages.create(
        model ='claude-3-haiku-20240307',
        **kwargs,
        system=sys_prompt,
        messages=[
            {"role": "user", "content": cur_prompt}
        ]
    )
    cur_text = message.content[0].text
    return cur_text


def generate_llama(sys_prompt,cur_prompt,kwargs):
    chat_completion = client_groq.chat.completions.create(
        messages=[
            {"role": "system","content": sys_prompt},
            {"role": "user","content": cur_prompt,}
        ],
        stop=None,
        model="llama3-70b-8192",
        **kwargs,        
    )
    cur_text = chat_completion.choices[0].message.content
    return cur_text



def save_data(output_file, args, data):
    # write args to file
    args_file = f"{output_file}.args.json"
    with open(args_file, "w") as fout:
        json.dump(args.__dict__, fout, indent=4)
        print(f"Args written into {args_file}")

    # write the data to a json file in the save folder
    data_file = f"{output_file}.raw_data.json"
    with open(data_file, "w") as fout:
        json.dump(data, fout, indent=4)
        print(f"Raw data written into {data_file}")


def load_data(input_file):
    data_file = f"{input_file}.raw_data.json"
    print(data_file)
    with open(data_file, "r") as fin:
        data = json.load(fin)
        print(f"Raw data loaded from {data_file}")
    return data


class DataBuilder:
    def __init__(self, args):
        self.args = args
        self.base_tokenizer = load_tokenizer(args.base_model_name, args.dataset, args.cache_dir)
        self.base_model = None if args.openai_model else load_model(args.base_model_name, args.device, args.cache_dir)
           
    def _openai_sample(self, prefix):
        def _drop_last_word(text):
            return ' '.join(text.split(' ')[:-1])

        if self.args.dataset != 'pubmed':  # keep Answer: prefix for pubmed
            prefix = _drop_last_word(prefix)

        # sample from the openai model
        kwargs = {"max_tokens": 200}
        if self.args.do_top_p:
            kwargs['top_p'] = self.args.top_p
        elif self.args.do_top_k:
            kwargs['top_k'] = self.args.top_k
        elif self.args.do_temperature:
            kwargs['temperature'] = self.args.temperature

        if self.args.openai_model == 'davinci':
            import openai
            assert self.args.openai_key is not None, "Must provide OpenAI API key as --openai_key"
            openai.api_key = self.args.openai_key
            if self.args.openai_base is not None:
                openai.api_base = self.args.openai_base
                kwargs["engine"] = self.args.openai_model
            response = openai.Completion.create(prompt=f"{prefix}", **kwargs)
            return prefix + response['choices'][0]['text']

        elif self.args.openai_model in ['gpt-3.5-turbo', 'gpt-4','gemini','llama','claude','claude_sonnet','gpt4turbo','claude_haiku']:
            roles = {'xsum': 'You are a News writer.',
                     'writing': 'You are a Fiction writer.',
                     'pubmed': 'You are a Technical writer.'}
            prompts = {'xsum': 'Please write an article with about 150 words starting exactly with:',
                       'writing': 'Please write an article with about 150 words starting exactly with:',
                       'pubmed': 'Please answer the question in about 50 words.'}
            system_prompt = roles[self.args.dataset]
            user_prompt = f'{prompts[self.args.dataset]} {prefix}'
            # print(args.openai_model)
            if args.openai_model=='gemini':
                response = generate_gemini(system_prompt,user_prompt,kwargs)
            elif args.openai_model=='llama':
                response = generate_llama(system_prompt,user_prompt,kwargs)
            elif args.openai_model=='claude':
                response = generate_claude(system_prompt,user_prompt,kwargs)
            elif args.openai_model=='claude_sonnet':
                response = generate_claude_sonnet(system_prompt,user_prompt,kwargs)
            elif args.openai_model=='claude_haiku':
                response = generate_claude_haiku(system_prompt,user_prompt,kwargs)
            elif args.openai_model=='gpt4turbo':
                response = generate_gpt4turbo(system_prompt,user_prompt,kwargs)
            else:
                import openai
                assert self.args.openai_key is not None, "Must provide OpenAI API key as --openai_key"
                openai.api_key = self.args.openai_key
                if self.args.openai_base is not None:
                    openai.api_base = self.args.openai_base
                kwargs["engine"] = self.args.openai_model
                messages = [
                    {'role': 'system', 'content': roles[self.args.dataset]},
                    {'role': 'user', 'content': f'{prompts[self.args.dataset]} {prefix}'},
                ]
                kwargs["model"] = self.args.openai_model
                kwargs["messages"] = messages
                response = openai.ChatCompletion.create(**kwargs)
                response = response['choices'][0]['message']['content']
            # ChatGPT may repeat the prefix
            # print(response)
            if response.startswith(prefix[:20]):
                return response
            return prefix + ' ' + response

        else:
            raise NotImplementedError

    # sample from base_model using ****only**** the first 30 tokens in each example as context
    def _sample_from_model(self, texts, min_words=55, prompt_tokens=30):
        # encode each text as a list of token ids
        if self.args.dataset == 'pubmed':
            texts = [t[:t.index(custom_datasets.SEPARATOR)] for t in texts]
            all_encoded = self.base_tokenizer(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.args.device)
        else:
            all_encoded = self.base_tokenizer(texts, return_tensors="pt", padding=True, return_token_type_ids=False).to(self.args.device)
            all_encoded = {key: value[:, :prompt_tokens] for key, value in all_encoded.items()}

        if self.args.openai_model:
            # decode the prefixes back into text
            prefixes = self.base_tokenizer.batch_decode(all_encoded['input_ids'], skip_special_tokens=True)

            decoded = []
            for idx, prefix in enumerate(prefixes):
                try_num=0
                while idx >= len(decoded) and try_num<10:
                    try:
                        decoded.append(self._openai_sample(prefix))
                    except Exception as ex:
                        try_num+=1
                        print(ex)
                        print(f'{try_num}th try failed: Wait 10 minutes before retry ...')
                        time.sleep(10)

        else:
            self.base_model.eval()
            decoded = ['' for _ in range(len(texts))]

            # sample from the model until we get a sample with at least min_words words for each example
            # this is an inefficient way to do this (since we regenerate for all inputs if just one is too short), but it works
            tries = 0
            m = 0
            while m < min_words:
                if tries != 0:
                    print()
                    print(f"min words: {m}, needed {min_words}, regenerating (try {tries})")
                    prefixes = self.base_tokenizer.batch_decode(all_encoded['input_ids'], skip_special_tokens=True)
                    for prefix, x in zip(prefixes, decoded):
                        if len(x.split()) == m:
                            print(prefix, '=>', x)

                sampling_kwargs = {}
                if self.args.do_top_p:
                    sampling_kwargs['top_p'] = self.args.top_p
                elif self.args.do_top_k:
                    sampling_kwargs['top_k'] = self.args.top_k
                elif self.args.do_temperature:
                    sampling_kwargs['temperature'] = self.args.temperature
                min_length = 50 if self.args.dataset in ['pubmed'] else 150
                outputs = self.base_model.generate(**all_encoded, min_length=min_length, max_length=200, do_sample=True,
                                                   **sampling_kwargs, pad_token_id=self.base_tokenizer.eos_token_id,
                                                   eos_token_id=self.base_tokenizer.eos_token_id)
                decoded = self.base_tokenizer.batch_decode(outputs, skip_special_tokens=True)
                
                m = min(len(x.split()) for x in decoded)
                tries += 1

        return decoded

    def generate_samples(self, raw_data, batch_size):
        # trim to shorter length
        def _trim_to_shorter_length(texta, textb):
            # truncate to shorter of o and s
            shorter_length = min(len(texta.split(' ')), len(textb.split(' ')))
            texta = ' '.join(texta.split(' ')[:shorter_length])
            textb = ' '.join(textb.split(' ')[:shorter_length])
            return texta, textb

        def _truncate_to_substring(text, substring, idx_occurrence):
            # truncate everything after the idx_occurrence occurrence of substring
            assert idx_occurrence > 0, 'idx_occurrence must be > 0'
            idx = -1
            for _ in range(idx_occurrence):
                idx = text.find(substring, idx + 1)
                if idx == -1:
                    return text
            return text[:idx]

        data = {
            "original": [],
            "sampled": [],
        }

        for batch in range(len(raw_data) // batch_size):
            print('Generating samples for batch', batch, 'of', len(raw_data) // batch_size)
            original_text = raw_data[batch * batch_size:(batch + 1) * batch_size]
            sampled_text = self._sample_from_model(original_text, min_words=30 if self.args.dataset in ['pubmed'] else 55)

            for o, s in zip(original_text, sampled_text):
                if self.args.dataset == 'pubmed':
                    s = _truncate_to_substring(s, 'Question:', 2)
                    o = o.replace(custom_datasets.SEPARATOR, ' ')

                o, s = _trim_to_shorter_length(o, s)

                # add to the data
                data["original"].append(o)
                data["sampled"].append(s)

        return data

def generate_data(args, dataset, key):
    # strip newlines from each example; replace one or more newlines with a single space
    def _strip_newlines(text):
        return ' '.join(text.split())

    # load data
    if dataset in custom_datasets.DATASETS:
        data = custom_datasets.load(dataset, args.cache_dir)
    else:
        data = custom_datasets.load_dataset(dataset, split='train', cache_dir=args.cache_dir)[key]
    # get unique examples, strip whitespace, and remove newlines
    # then take just the long examples, shuffle, take the first 5,000 to tokenize to save time
    # then take just the examples that are <= 512 tokens (for the base model)
    # then generate n_samples samples

    # remove duplicates from the data
    data = list(dict.fromkeys(data))  # deterministic, as opposed to set()

    # strip whitespace around each example
    data = [x.strip() for x in data]

    # remove newlines from each example
    data = [_strip_newlines(x) for x in data]

    # try to keep only examples with > 250 words
    if dataset in ['writing', 'squad', 'xsum']:
        long_data = [x for x in data if len(x.split()) > 250]
        if len(long_data) > 0:
            data = long_data

    random.shuffle(data)
    data = data[:5_000]

    # keep only examples with <= 512 tokens according to base_tokenizer
    # this step has the extra effect of removing examples with low-quality/garbage content
    data_builder = DataBuilder(args)
    tokenized_data = data_builder.base_tokenizer(data)
    data = [x for x, y in zip(data, tokenized_data["input_ids"]) if len(y) <= 512]

    # print stats about remaining data
    print(f"Total number of samples: {len(data)}")
    print(f"Average number of words: {np.mean([len(x.split()) for x in data])}")

    return data_builder.generate_samples(data[:args.n_samples], batch_size=args.batch_size)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--output_file', type=str, default="./exp_gpt3/data/xsum_gpt2")
    parser.add_argument('--dataset', type=str, default="xsum")
    parser.add_argument('--n_samples', type=int, default=200)
    parser.add_argument('--openai_base', type=str, default=None)
    parser.add_argument('--openai_key', type=str, default=None)
    parser.add_argument('--openai_model', type=str, default=None)  # davinci, gpt-3.5-turbo, gpt-4
    parser.add_argument('--base_model_name', type=str, default="gpt2")
    parser.add_argument('--batch_size', type=int, default=50)
    parser.add_argument('--do_top_k', action='store_true')
    parser.add_argument('--top_k', type=int, default=40)
    parser.add_argument('--do_top_p', action='store_true')
    parser.add_argument('--top_p', type=float, default=0.96)
    parser.add_argument('--do_temperature', action='store_true')
    parser.add_argument('--temperature', type=float, default=0.8)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--device', type=str, default="cuda")
    parser.add_argument('--cache_dir', type=str, default="../cache")
    args = parser.parse_args()

    os.environ["XDG_CACHE_HOME"] = args.cache_dir
    if not os.path.exists(args.cache_dir):
        os.makedirs(args.cache_dir)
    print(f"Using cache dir {args.cache_dir}")

    random.seed(args.seed)
    torch.manual_seed(args.seed)
    np.random.seed(args.seed)

    print(f'Loading dataset {args.dataset}...')
    dataset_keys = {'xsum': 'document', 'squad': 'context', 'writing': 'document'}
    data = generate_data(args, args.dataset, dataset_keys[args.dataset] if args.dataset in dataset_keys else None)
    save_data(args.output_file, args, data)
