import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer, LogitsProcessorList
import numpy as np
import os
import argparse
from datasets import load_dataset, disable_caching
from tqdm.auto import tqdm
from utils import get_dataset
import json
from datetime import datetime
from transformers import set_seed
import outlines
from pydantic import BaseModel, Field, constr, create_model
from outlines.fsm.json_schema import build_regex_from_schema
from outlines.samplers import greedy, multinomial
from typing import List
import time
import ctrlg


def generate_ace05_response_model(roles: List[str], max_length: int = 20):
    field_defs = {role: (constr(max_length=max_length), ...) for role in roles}
    return create_model('ArgumentResponse', **field_defs)


HMM_MODEL = {
    "llama8b": "",
    "llama70b": "",
    "qwen7b": ""
}


def main(args):
    set_seed(42)
    
    hmm_path = HMM_MODEL[args.model]
    if args.model == "llama8b":
        args.model = "meta-llama/Llama-3.1-8B-Instruct"
    elif args.model == "llama70b":
        args.model = "meta-llama/Llama-3.1-70B-Instruct"
    elif args.model == "qwen7b":
        args.model = "Qwen/Qwen2.5-7B-Instruct"
    
    split = get_dataset(args.dataset, args.instruction_format, args.num_shots, split=args.split)
    
    if args.test_run:
        split = split.select(np.random.choice(len(split), 100))
    
    if args.control_type == "decog":
        model = AutoModelForCausalLM.from_pretrained(
            args.model, 
            device_map="auto",
            torch_dtype=torch.bfloat16, 
            trust_remote_code=True
        )
    else:
        model = outlines.models.transformers(
            args.model,
            model_kwargs={
                'device_map': "auto",
                'torch_dtype': torch.bfloat16,
                'trust_remote_code': True
            }
        )

    tokenizer = AutoTokenizer.from_pretrained(args.model)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'
    
    formatted_time = datetime.now().strftime("%b%d_%H_%M")
    file_name = os.path.join(
        args.output_dir, 
        f'{args.dataset}-{args.split}-{args.control_type}-wc-{args.instruction_format}'\
        f'-{args.num_shots}shot-alpha{args.alpha}-{os.path.basename(args.model)}-{formatted_time}.jsonl'
    )

    # generator
    if args.control_type == "unstructured":
        generator = outlines.generate.text(model, sampler=greedy())
        with open(file_name, 'w') as f:
            pbar = tqdm(range(int(np.ceil(len(split)/args.batch_size))))
            for idx in pbar:
                start, end = idx*args.batch_size, (idx+1)*args.batch_size
                batch = split[start:end]
                prompts = []
                labels = batch["label"]
                
                if "Llama" in args.model:
                    for fp in batch["full_prompt"]:
                        conversation = [{   
                            "role": "user",
                            "content": fp,
                        }]
                        prompt = tokenizer.apply_chat_template(
                            conversation,
                            tokenize=False,
                            add_generation_prompt=True,
                            return_dict=False,
                        )
                        prompts.append(prompt)
                
                output = generator(prompts, max_tokens=100)
                for o, l in zip(output, labels):
                    line = {
                        "label": l,
                        "output": o,
                    } 
                    f.write(json.dumps(line) + "\n")
        
    elif args.control_type == "structured":
        if args.instruction_format == "json":
            with open(file_name, 'w') as f:
                pbar = tqdm(range(len(split)))
                for idx in pbar:
                    example = split[idx]
                    fp = example["full_prompt"]
                    label = example["label"]
                    # question = example["question"]

                    if "Instruct" in args.model:
                        conversation = [{   
                            "role": "user",
                            "content": fp,
                        }]
                        prompt = tokenizer.apply_chat_template(
                            conversation,
                            tokenize=False,
                            add_generation_prompt=True,
                            return_dict=False,
                        )
                    else:
                        raise ValueError("Invalid model")
                    
                    valid_roles = label["valid_roles"]
                    # Generate the dynamic model
                    Ace05Response = generate_ace05_response_model(valid_roles)
                    print(label["trigger"][2])      # print event type
                    schema_regex = build_regex_from_schema(json.dumps(Ace05Response.model_json_schema()))
                    generator = outlines.generate.regex(model, schema_regex, sampler=greedy())
                    
                    output = generator(prompt, max_tokens=max_new_tokens)
                    line = {
                        "label": label,
                        "output": output,
                    } 
                    f.write(json.dumps(line) + "\n")
                
                
        elif args.instruction_format == "nl":
            raise NotImplementedError("Not implemented yet")
        elif args.instruction_format == "no-format":
            raise NotImplementedError("Not implemented yet")
        
    elif args.control_type == "decog":
        hmm_model = ctrlg.HMM.from_pretrained(hmm_path).to('cuda')
        vocab_size = hmm_model.vocab_size
        eos_token_id = hmm_model.eos_token_id
        wc_builder = ctrlg.WildcardTrieBuilder(vocab_size)
        eos_builder = ctrlg.EOSBuilder(vocab_size, eos_token_id)

        with open(file_name, 'w') as f:
            pbar = tqdm(range(len(split)))
            for idx in pbar:
                example = split[idx]
                fp = example["full_prompt"]
                label = example["label"]
                # question = example["question"]

                if "Instruct" in args.model:
                    conversation = [{   
                        "role": "user",
                        "content": fp,
                    }]
                    prompt = tokenizer.apply_chat_template(
                        conversation,
                        tokenize=False,
                        add_generation_prompt=True,
                        return_dict=False,
                    )
                    # print(prompt)
                else:
                    raise ValueError("Invalid model")
                    
                prefix_ids = tokenizer.encode(fp, add_special_tokens=False)
                # prefix_ids = []
                suffix_ids = []
                prompt_ids = tokenizer.encode(prompt)
            
                valid_roles = label["valid_roles"]
                patterns = []
                if "Llama" in args.model:
                    for i_role, role in enumerate(valid_roles):
                        if i_role == 0:
                            tokenized_role = tokenizer.encode(f'The {role} is:', add_special_tokens=False)  # Llama
                        else:
                            tokenized_role = tokenizer.encode(f'\nThe {role} is:', add_special_tokens=False)    # Llama
                        patterns.append(tokenized_role)
                        # patterns.append((1,4))          # old Llama
                        patterns.append((1,5))
                    patterns.append([tokenizer.eos_token_id])
                    
                elif "Qwen" in args.model:
                    for i_role, role in enumerate(valid_roles):
                        if i_role == 0:
                            tokenized_role = tokenizer.encode(f'### Summary:\nThe {role} is: "', add_special_tokens=False)  # Qwen best
                        else:
                            tokenized_role = tokenizer.encode(f'"\nThe {role} is: "', add_special_tokens=False)  # Qwen best
                        patterns.append(tokenized_role)
                        # patterns.append((1,4))    # qwen
                        patterns.append((1,5))
                    ending_pattern = tokenizer.encode(f'"{tokenizer.eos_token}', add_special_tokens=False)          # Qwen best after 070000
                    patterns.append(ending_pattern)    
                else:
                    raise ValueError("Invalid model")
                
                dfa_graphs = wc_builder.build([patterns])
                dfa_model = ctrlg.DFAModel(dfa_graphs, vocab_size).to('cuda')

                min_new_tokens = 1
                max_new_tokens = 100
                
                # initialze the constraints logits processor
                constraint_logits_processor = ctrlg.ConstraintLogitsProcessor(
                    hmm_model, dfa_model,
                    min_new_tokens, max_new_tokens,
                    prompt_ids, prefix_ids=prefix_ids, suffix_ids=suffix_ids,
                    alpha=args.alpha)

                # set the hmm_batch_size & temperature
                beam_size = 1
                temperature = 0.0
                constraint_logits_processor.hmm_batch_size = beam_size
                constraint_logits_processor.temperature = 1.0

                input_ids = torch.tensor([prompt_ids], device='cuda')
                output = model.generate(
                    temperature=temperature,
                    input_ids=input_ids,
                    do_sample=False,
                    num_return_sequences=beam_size, 
                    min_new_tokens=min_new_tokens, max_new_tokens=max_new_tokens,
                    logits_processor=LogitsProcessorList([constraint_logits_processor]),
                )

                output = tokenizer.decode(output[0], skip_special_tokens=True)
                if "Llama" in args.model or "Qwen" in args.model:
                    output = output[output.find("assistant\n")+len("assistant\n"):]
                else:
                    output = output[len(fp):]
                print(output)

                line = {
                    "label": label,
                    "output": output,
                } 
                f.write(json.dumps(line) + "\n")


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default='qwen7b', choices=["llama8b", "llama70b", "qwen7b"], help='model name')
    parser.add_argument('--dataset', type=str, default="ace05-en", choices=["ace05-en"], help='dataset')
    parser.add_argument('--split', type=str, default='test', choices=['train', 'dev', 'test'], help='split')
    parser.add_argument('--batch_size', type=int, default=16, help='batch size')
    parser.add_argument('--instruction_format', type=str, default="no-format", choices=["nl", "json", "no-format"], help='format type')
    parser.add_argument('--control_type', type=str, default="decog", choices=["structured", "unstructured", "decog"], help='control type')
    parser.add_argument('--num_shots', type=int, default=0, help='number of few-shot examples')
    parser.add_argument('--alpha', type=float, default=1.0, help='alpha')
    parser.add_argument('--output_dir', type=str, default="", help='output directory')
    parser.add_argument('--test_run', action='store_true', help='test run')
    return parser.parse_args()

if __name__ == '__main__':
    args = parse_args()
    main(args)
