import os
import sys
import csv
import pdb
import json
import pickle
import regex as re
from typing import Dict, Tuple, List, Iterable, Optional, Any, Union
from functools import partial

import numpy as np
import pandas as pd
import torch

# sys.path.append("../src")
from src.dataset import BinaryProbeDataLoader


NUM_BOXES=7
N_OBJECTS=100
MAX_QUERY_OPS=3


def load_object_names(object_file_path):
    object_map = {}
    object_list = []
    with open(object_file_path, encoding="utf-8-sig") as f:
        reader = csv.DictReader(f)
        for i, row in enumerate(reader):
            object_map[row["object_name"]] = i
            object_list.append(row["object_name"])
    return object_map, object_list



def load_box_data_old(path_to_data, object_to_index_map: Dict[str, int], include_empty:bool=False) -> pd.DataFrame:
    raw_examples = []

    with open(path_to_data, "r", encoding="UTF-8") as data_f:
        for line in data_f:
            raw_examples.append(json.loads(line))

    assert len(raw_examples) % NUM_BOXES == 0, f"Number of examples is not a multiple of {NUM_BOXES}!"

    box_contents = np.zeros(N_OBJECTS)  # vector with object positions, void = 0
    df = []
    for i, ex in enumerate(raw_examples):
        s_parts = ex["sentence"].strip(".").split(".")
        s = s_parts[-1].strip()
        is_empty = True
        n_obj = 0
        if "is empty" not in ex["masked_content"] and "nothing" not in ex["masked_content"]:
            is_empty = False
            contents = [_.replace("the ", "") for _ in
                        ex["masked_content"].replace("<extra_id_0> ", "").replace("contains ", "").split(" and ")]
            for c in contents:
                n_obj += 1
                oidx = object_to_index_map[c]
                box_contents[oidx] = 1

        if not is_empty or include_empty:
            ex["box_contents"] = box_contents
            ex["num_ops"] = [len(s_parts) - 2] * N_OBJECTS
            box_contents = torch.zeros(N_OBJECTS)
            mentioned_objects = np.zeros(N_OBJECTS)  # vector with mentioned objects
            o_names = re.findall(r'the ([^ ,.]+) ', " ".join(s_parts[:-1]) + " ")
            for o in o_names:
                if o == "contents":
                    continue
                oidx = object_to_index_map[o]
                mentioned_objects[oidx] = 1
            ex["all_mentioned_objects"] = mentioned_objects
            df.append(ex)

    df = pd.DataFrame(df)
    return df


def remove_empty(df: pd.DataFrame) -> pd.DataFrame:
    return df[df.masked_content.apply(lambda x: "is empty" not in x and "nothing" not in x)]


def load_box_data(dataset_path: str, object_to_index_map: Dict[str, int], num_prior_state:int=-1) -> pd.DataFrame:
    if num_prior_state != -1 and ("movecontent" in dataset_path.lower() or "move_content" in dataset_path.lower()):
        full_dataset_path = dataset_path.replace("-subsample-states", "")
        with open(full_dataset_path.replace("gpt.jsonl", f'subsample-states-mask.p'), "rb") as rep_f:
            subset_mask = pickle.load(rep_f)
    else:
        full_dataset_path = dataset_path
        subset_mask = None

    # load rest of the dataset features (sentence, masked_content, etc)
    df = pd.read_json(dataset_path, lines=True, orient="records")
    df = remove_empty(df)
    len_subset = len(df)
    df = df[df.numops >= (-num_prior_state) - 1]

    # load the labels
    ds = BinaryProbeDataLoader([np.zeros(1) for _ in range(len_subset)], full_dataset_path, object_to_index_map,
                          include_empty=False, local_operation_order=num_prior_state,
                          subset_mask=subset_mask)

    assert len(df) == len(ds), f"length of ds ({len(ds)}) and df ({len(df)}) do not match!"
    df["box_contents"] = ds.examples
    df["global_numops"] = ds.num_ops
    df["all_mentioned_objects"] = ds.mentioned_objects

    return df


