import click
import logging
import ujson as json
import sys
sys.path.append(os.path.dirname(__file__)+"./FastChat/")
from fastchat.train.train import preprocess
from collections import namedtuple
from tqdm import tqdm
import torch
import transformers

logger = logging.getLogger(__file__)

def extract_first_conv(conv):
    if conv[0]["from"] == "gpt":
        # if first round begins by gpt
        first_round = conv[:3]
    elif conv[0]["from"] == "human":
        # if first round begins by human
        first_round = conv[:2]
    else:
        pass

    return first_round

def run(args):
    with open(args.input) as fd:
        raw_data = json.load(fd)
    # load reward model
    model_name = "meta-llama/Llama-2-7b-hf"
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_name,
        use_fast = False
    )
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.model_max_length = 4096
    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_name, torch_dtype=torch.float32, device_map="sequential",
    )
    model.eval()
    print(model)
    print(model.device)

    validated_data = []
    for one_data in raw_data:
        if not one_data["conversations"]:
            continue
        same_role_flag = False
        wrong_from_flag = False
        last_from = None
        for sent in one_data["conversations"]:
            if sent["from"] == last_from:
                same_role_flag = True
            if sent["from"] not in ["human", "gpt"]:
                wrong_from_flag = True
            last_from = sent["from"]
        if same_role_flag or wrong_from_flag:
            continue
        validated_data.append(one_data)


    save_data = []
    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")

    for one_data in tqdm(validated_data[args.start::args.batch_size]):
        first_conv = extract_first_conv(one_data['conversations'])
        inputs = preprocess([first_conv], tokenizer)
        for k in inputs:
            inputs[k]=inputs[k].to(model.device)
        # input_ids = inputs["input_ids"]
        labels = inputs["labels"]
        attention_mask = inputs["attention_mask"]
        '''
        input_ids = torch.unsqueeze(inputs["input_ids"], 0)
        labels = torch.unsqueeze(inputs["labels"], 0)
        attention_mask = torch.unsqueeze(inputs["attention_mask"], 0)

        # input_ids = torch.unsqueeze(inputs["input_ids"], 0)
        labels = torch.unsqueeze(inputs["labels"], 0)
        attention_mask = torch.unsqueeze(inputs["attention_mask"], 0)
        '''
        with torch.no_grad():
            ret = model(**inputs, return_dict=True)
        #print(ret.loss)
        loss = ret.loss.double()
        logits = ret.logits.double()
        #print(attention_mask.sum())
        #print(loss) #TODO: check loss type, it should be float
        #print(logits.shape)

        # refer to https://github.com/huggingface/evaluate/blob/main/metrics/perplexity/perplexity.py
        shift_logits = logits[..., :-1, :].contiguous().double()
        shift_labels = labels[..., 1:].contiguous()
        shift_attention_mask_batch = attention_mask[..., 1:].contiguous().double()
        #print(shift_logits.shape, shift_labels.shape, shift_attention_mask_batch.shape)
        shift_loss = loss_fct(shift_logits.transpose(1, 2), shift_labels)
        #print(shift_labels.shape, shift_loss.shape)

        perplexity = torch.exp(
            (
                shift_loss * shift_attention_mask_batch
            ).sum(1)
            / shift_attention_mask_batch.sum(1)
        )
        #print(perplexity)

        one_data["perplexity"] = float(perplexity.item())
        one_data["llama2_loss"]= float(loss)
        print(one_data["perplexity"], one_data["llama2_loss"])
        save_data.append(one_data)
    
    save_file_name = args.output
    if args.batch_size > 1:
        save_file_name += f".{args.start}_{args.batch_size}"
    with open(args.output, "w") as fd:
        json.dump(save_data, fd, indent=1)


@click.command()
@click.option('-i', '--input', required=True)
@click.option('-o', '--output', required=True)
@click.option('-b', '--batch_size', default=1)
@click.option('-s', '--start', default=0)
def main(**kwargs):
    Arg = namedtuple('Arg', kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == '__main__':
    logging.basicConfig(filename=__file__ + '.log',
                        filemode='a',
                        format='[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s',
                        datefmt='%d-%m-%Y %H:%M:%S',
                        level=logging.DEBUG)
    logging.getLogger().addHandler(logging.StreamHandler())
    main()
