from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
from transformers import GPT2TokenizerFast, GPT2Tokenizer
from editor import apply_grace_to_model, GraceHyperParams
import torch
import logging
import os
import numpy as np

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt='%m/%d/%Y %H:%M:%S',
                    level=logging.INFO)
LOG = logging.getLogger(__name__)


def edit(prompt, target_new, num_steps, replacement):
    request = {"prompt": prompt, "target_new": target_new}
    hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")

    model = AutoModelForCausalLM.from_pretrained("./models/gpt2", device_map='cpu')
    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    tok.pad_token_id = tok.eos_token_id
    global edit_model
    edit_model = apply_grace_to_model(model, tok, request, hparams, num_steps, replacement)
    return prompt


def generate(input_text, target_new=None):
    tok = GPT2Tokenizer.from_pretrained("./models/gpt2")
    hparams = GraceHyperParams.from_hparams("./hparams/GRACE/gpt2.yaml")
    tok.pad_token_id = tok.eos_token_id

    global edit_model

    if target_new is None:
        max_new_tokens = 25
    else:
        max_new_tokens = len(tok.encode(target_new))
    prompt_len = len(input_text)
    input_ids = tok.encode(input_text, return_tensors='pt').to('cpu')
    edit_output = edit_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
    edit_reply = tok.decode(edit_output[0], skip_special_tokens=True)
    torch.cuda.empty_cache()

    ori_model = AutoModelForCausalLM.from_pretrained("./models/gpt2").to('cpu')
    ori_output = ori_model.generate(input_ids, max_new_tokens=max_new_tokens, pad_token_id=tok.eos_token_id)
    ori_reply = tok.decode(ori_output[0], skip_special_tokens=True)
    ori_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(ori_reply)]
    edit_reply = [(_, 'output') if i > prompt_len else (_, None) for i, _ in enumerate(edit_reply)]
    return ori_reply, edit_reply


def get_handler(path, log_name):
    log_file_path = os.path.join(path, log_name)
    try:
        if not os.path.exists(path):
            print("We are creating the logger files")
            os.makedirs(path)
    except:
        pass
    file_handler = logging.FileHandler(log_file_path)
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))

    stream_handler = logging.StreamHandler()
    stream_handler.setLevel(logging.DEBUG)
    stream_handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
    return file_handler, stream_handler


def make_logs():
    f_h, s_h = get_handler('logs', log_name='run.log')
    LOG.addHandler(f_h)
    LOG.addHandler(s_h)


class Dict(dict):
    __setattr__ = dict.__setitem__
    __getattr__ = dict.__getitem__


def dictToObj(dictObj):
    if not isinstance(dictObj, dict):
        return dictObj
    d = Dict()
    for k, v in dictObj.items():
        d[k] = dictToObj(v)
    return d


def summary_metrics(all_metrics):
    mean_metrics = dict()
    for eval in ["pre", "post"]:
        mean_metrics[eval] = dict()
        for key in ["rewrite_acc", "rephrase_acc"]:
            if key in all_metrics[0][eval].keys():
                mean_metrics[eval][key] = np.mean([metric[eval][key] for metric in all_metrics])
        for key in ["rewrite_ppl", "rephrase_ppl", "ood_generality_threshold_succ"]:
            if key in all_metrics[0][eval].keys():
                mean_metrics[eval][key] = np.mean([metric[eval][key] for metric in all_metrics])
        for key in ["locality", "portability"]:
            if key in all_metrics[0][eval].keys() and all_metrics[0][eval][key] != {}:
                mean_metrics[eval][key] = dict()
                for lkey in all_metrics[0][eval][key].keys():
                    if lkey.endswith("acc"):
                        mean_metrics[eval][key][lkey] = np.mean([metric[eval][key][lkey] for metric in all_metrics])

    print("Metrics Summary: ", mean_metrics)

    return mean_metrics
