import hydra
from omegaconf import OmegaConf

import collections
from collections.abc import Iterable
import copy
# import ctranslate2
import datasets
import numpy as np
import matplotlib.pyplot as plt
import os
import pickle
import random
import sys
import torch
import time
from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer
from policies import LMPolicy
from constrained_samplers import *
import wandb

import re

from lm import (
    conditional_nn_generate,
    predict_error,
    sample_top_p,
    nn_log_probs
)


class WordLengthConstraint:
    def __init__(self, word_len, tokenizer):
        self.word_len = word_len
        self.tokenizer = tokenizer
        self.prompt = f"Generate a one-sentence story with only {word_len}-letter words:\n"
        self.prompt_ids = tokenizer(self.prompt)['input_ids']
    def reward_function(self, sequence_ids):
        sequence_ids = sequence_ids[len(self.prompt_ids):]
        if self.tokenizer.eos_token_id in sequence_ids:
            return 0
            #end_loc = sequence_ids.index(tokenizer.eos_token_id)
            #sequence_ids = sequence_ids[:end_loc]
        s = self.tokenizer.decode(sequence_ids, skip_special_tokens=True)
        clean_text = re.sub(r'[^\w\s]', '', s)
        L = clean_text.split(" ")
        for word in L:
            if len(word) != word_len:
                return 0
        return 1

class LetterConstraint:
    def __init__(self, letter, tokenizer, vocab_size):
        self.letter = letter
        self.tokenizer = tokenizer  
        self.prompt = f"Generate a one-sentence story without using the letter '{self.letter}':\n"
        self.prompt_ids = tokenizer(self.prompt)['input_ids']
        self.id2token = [tokenizer.decode([i]) for i in range(vocab_size)]
        self.satisfies = [(letter not in self.id2token[i]) for i in range(vocab_size)]
    def reward_function(self, sequence_ids):
        sequence_ids = sequence_ids[len(self.prompt_ids):]
        #if self.tokenizer.eos_token_id in sequence_ids:
        #    return 0
            #end_loc = sequence_ids.index(tokenizer.eos_token_id)
            #sequence_ids = sequence_ids[:end_loc]
        eos_token_id = self.tokenizer.eos_token_id
        for token_id in sequence_ids:
            if (token_id == eos_token_id) or not self.satisfies[token_id]:
                return 0
        #s = self.tokenizer.decode(sequence_ids, skip_special_tokens=True)
        #if self.letter in s:
        #    return 0
        return 1


@hydra.main(config_path='../hydra_configs', config_name="qwen_0.5B_cg.yaml", version_base=None)
def main(cfg):
    # add runtime info to cfg
    OmegaConf.set_struct(cfg, False)
    cfg.meta = OmegaConf.create({})
    cfg.meta.original_dir = hydra.utils.get_original_cwd()
    cfg.meta.run_dir = os.getcwd()
    backoff_time = 5
    while True:
        if torch.cuda.is_available():
            free_mem = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
            best_gpu = free_mem.index(max(free_mem))
            cfg.meta.device = f"cuda:{best_gpu}"
            print(f"Free memory {free_mem[best_gpu] / (1024**2)}", flush=True)
            if free_mem[best_gpu] / (1024**2) > 6000:
                break
            else:
                print("Not enough memory")
                time.sleep(backoff_time)
                backoff_time *= 2
        else:
            cfg.meta.device = "cpu"
    print(cfg)
    torch.manual_seed(cfg.seed)
    random.seed(cfg.seed)

    if os.path.exists(cfg.fs.eval_save_path):
        with open(cfg.fs.eval_save_path, 'rb') as f:
            results = pickle.load(f)
        print(results[0])
        results_stripped = []
        for d in results:
            d.pop('output_ids')
            results_stripped.append(d)
        with open(cfg.fs.eval_save_path + "_cpu", 'wb') as f:
            pickle.dump(results_stripped, f)
        print(results_stripped[0])
        print("Saved cpu-friendly results")
        return

    verbose = 1

    model = AutoModelForCausalLM.from_pretrained(cfg.model.name,torch_dtype=torch.float16).to(cfg.meta.device)
    newline_or_null = "\n"
    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name,torch_dtype=torch.float16)

    evaluator_model = AutoModelForCausalLM.from_pretrained(cfg.evaluator_model.name,torch_dtype=torch.float16).to(cfg.meta.device)
    

    task = LetterConstraint(cfg.generation_configs.letter, tokenizer, model.vocab_size)
       
    #print("Prompt log prob:",nn_log_probs(evaluator_model, [task.prompt_ids], device=cfg.meta.device))
    #assert False

    piref = LMPolicy(model, tokenizer, cfg.generation_configs.max_new_tokens, cfg.meta.device)

    if cfg.eval.sampler[:3] == "WJS":
        up_prob = float(cfg.eval.sampler[3:])
        sampler_class = GenericJSSampler(up_prob = up_prob)
    elif cfg.eval.sampler == "LM":
        sampler_class = UnguidedLMSampler
    elif cfg.eval.sampler == "ResetTW":
        sampler_class = GenericTWSampler(reset=True)
    elif cfg.eval.sampler == "TW":
        sampler_class = GenericTWSampler(reset=False)
    else:
        assert False, "Unknown sampler"

    sampler = sampler_class(
        piref=piref,
        reward_function = task.reward_function,
        device=cfg.meta.device
    )

    total_steps = 0
    n = cfg.generation_configs.n
    start_time = time.perf_counter()
    sequences = []
    results = []
    for i in range(n):
        output_ids, steps = sampler.sample(task.prompt_ids)
        reward = task.reward_function(output_ids)
        total_steps += steps
        lp = nn_log_probs(evaluator_model, [output_ids], device=cfg.meta.device)[0][0]
        print("---------------------------")
        print(tokenizer.decode(output_ids, skip_special_tokens=False))
        print("---------------------------")
        print("Index:", i)
        print("Reward:", reward)
        print("Steps:", steps)
        print("Log prob:", lp, flush=True)
        print("+++++++++++++++++++++++++++")
        sequences.append(output_ids)
        results.append({'output_ids': output_ids, 'reward': reward, 'steps': steps, 'eval_log_prob': lp})
    end_time = time.perf_counter()
    log_probs = nn_log_probs(evaluator_model, sequences, device=cfg.meta.device)
    print(log_probs)
    print(f"Average log likelihood: {np.mean(log_probs):.4f}")
    print(f"Average generation time: {(end_time-start_time)/n:.4f}")
    print(f"Average steps: {total_steps/n:.4f}")

    save_dir = os.path.dirname(cfg.fs.eval_save_path)
    os.makedirs(save_dir, exist_ok=True)
    with open(cfg.fs.eval_save_path, 'wb') as f:
        pickle.dump(results, f)
    print("Saved results")


if __name__ == '__main__':
    main()
