import re
import os
import warnings
import json

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

warnings.filterwarnings("ignore")

import torch
import pickle
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from outlines.types import Regex
from outlines.models import transformers
import outlines

device = "cpu" #"cuda:0"

MODEL_DIR = "/home/vcollura/TRIDENT/ilm-master/models/sto_ilm"

def load_model_and_tokenizer(model_dir: str, device: str = None):
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    config_path = os.path.join(model_dir, "config.json")
    config = GPT2Config.from_json_file(config_path)

    model = GPT2LMHeadModel(config)
    state_dict = torch.load(os.path.join(model_dir, "pytorch_model.bin"), map_location="cpu")
    missing, unexpected = model.load_state_dict(state_dict, strict=False)

    if missing:
        pass
    if unexpected:
        pass

    model.to(device)
    model.eval()

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    pkl_path = os.path.join(model_dir, "additional_ids_to_tokens.pkl")
    new_tokens = []
    if os.path.exists(pkl_path):
        with open(pkl_path, "rb") as f:
            additional_tokens = pickle.load(f)

        if isinstance(additional_tokens, dict):
            new_tokens = list(additional_tokens.values())
        elif isinstance(additional_tokens, list):
            new_tokens = additional_tokens
        else:
            raise ValueError("additional_ids_to_tokens.pkl")

        if new_tokens:
            tokenizer.add_tokens(new_tokens)
            model.resize_token_embeddings(len(tokenizer))
            for token in new_tokens:
                token_id = tokenizer.convert_tokens_to_ids(token)
                print(f"Token: {token} -> ID: {token_id}")
    else:
        print("No additional_ids_to_tokens.pkl.")

    return model, tokenizer, device

model_dir = "" # model dir
hf_model, tokenizer, device = load_model_and_tokenizer(model_dir, device)

model = transformers.Transformers(hf_model, tokenizer)

def extract_concepts(text):
    pattern = r'(.*?)(<\|.*?\|>)(.*)'
    concepts = []
    while text:
        match = re.match(pattern, text, re.DOTALL)
        if match:
            if match.group(1):
                concepts.append(match.group(1))
            concepts.append(match.group(2))
            text = match.group(3)
        else:
            if text:
                concepts.append(text)
            break
    return concepts

def generate_infilled_text(prompt):
    concepts = extract_concepts(prompt)
    current_prompt = ""
    spans = []
    
    for concept in concepts:
        #print("Concept:", concept)
        if concept.startswith('<|') and concept.endswith('|>'):
            if concept == '<|infill_ngram|>':
                regex_pattern = r"[a-zA-Z0-9' ,.!?]+(?:[ ,.][a-zA-Z0-9' ]+)*" # + tokenizer.eos_token
                max_new_tokens = 9
            elif concept == '<|infill_word|>':
                regex_pattern = r"[ ]?[a-zA-Z0-9'.!,?]+[ ]?" # + tokenizer.eos_token
                max_new_tokens = 4
            elif concept == '<|infill_sentence|>':
                regex_pattern = r"[a-zA-Z0-9', ]+[.!?]" # + tokenizer.eos_token
                max_new_tokens = 16
            else:
                current_prompt += concept
                continue
            
            regex_constraint = Regex(regex_pattern)
            generator = outlines.Generator(model, regex_constraint)
            generated_text = generator(current_prompt, max_new_tokens=max_new_tokens)      
            spans.append(generated_text)
            current_prompt += generated_text
        else:
            current_prompt += concept
            
    return current_prompt, spans

def process_files(file_paths):
    for file_path in file_paths:
        print(f"Processing {file_path}...")
        
        try:
            with open(file_path, 'r') as f:
                data = json.load(f)
        except Exception as e:
            print(e)
            continue
        
        output_dict = {}
        for item in data:
            try:
                generated_story, spans = generate_infilled_text(item["story"])
                generated_text = f"{item['title']}\n{generated_story}"
                
                output_dict[str(item["id"])] = {
                    "title": item["title"],
                    "generated": generated_text,
                    "spans": spans
                }
                
                
            except Exception as e:
                print(f"Error processing item {item['id']}: {str(e)}")
                output_dict[str(item["id"])] = {
                    "title": item["title"],
                    "generated": f"{item['title']}\n{item['story']}",
                    "spans": []
                }
        
        base_name = os.path.splitext(os.path.basename(file_path))[0]
        output_file = f"{base_name}_outlines.json"
        
        with open(output_file, 'w') as f:
            json.dump(output_dict, f, indent=2)
        
        print(f"Saved output to {output_file}")

if __name__ == "__main__":
    file_list = []
    process_files(file_list)