import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
import numpy as np
import os
import argparse
from datasets import load_dataset, disable_caching
from tqdm.auto import tqdm
from utils import get_dataset
import json
from datetime import datetime
from transformers import set_seed
import outlines
from pydantic import BaseModel, Field, constr
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.samplers import greedy, multinomial
import ctrlg

class Gsm8kResponse(BaseModel):
    reason: constr(max_length=1000)
    answer: int = Field(pattern=r'[1-9][0-9]{0,9}')


class SummevalResponse(BaseModel):
    analysis: constr(max_length=2000)
    rating: int = Field(pattern=r'[1-5]')

HMM_MODEL = {
    "llama8b": "",
    "llama70b": "",
    "qwen7b": "",
}

def main(args):
    set_seed(42)
    generator = None
    
    hmm_path = HMM_MODEL[args.model]
    if args.model == "llama8b":
        args.model = "meta-llama/Llama-3.1-8B-Instruct"
    elif args.model == "llama70b":
        args.model = "meta-llama/Llama-3.1-70B-Instruct"
    elif args.model == "qwen7b":
        args.model = "Qwen/Qwen2.5-7B-Instruct"
    
    dataset = get_dataset(args.dataset, args.instruction_format, args.num_shots, score_type=args.score_type)
    split = None
    if args.dataset == "summeval":
        split = dataset
    else:
        if args.split == "train":
            split = dataset['train']
        elif args.split == "test":
            split = dataset['test']
            
    if args.test_run:
        split = split.select(np.random.choice(len(split), 100))
    
    if args.control_type == "decog":
        model = AutoModelForCausalLM.from_pretrained(
            args.model, 
            device_map="auto",
            # device_map="balanced_low_0",
            torch_dtype=torch.bfloat16, 
            trust_remote_code=True
        )
    else:
        model = outlines.models.transformers(
            args.model,
            model_kwargs={
                'device_map': "auto",
                'torch_dtype': torch.bfloat16,
                'trust_remote_code': True
            }
        )

    # generator
    if args.control_type == "unstructured":
        generator = outlines.generate.text(model, sampler=greedy())
        
    elif args.control_type == "structured":
        if args.instruction_format == "json":
            schema_regex = None
            if args.dataset == "gsm8k":
                schema_regex = build_regex_from_schema(json.dumps(Gsm8kResponse.model_json_schema()))
            elif args.dataset == "summeval":
                schema_regex = build_regex_from_schema(json.dumps(SummevalResponse.model_json_schema()))    
            generator = outlines.generate.regex(model, schema_regex, sampler=greedy())
        elif args.instruction_format == "nl":
            regex = None
            if args.dataset == "gsm8k":
                print("start compiling regex")
                regex = r'[\w \d\.\*\-\=\+\,\?\/\n]{50,1000}The final answer is [\$\,\d]{0,20}\.'
            elif args.dataset == "summeval":
                regex = r'[\w \d\.\*\-\=\+\,\?\/\n]{50,1000}Rating: [1-5]\.'
            generator = outlines.generate.regex(model, regex, sampler=greedy())
        else:
            raise ValueError("Invalid format type")
        print("done with regex")
        
    elif args.control_type == "decog":
        hmm_model = ctrlg.HMM.from_pretrained(hmm_path).to('cuda')
        
        vocab_size = hmm_model.vocab_size
        eos_token_id = hmm_model.eos_token_id

        ac_builder = ctrlg.AhoCorasickBuilder(vocab_size)
        eos_builder = ctrlg.EOSBuilder(vocab_size, eos_token_id)
        
        
    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    
    formatted_time = datetime.now().strftime("%b%d_%H_%M")
    test_run_string = "_test_" if args.test_run else ""
    
    alpha_str = "alpha{args.alpha}-" if args.control_type=="decog" else ""
    if args.dataset == "summeval":
        file_name = os.path.join(
            args.output_dir, 
            f'{test_run_string}{args.dataset}-{args.score_type}-{args.control_type}-{args.instruction_format}-{alpha_str}{os.path.basename(args.model)}-{formatted_time}.jsonl'
        )
    else:
        file_name = os.path.join(
            args.output_dir, 
            f'{test_run_string}{args.dataset}-{args.split}-{args.control_type}-{args.instruction_format}-{args.num_shots}shot-{alpha_str}{os.path.basename(args.model)}-{formatted_time}.jsonl'
        )
    
    if generator is not None:
        with open(file_name, 'w') as f:
            pbar = tqdm(range(int(np.ceil(len(split)/args.batch_size))))
            for idx in pbar:
                start, end = idx*args.batch_size, (idx+1)*args.batch_size
                batch = split[start:end]
                prompts = []
                labels = batch["label"]
                
                if "Llama" in args.model or "Qwen" in args.model:
                    for fp in batch["full_prompt"]:
                        conversation = [{
                            "role": "user",
                            "content": fp,
                        }]
                        prompt = tokenizer.apply_chat_template(
                            conversation,
                            tokenize=False,
                            add_generation_prompt=True,
                            return_dict=False,
                        )
                        prompts.append(prompt)
                
                output = generator(prompts, max_tokens=500)
                if args.dataset == "gsm8k":
                    for o, l, q in zip(output, labels, batch["question"]):
                        line = {
                            "question": q,
                            "label": l,
                            "output": o,
                        } 
                        f.write(json.dumps(line) + "\n")
                elif args.dataset == "summeval":
                    for o, l, t, s in zip(output, labels, batch["text"], batch["machine_summary"]):
                        line = {
                            "text": t,
                            "summary": s,
                            "label": l,
                            "output": o,
                        } 
                        f.write(json.dumps(line) + "\n")
    else:
        with open(file_name, 'w') as f:
            pbar = tqdm(range(len(split)))
            for idx in pbar:
                example = split[idx]
                fp = example["full_prompt"]
                label = example["label"]

                if "Llama" in args.model:
                    conversation = [
                        {"role": "user", "content": fp}
                    ]
                    prompt = tokenizer.apply_chat_template(
                        conversation,
                        tokenize=False,
                        add_generation_prompt=True,
                        return_dict=False,
                    )
                elif "Qwen" in args.model:
                    conversation = [
                        {"role": "user", "content": fp}
                    ]
                    prompt = tokenizer.apply_chat_template(
                        conversation,
                        tokenize=False,
                        add_generation_prompt=True,
                        return_dict=False,
                    )
                else:
                    raise ValueError("Invalid model")
                
                # works for llama3 but not necessarily for other llms
                prefix_ids = tokenizer.encode(fp, add_special_tokens=False)
                
                suffix_ids = []
                prompt_ids = tokenizer.encode(prompt)
                
                if args.dataset == "gsm8k":
                    if "Llama" in args.model:
                        phrases_list = [['The final answer is', 'The answer is']]       # Llama
                    elif "Qwen" in args.model:
                        phrases_list = [['The final answer is: **', 'The answer is: **']]       # Qwen
                    else:
                        raise ValueError("Invalid model")
                elif args.dataset == "summeval":
                    if "Llama" in args.model:
                        phrases_list = [["Rating: ", " Rating: "]]  # Llama
                    elif "Qwen" in args.model:
                        phrases_list = [["Rating: **"]]             # Qwen
                    else:
                        raise ValueError("Invalid model")
                
                dfa_graphs = []
                for phrases in phrases_list:
                    patterns = []
                    for phrase in phrases:
                        pattern = ctrlg.generate_patterns(tokenizer, phrase)
                        patterns.extend(pattern)
                    dfa_graphs.append(ac_builder.build(patterns))
                dfa_graphs.append(eos_builder.build())
                dfa_graphs = ctrlg.DFA_prod(dfa_graphs, mode='intersection')
                dfa_model = ctrlg.DFAModel(dfa_graphs, vocab_size).to('cuda')

                min_new_tokens = 1
                max_new_tokens = 500
                
                # initialze the constraints logits processor
                constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
                    hmm_model, dfa_model,
                    min_new_tokens, max_new_tokens,
                    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids,
                    alpha=args.alpha)

                beam_size = 1
                temperature = 0.0
                constraint_logits_processor.hmm_batch_size = beam_size
                constraint_logits_processor.temperature = 1.0

                input_ids = torch.tensor([prompt_ids], device='cuda')
                output = model.generate(
                    temperature=temperature,
                    input_ids=input_ids,
                    do_sample=False,
                    num_return_sequences=beam_size, 
                    min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
                    logits_processor=LogitsProcessorList([constraint_logits_processor]),
                )
                
                output = tokenizer.decode(output[0], skip_special_tokens=True)
                if ("Llama" in args.model or "Qwen" in args.model):
                    output = output[output.find("assistant\n")+len("assistant\n"):]
                else:
                    output = output[len(fp):]
                print(output)
                
                if args.dataset == "gsm8k":
                    line = {
                        "question": example["question"],
                        "label": label,
                        "output": output,
                    } 
                    f.write(json.dumps(line) + "\n")
                elif args.dataset == "summeval":
                    line = {
                        "text": example["text"],
                        "summary": example["machine_summary"],
                        "label": label,
                        "output": output,
                    } 
                    f.write(json.dumps(line) + "\n")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='qwen7b', choices=["llama8b", "llama70b", "qwen7b"], help='model name')
    parser.add_argument('--dataset', type=str, default="gsm8k", choices=["gsm8k", "summeval"], help='dataset')
    parser.add_argument('--split', type=str, default='test', choices=['train','test'], help='split')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size')
    parser.add_argument('--instruction_format', type=str, default="no-format", choices=["nl", "json", "no-format"], help='format type')
    parser.add_argument('--control_type', type=str, default="decog", choices=["structured", "unstructured", "decog"], help='control type')
    parser.add_argument('--score_type', type=str, default="coherence", choices=["coherence", "fluency", "relevance", "consistency"], help='score type')
    parser.add_argument('--num_shots', type=int, default=0, help='number of few-shot examples')
    parser.add_argument('--alpha', type=float, default=1.0, help='alpha')
    parser.add_argument('--output_dir', type=str, default="", help='output directory')
    parser.add_argument('--test_run', action='store_true', help='test run')
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    main(args)
