from openai import OpenAI
import hydra
from omegaconf import OmegaConf

import collections
from collections.abc import Iterable
import copy
# import ctranslate2
import datasets
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import random
import sys
import torch
import time
from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer
from policies import LMPolicy
from constrained_samplers import *
import wandb

import re

from lm import (
    conditional_nn_generate,
    predict_error,
    sample_top_p,
    nn_log_probs
)


class WordLengthConstraint:
    def __init__(self, word_len, tokenizer):
        self.word_len = word_len
        self.tokenizer = tokenizer
        self.prompt = f"Generate a one-sentence story with only {word_len}-letter words:\n"
        self.prompt_ids = tokenizer(self.prompt)['input_ids']
    def reward_function(self, sequence_ids):
        sequence_ids = sequence_ids[len(self.prompt_ids):]
        if self.tokenizer.eos_token_id in sequence_ids:
            return 0
            #end_loc = sequence_ids.index(tokenizer.eos_token_id)
            #sequence_ids = sequence_ids[:end_loc]
        s = self.tokenizer.decode(sequence_ids, skip_special_tokens=True)
        clean_text = re.sub(r'[^\w\s]', '', s)
        L = clean_text.split(" ")
        for word in L:
            if len(word) != word_len:
                return 0
        return 1

class LetterConstraint:
    def __init__(self, letter, tokenizer, vocab_size):
        self.letter = letter
        self.tokenizer = tokenizer  
        self.prompt = f"Generate a one-sentence story without using the letter '{self.letter}':\n"
        self.prompt_ids = tokenizer(self.prompt)['input_ids']
        self.id2token = [tokenizer.decode([i]) for i in range(vocab_size)]
        self.satisfies = [(letter not in self.id2token[i]) for i in range(vocab_size)]
    def reward_function(self, sequence_ids):
        sequence_ids = sequence_ids[len(self.prompt_ids):]
        #if self.tokenizer.eos_token_id in sequence_ids:
        #    return 0
            #end_loc = sequence_ids.index(tokenizer.eos_token_id)
            #sequence_ids = sequence_ids[:end_loc]
        eos_token_id = self.tokenizer.eos_token_id
        for token_id in sequence_ids:
            if (token_id == eos_token_id) or not self.satisfies[token_id]:
                return 0
        #s = self.tokenizer.decode(sequence_ids, skip_special_tokens=True)
        #if self.letter in s:
        #    return 0
        return 1

def get_generations(fname, prompt_ids, tokenizer):
    with open(fname,'rb') as f:
        results = pickle.load(f)
    L = []
    for res in results:
        output_ids = res['output_ids']
        output_ids = output_ids[len(prompt_ids):]
        s = tokenizer.decode(output_ids, skip_special_tokens=True)
        L.append(s)
        print("-----------------------------------")
        print(s)
    return L

def judge_coherence(client, texts):
    total_prompt = "You are a writing quality judge. You will be given a pair of text fragments. Say ONLY 'A' if the first fragment is more coherent, or ONLY 'B' if the second fragment is more coherent."
    first_idx = (np.random.random() <= 0.5)
    total_prompt += f"""
        Fragment A: {texts[first_idx]}

        Fragment B: {texts[1-first_idx]}
        """
    backoff_time = 30
    while 1:
        try:
            response = client.responses.create(
                model="gpt-4o-mini",
                input=total_prompt,
                temperature=0
            )
            break
        except:
            print(f"Failed; sleeping for {backoff_time}")
            time.sleep(backoff_time)
            backoff_time *= 2
    print("++++++++++++++++++++++++++++")
    line = response.output_text.strip()
    print(line)
    print("----------------------------")
    alg_wins = 0
    baseline_wins = 0
    if "A" in line and "B" not in line:
        if first_idx == 0:
            alg_wins += 1
        else:
            baseline_wins += 1
    elif "B" in line and "A" not in line:
        if first_idx == 0:
            baseline_wins += 1
        else:
            alg_wins += 1
    else:
        print("Failure to parse judge output line:")
        print(line)
    print("Alg wins:", alg_wins)
    print("Baseline wins:", baseline_wins)
    return alg_wins, baseline_wins

@hydra.main(config_path='../hydra_configs', config_name="qwen_0.5B_cg.yaml", version_base=None)
def main(cfg):
    # add runtime info to cfg
    OmegaConf.set_struct(cfg, False)
    cfg.meta = OmegaConf.create({})
    cfg.meta.original_dir = hydra.utils.get_original_cwd()
    cfg.meta.run_dir = os.getcwd()
    if torch.cuda.is_available():
        free_mem = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
        best_gpu = free_mem.index(max(free_mem))
        cfg.meta.device = f"cuda:{best_gpu}"
    else:
        cfg.meta.device = "cpu"
    print(cfg)
    torch.manual_seed(cfg.seed)
    random.seed(cfg.seed)

    verbose = 1

    #model = AutoModelForCausalLM.from_pretrained(cfg.model.name,torch_dtype=torch.float16).to(cfg.meta.device)
    #newline_or_null = "\n"
    #model.eval()
    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name,torch_dtype=torch.float16)

    #print(model.vocab_size)
    #assert False

    task = LetterConstraint(cfg.generation_configs.letter, tokenizer, cfg.model.vocab_size)
    
    eval_save_path = cfg.fs.eval_save_path
    baseline_save_path = os.path.dirname(cfg.fs.eval_save_path) + f"/{cfg.eval.baseline_sampler}_n{cfg.generation_configs.n}.pkl"

    print(f"Loading generations for {cfg.eval.sampler}...")
    L_alg = get_generations(eval_save_path, task.prompt_ids, tokenizer)

    print(f"Loading generations for {cfg.eval.baseline_sampler}...")
    L_baseline = get_generations(baseline_save_path, task.prompt_ids, tokenizer)

    client = OpenAI(api_key = cfg.openai_api_key)

    texts = [(L_alg[i], L_baseline[i]) for i in range(cfg.generation_configs.n)]

    alg_wins = 0
    baseline_wins = 0
    for i in range(0, cfg.generation_configs.n):
        a,b = judge_coherence(client, texts[i])
        alg_wins += a
        baseline_wins += b
        time.sleep(1)
    print(f"mnt = {cfg.generation_configs.max_new_tokens}; results for {cfg.eval.sampler} vs {cfg.eval.baseline_sampler}: {alg_wins}, {baseline_wins}")
    print(alg_wins, baseline_wins)

if __name__ == '__main__':
    main()
