import argparse
import csv
import json
import os
import pdb
from tqdm import tqdm
from typing import List, Dict, Tuple, Optional, Any

import torch
import numpy as np
import pandas as pd

import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

from src.dataset import *
from src.utils import format_sentence, get_quantization_config


def find_last_nth_index(value, lst, n):
    matches = [(i, val) for i, val in enumerate(reversed(lst)) if val == value]
    if len(matches) >= n:
        return len(lst) - matches[n - 1][0] - 1  # Get index from last nth match
    else:
        return -1



def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_dir", type=str, required=True,
        help="Path to the entity tracking model.")
    parser.add_argument("--load_in_8bit", action="store_true", help="Load model with 8-bit")
    parser.add_argument("--load_in_4bit", action="store_true", help="Load model with 4-bit")

    parser.add_argument(
        "--data_path",
        type=str,
        default=None,
        help="Path to a jsonl file of data.",
        required=True,
    )
    parser.add_argument(
        "--object_vocabulary_file",
        type=str,
        default="data/objects_with_bnc_frequency.csv",
        help='Path to a .csv file with a string field "object_names".'
    )
    parser.add_argument(
        "--sampling_seed",
        type=int,
        default=22,
        help="Seed for random sampling."
    )
    parser.add_argument(
        "--output_dir", type=str,
        default="results/baseline_inference"
    )
    parser.add_argument(
        "--few_shot_prompt",
        type=str,
        default=False,
        choices=[False, "PROMPT", "PROMPT_ALTFORM", "PROMPT_ALLBOX_ALTFORM", "INSTRUCTION",]
    )
    parser.add_argument(
        "--chat", action="store_true",
    )

    parser.add_argument(
        "--batch_size",
        type=int,
        default=8,
    )

    # distributed inference ( for caching embedding)
    parser.add_argument('--distributed',
                        dest="distributed",
                        action="store_true")
    parser.add_argument("--local-rank", "--local_rank", type=int)

    args = parser.parse_args()

    # load object map
    object_map = {}
    object_list = []
    with open(args.object_vocabulary_file, encoding="utf-8-sig") as f:
        reader = csv.DictReader(f)
        for i, row in enumerate(reader):
            object_map[row["object_name"]] = i
            object_list.append(row["object_name"])

    # set up distributed inference
    if args.distributed:
        rank = int(os.environ.get("LOCAL_RANK",0))
        print(f"{rank=}")
        torch.cuda.set_device(rank)
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)  # https://github.com/pytorch/pytorch/issues/146767
        torch.distributed.init_process_group("nccl", device_id=device)
        tp_plan = "auto"
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        tp_plan = None
    model_kwargs = {"tp_plan": tp_plan}

    model_name = args.model_dir
    if args.distributed:
        if args.load_in_8bit:
            model_kwargs["load_in_8bit"] = True
        elif args.load_in_4bit:
            model_kwargs["load_in_4bit"] = True
    else:
        model_kwargs["quantization_config"] =  get_quantization_config(args)
    if model_kwargs.get("quantization_config") is None and transformers.utils.is_torch_bf16_gpu_available():
        model_kwargs["torch_dtype"] = torch.bfloat16
    model = AutoModelForCausalLM.from_pretrained(args.model_dir, **model_kwargs)
    if model_kwargs.get("quantization_config") is None and not args.distributed:
        model = model.to(device)

    tokenizer = AutoTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token

    os.makedirs(args.output_dir, exist_ok=True)
    out_path = os.path.join(args.output_dir, f"inference_{os.path.basename(os.path.normpath(args.model_dir))}"
                                             f"{'_8bit' if args.load_in_8bit else '_4bit' if args.load_in_4bit else ''}"
                                             f"{'_fs'+args.few_shot_prompt if args.few_shot_prompt else ''}"
                                             f"{'_chat' if args.chat else ''}"
                                             f"_{os.path.basename(os.path.normpath(args.data_path))}")
    print(f"Writing output to {out_path}")

    # Data
    sampled_data = pd.read_json(args.data_path, orient='records', lines=True).to_dict("records")

    # TODO batch generation
    for bi in tqdm(range(0, len(sampled_data), args.batch_size)):
        batch_data = sampled_data[bi:bi+args.batch_size]
        batch_sent = []
        for dat in batch_data:
            target = dat.get("gold", dat["masked_content"])
            orig_items = [] if target == "nothing" else target.removeprefix("the ").split(" and the ")
            dat["orig_items"] = orig_items
            example_sent = format_sentence(dat, args.few_shot_prompt, globals().get(args.few_shot_prompt), chat_format=args.chat, tokenizer=tokenizer)
            batch_sent.append(example_sent)

        batch_input = tokenizer(batch_sent, return_tensors="pt", padding=True, padding_side="left").to(device)
        with torch.no_grad():
            output = model.generate(
                **batch_input,
                max_new_tokens=50,
                stop_strings=[".", "\n", "Box"],
                tokenizer=tokenizer,
                num_beams=1,
                do_sample=False,  # greedy decoding
            )
        generation = tokenizer.batch_decode(output[:, batch_input.input_ids.shape[1]:], skip_special_tokens=True)
        for i, dat in enumerate(batch_data):
            pred_objs = list(set([o for o in generation[i].lower().replace(","," ").replace("."," ").split(" ") if o in object_map]))
            write_d = {
                'prefix': batch_sent[i],
                'original_answer': generation[i],
                'parsed_original_answer': pred_objs,
                'gold_items': dat['orig_items'],
                'gold_answer': dat["masked_content"],
                'numops': dat["numops"],
                'numops_global': dat["prefix"].count(". ")-1,
                'correct': set(pred_objs) == set(dat['orig_items']),
                'precision': np.mean([o in dat['orig_items'] for o in pred_objs]).item(),
                'recall': np.mean([o in pred_objs for o in dat['orig_items']]).item(),
            }
            with open(out_path, 'a') as wf:
                wf.write(json.dumps(write_d) + "\n")

    # for dat in tqdm(sampled_data):
    #     target = dat.get("gold", dat["masked_content"])
    #     orig_items = [] if target == "nothing" else target.removeprefix("the ").split(" and the ")
    #     dat["orig_items"] = orig_items
    #     example_sent = format_sentence(args, dat)
    #     input_example = tokenizer(example_sent, return_tensors="pt").to(device)
    #
    #     with torch.no_grad():
    #         output = model.generate(
    #             **input_example,
    #             max_new_tokens=50,
    #             stop_strings=[".", "\n", "Box"],
    #             tokenizer=tokenizer,
    #             num_beams=1,
    #             do_sample=False,  # greedy decoding
    #         )
    #     generation = tokenizer.batch_decode(output[:, input_example.input_ids.shape[1]:])[0]
    #
    #     # print(generation)
    #     # print()
    #     pred_objs = list(set([o for o in generation.lower().replace(","," ").replace("."," ").split(" ") if o in object_map]))
    #     write_d = {
    #         'prefix': example_sent,
    #         'original_answer': generation,
    #         'parsed_original_answer': pred_objs,
    #         'gold_items': dat['orig_items'],
    #         'gold_answer': dat["masked_content"],
    #         'numops': dat["numops"],
    #         'numops_global': example_sent.count(". ")-1,
    #         'correct': set(pred_objs) == set(dat['orig_items']),
    #         'precision': np.mean([o in dat['orig_items'] for o in pred_objs]).item(),
    #         'recall': np.mean([o in pred_objs for o in dat['orig_items']]).item(),
    #     }
    #     with open(out_path, 'a') as wf:
    #         wf.write(json.dumps(write_d) + "\n")




if __name__ == '__main__':
    main()