import os
import socket
host_name = socket.gethostname()
CACHE_LOCATION = "/mnt/batch/tasks/shared/LS_root/mounts/clusters/{}/code/HUGGINGFACE_CACHE".format(host_name)
print("Hostname is {}".format(host_name))
print("Cache location is {}".format(CACHE_LOCATION))
os.environ['TRANSFORMERS_CACHE'] = CACHE_LOCATION
# os.environ['HF_HOME'] = CACHE_LOCATION
# os.environ['XDG_CACHE_HOME'] = CACHE_LOCATION

import copy
import torch
from IPython import embed
import logging
import numpy as np
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from transformers import OPTForCausalLM, AutoTokenizer, AutoConfig, AutoModelForCausalLM, T5Tokenizer, T5ForConditionalGeneration
from accelerate import infer_auto_device_map, init_empty_weights
from torch.utils.data import DataLoader

import generate_data
from timer import Timer
os.environ["TOKENIZERS_PARALLELISM"] = "false"

def verify_accuracy(trainer, lightning_model, num_digits, type, batch_size=1, digit_only=False, silent=True, flash=True, num_samples=10):
    loaders = []
    lightning_model.set_cur_num_digits(num_digits)
    for test_digits in range(1, num_digits+1):
        
        # If digit only, we only test the num_digits case
        if digit_only and test_digits != num_digits:
            dataset = []
        else:
            dataset = generate_data.AdditionDataset(num_examples=max(num_samples, batch_size), num_digits=test_digits, dataset_type="english", tokenizer=lightning_model.tokenizer, type=type, silent=silent)

        loaders.append(DataLoader(dataset, batch_size=batch_size, num_workers=1))

    
    # Now add the flash datasets
    if flash:
        for test_digits in range(1, num_digits+1):
            if digit_only and test_digits != num_digits:
                dataset = []
            else:
                dataset = generate_data.AdditionDataset(num_examples=max(num_samples, batch_size), num_digits=test_digits, dataset_type="english", tokenizer=lightning_model.tokenizer, type="flash", silent=silent)
        
            loaders.append(DataLoader(dataset, batch_size=batch_size, num_workers=1))

    results = trainer.validate(lightning_model, dataloaders=loaders, verbose=False)

    flash_accs, think_accs = [], [1.0]
    for test_digits in range(1, num_digits+1):
        if digit_only and test_digits != num_digits:
            continue

        if flash:
            flash_accs.append(results[0]["val_flash_acc_{}_digits".format(test_digits)])

        if test_digits > 1:
            think_accs.append(results[0]["val_acc_{}_digits".format(test_digits)])
    
    accs = think_accs + flash_accs

    # Only need to get to 75%
    perfect_flash = not any(np.array(flash_accs) < .75)

    # Only need to get to 99% for think
    perfect_think = not any(np.array(think_accs) < .99)

    return perfect_flash and perfect_think, np.array(accs)

def move_to_cuda(inputs):
    for key in inputs:
        inputs[key] = inputs[key].to('cuda')
    return inputs

def decode(tokenizer, input, delete_after_period=True):
    raw_line = tokenizer.batch_decode(input, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    if len(raw_line.split(".")) > 4:

        # Delete everything after the 4th period
        arrs = raw_line.split(".")

        period_idx = sum([len(arr) + 1 for arr in arrs[:4]])
        return raw_line[:period_idx]
    
    return raw_line

def gen(model, tokenizer, input):
    inputs = move_to_cuda(tokenizer(input, return_tensors="pt"))
    library_outputs = model.generate(**inputs, max_new_tokens=20)
    return decode(tokenizer=tokenizer, input=library_outputs, delete_after_period=False)

# 10 digit tokens + 1 EOS token
def generate_digit_and_special_tokens(tokenizer):

    # Special tokens
    DIGIT_TOKENS = set([0, 1, 2, 3])

    for i in range(100000):
        DIGIT_TOKENS.update(tokenizer([str(i) + ". " + str(i)], return_tensors="pt")["input_ids"][0].tolist())

    return DIGIT_TOKENS

def whitelist_tokens(tokenizer):

    WHITELIST = "1234567890. "
    OK_TOKENS = [0, 1, 2, 3]

    for i in range(tokenizer.vocab_size):

        decoded_word = tokenizer.decode([i], skip_special_tokens=True, clean_up_tokenization_spaces=False)

        # Reject string if it contains any characters other than the whitelist
        if any (c not in WHITELIST for c in decoded_word):
            # print(i, "Rejected: {}".format(decoded_word))
            pass
        else:
            # print(" Accepted: {}".format(decoded_word))
            OK_TOKENS.append(i)

    logging.info("Whitelist contains {} tokens".format(len(OK_TOKENS)))
    return set(OK_TOKENS)

# Manually generate from a hugging face model given the inputs
def manually_generate(model, tokenizer, orig_inputs, token_whitelist, maxlen=20):

    log_prob = 0.
    inputs = copy.deepcopy(orig_inputs)

    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    past_key_values = None

    for i in range(maxlen):

        outputs = model(input_ids, attention_mask=attention_mask, past_key_values=past_key_values)
        logits = outputs.logits

        if token_whitelist:
            logits[:, :, list(set(range(logits.shape[-1])) - token_whitelist)] = -np.inf


        # I don't understand how past_key_values works, but I think it speeds it up or something?
        # past_key_values = outputs.past_key_values

        next_token_logits = logits[:, -1, :]
        log_probs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)

        next_token = torch.argmax(next_token_logits, dim=-1)
        log_prob += log_probs[0, next_token].item()

        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
        attention_mask = torch.cat([attention_mask, torch.ones_like(next_token).unsqueeze(-1)], dim=-1)

        # Exit if next token is EOS
        if next_token.item() == tokenizer.eos_token_id:
            break
            
    return input_ids, log_prob


def load_model(model_size, family="opt"):
    logging.info("Loading {} model".format(model_size))

    if family == "opt":
        name = "facebook/opt-{}".format(model_size)

        config = AutoConfig.from_pretrained(name)
        with init_empty_weights():
            model = AutoModelForCausalLM.from_config(config)

        device_map = infer_auto_device_map(model, no_split_module_classes=["OPTDecoderLayer"], dtype="float16")
        logging.info(device_map)
        model = OPTForCausalLM.from_pretrained("facebook/opt-{}".format(model_size), torch_dtype=torch.float16, device_map=device_map, offload_folder="offload", offload_state_dict = True)
        tokenizer = AutoTokenizer.from_pretrained("facebook/opt-{}".format(model_size))
    
    elif "t5" in family:
        name = "google/{}-{}".format(family, model_size)

        tokenizer = AutoTokenizer.from_pretrained(name)
        model = AutoModelForSeq2SeqLM.from_pretrained(name)
    
    return model, tokenizer