import os
import random
import torch
import argparse
import numpy as np
import pandas as pd

from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

from utils import PHD


MIN_SUBSAMPLE = 40 
INTERMEDIATE_POINTS = 7


# Set random seed for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)


def get_phd_single(text, solver, tokenizer, model, device, max_length=512, prompt_number=0):
    inputs = tokenizer(text, truncation=True, max_length=max_length, return_tensors="pt").to(device)
    with torch.no_grad():
        outp = model(**inputs)
    
    # We omit the first and last tokens (<CLS> and <SEP> because they do not directly correspond to any part of the)
    outp = outp[0][0].cpu().numpy()[1:]
    outp = outp[prompt_number+1:]
    
    # vmx_points = inputs['input_ids'].shape[1] - 2
    mx_points = outp.shape[0]

    mn_points = MIN_SUBSAMPLE
    step = ( mx_points - mn_points ) // INTERMEDIATE_POINTS
    
    return solver.fit_transform(outp,  min_points=mn_points, max_points=mx_points - step, \
                                point_jump=step)


def truncate_text(text, tokenizer, truncation_length):
    tokens = tokenizer.tokenize(text)
    tokens = tokens[:truncation_length]
    truncated_text = tokenizer.convert_tokens_to_string(tokens)
    return truncated_text


# Main training loop
def main(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_devices

    set_seed(args.seed)

    df = pd.read_json(os.path.join(args.data_dir, args.filename), lines=True)
    df = df.head(2)

    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    tokenizer = AutoTokenizer.from_pretrained(args.base_lm_name, cache_dir=args.cache_dir)
    model = AutoModel.from_pretrained(args.base_lm_name, cache_dir=args.cache_dir)
    model.to(device)
    
    solver = PHD(alpha=1.0, metric='euclidean', n_points=9)
    
    dims = []
    dims_prompt = []

    prompts = [
        'Identifying text generated by large language models (LLMs) like GPT can be challenging, but there are some strategies and telltale signs that can help. Here are some suggestions: Repetition and Redundancy and Overly Polished or Generic Language. ',
        'LLM-generated text may contain repetitive phrases or ideas. Although models are improving, they might still circle back to the same points or restate them in slightly different ways. ',
        'LLMs often produce text that is well-structured and free from obvious errors, sometimes appearing overly polished. The language may also be somewhat generic, lacking the nuance, style, or personal touch that a human writer might include. ',
        'The tone or style of LLM-generated text might fluctuate unexpectedly. For instance, a paragraph may shift from formal to informal language or from being highly detailed to overly vague. ',
        'While LLMs are generally accurate, they can make mistakes that are unusual for humans, such as misinterpreting context, creating fictional facts, or making nonsensical statements. ',
        'LLMs can generate text that seems insightful but lacks true depth. The text may superficially address a topic without providing meaningful analysis or understanding. ',
        'The phrasing in LLM-generated text might sound slightly off, either too formal or too casual, or use uncommon word combinations that a human writer might not choose. ',
        'LLMs might draw connections between ideas that don\'t logically follow from each other. This can result in text that feels disjointed or where the flow of argument is not clear. ',
        'Some LLMs may overuse certain phrases or sentence structures, especially ones that were prominent in their training data. ',
        'Use AI detection tools designed to identify LLM-generated content. These tools analyze patterns in the text to predict whether it was written by a human or a machine. ',
        'LLMs may provide inconsistent answers when asked the same question in different ways. A human writer is more likely to maintain consistent opinions or facts across different contexts. ',
        'Using a combination of these strategies can help you more reliably identify LLM-generated text, although it’s worth noting that as AI models improve, the distinction between human and machine-generated text is becoming increasingly subtle. ',
    ]
    print('Number of prompts: ', len(prompts))

    for _, row in tqdm(df.iterrows(), total=df.shape[0]):
        text = row['text']
        text = truncate_text(text, tokenizer, args.truncation_length)

        dim = get_phd_single(text=text, 
                             solver=solver,
                             model=model,
                             tokenizer=tokenizer, 
                             device=device,
                             max_length=args.max_length+2,
                             prompt_number=0,
                             )
        dims.append(dim)

        tmp = []
        for j in range(len(prompts)):
            prompt = prompts[j]
            prompt_inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
            prompt_number = prompt_inputs['input_ids'].shape[1]-2
            dim = get_phd_single(text=prompt+text, 
                                solver=solver,
                                model=model,
                                tokenizer=tokenizer, 
                                device=device,
                                max_length=args.max_length+2+prompt_number,
                                prompt_number=0,
                                )
            tmp.append(dim)
        dims_prompt.append(tmp)

    df[f'dim'] = dims
    df[f'dim_prompt'] = dims_prompt

    save_filename_prefix = args.filename.split('.')[0]
    df.to_json(os.path.join(args.data_dir, f'{save_filename_prefix}_dim_truncate_{args.truncation_length}_temp.jsonl'), orient='records', lines=True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--cuda_devices", type=str, default='0', help="CUDA devices to use")
    parser.add_argument("--data_dir", type=str, default='data/generated_dataset_temp', help="Data directory")
    parser.add_argument("--filename", type=str, default='human_text.jsonl', help="Data filename")
    parser.add_argument("--base_lm_name", type=str, default='meta-llama/Meta-Llama-3-8B-Instruct', help="Base LM model name")
    parser.add_argument("--cache_dir", type=str, default='/data/mjmao/ood/hf_models', help="Cache directory for models")
    parser.add_argument("--max_length", type=int, default=1000, help="Max length of input text")
    parser.add_argument("--truncation_length", type=int, default=150, help="Truncation number of tokens")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")

    args = parser.parse_args()
    main(args)
