import multiprocessing
from pathlib import Path
import json
 

class MpCounter:
    def __init__(self):
        self.val = multiprocessing.Value("i", 0)

    def increment(self, n=1):
        with self.val.get_lock():
            self.val.value += n

    @property
    def value(self):
        return self.val.value


def yield_chunks(data, size):
    data = list(data)
    for i in range(0, len(data), size):
        yield data[i : i + size]


def ensure_folder(folder: Path, parents=False):
    if not folder.exists():
        folder.mkdir(parents=parents)


def pick_if_present(d: dict, key_in_dict, key_new=None):
    if key_in_dict in d:
        if not key_new:
            return {key_in_dict: d[key_in_dict]}
        else:
            return {key_new: d[key_in_dict]}
    return {}


class AverageMeterSet(object):
    def __init__(self, meters=None):
        self.meters = meters if meters else {}

    def __getitem__(self, key):
        if key not in self.meters:
            meter = AverageMeter()
            meter.update(0)
            return meter
        return self.meters[key]

    def update(self, name, value, n=1):
        if name not in self.meters:
            self.meters[name] = AverageMeter()
        self.meters[name].update(value, n)

    def reset(self):
        for meter in self.meters.values():
            meter.reset()

    def values(self, format_string="{}"):
        return {format_string.format(name): meter.val for name, meter in self.meters.items()}

    def averages(self, format_string="{}"):
        return {format_string.format(name): meter.avg for name, meter in self.meters.items()}

    def sums(self, format_string="{}"):
        return {format_string.format(name): meter.sum for name, meter in self.meters.items()}

    def counts(self, format_string="{}"):
        return {format_string.format(name): meter.count for name, meter in self.meters.items()}


class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val
        self.count += n
        self.avg = self.sum / self.count

    def __format__(self, fmt):
        return "{self.val:{format}} ({self.avg:{format}})".format(self=self, format=fmt)

def tokenize_each_demonstration(demonstration_list, tokenizer, dataset_name=None, prefix = None):
        special_characters = [
            "~", " ~", "~ ", "!", " !", "! ", "@", " @", "@ ", "#", " #", "# ", 
            "$", " $", "$ ", "%", " %", "% ", "^", " ^", "^ ", "&", " &", "& ", 
            "*", " *", "* ", "(", " (", "( ", ")", " )", ") ", "_", " _", "_ ", 
            "+", " +", "+ ", "`", " `", "` ", "-", " -", "- ", "=", " =", "= ", 
            "{", " {", "{ ", "}", " }", "} ", "[", " [", "[ ", "]", " ]", "] ", 
            "|", " |", "| ", "\\", " \\", "\\ ", ":", " :", ": ", ";", " ;", "; ", 
            "\"", " \"", "\" ", "'", " '", "' ", "<", " <", "< ", ">", " >", "> ", 
            ",", " ,", ", ", ".", " .", ". ", "?", " ?", "? ", "/", " /", "/ "
        ]

        def strip_special_characters(input_string):
            for char in special_characters:
                input_string = input_string.replace(char.strip(), '')
            return input_string.strip()

        tokenized_demonstration_list = []
        for exp_id in range(len(demonstration_list)):
            if prefix is not None:
                demonstration_list[exp_id] = (prefix[0] + strip_special_characters(demonstration_list[exp_id][0]), prefix[1] + strip_special_characters(demonstration_list[exp_id][1]))
            else:
                demonstration_list[exp_id] = (strip_special_characters(demonstration_list[exp_id][0]), strip_special_characters(demonstration_list[exp_id][1]))
            e_original = tokenizer(demonstration_list[exp_id][0]) 
            e_rewrite = tokenizer(demonstration_list[exp_id][1])
            tokenized_demonstration_list.append((e_original, e_rewrite)) 
        return tokenized_demonstration_list

def apply_template(all_demos_reflect, tokenizer):
    all_demos_reflect_ = []
    for i in range(len(all_demos_reflect)):
        # pos_prompt = 'Solve the mathematics problem with step-by-step detailed reasoning. '
        # neg_prompt = 'Solve the mathematics problem with direct answering. '
        pos_prompt = 'Solve the problem with step-by-step detailed reasoning. '
        neg_prompt = 'Solve the problem with direct answering. '
        
        demos_with_template_pos = pos_prompt + all_demos_reflect[i]['question'] # + '\n\n' + all_demos_reflect[i]["answer"]
        demos_with_template_neg = neg_prompt + all_demos_reflect[i]['question'] # + '\n\n' + all_demos_reflect[i]["answer"].split('####')[-1].strip()

        all_demos_reflect_.append((demos_with_template_neg, demos_with_template_pos))
    return all_demos_reflect_

def apply_template_correctness(all_demos_reflect, tokenizer):
    all_demos_reflect_ = []
    for i in range(len(all_demos_reflect)):
        pos_prompt = 'Answer the question with knowledge retrieval. '
        neg_prompt = 'Answer the question. '
        demos_with_template_pos = pos_prompt + all_demos_reflect[i]['question']  + '\n\n' + all_demos_reflect[i]['answer'].split('####')[0].strip() # + '\n\n' + all_demos_reflect[i]["answer"]
        demos_with_template_neg = neg_prompt + all_demos_reflect[i]['question']  + '\n\n' + all_demos_reflect[i]['solution'].split('####')[0].strip() # + '\n\n' + all_demos_reflect[i]["answer"].split('####')[-1].strip()
        # demos_with_template_pos = pos_prompt + all_demos_reflect[i]['question'] + '\n\n' + all_demos_reflect[i]['step_without_mistake'] # + '\n\n' + all_demos_reflect[i]["answer"]
        # demos_with_template_neg = neg_prompt + all_demos_reflect[i]['question'] + '\n\n' + all_demos_reflect[i]['step_with_mistake'] # + '\n\n' + all_demos_reflect[i]["answer"].split('####')[-1].strip()
        
        all_demos_reflect_.append((demos_with_template_neg, demos_with_template_pos))
    return all_demos_reflect_


class CompactJSONEncoder(json.JSONEncoder):
    """A JSON Encoder that puts small containers on single lines."""

    CONTAINER_TYPES = (list, tuple, dict)
    """Container datatypes include primitives or other containers."""

    MAX_WIDTH = 1000
    """Maximum width of a container that might be put on a single line."""

    MAX_ITEMS = 60
    """Maximum number of items in container that might be put on single line."""

    def __init__(self, *args, **kwargs):
        # using this class without indentation is pointless
        if kwargs.get("indent") is None:
            kwargs["indent"] = 4
        super().__init__(*args, **kwargs)
        self.indentation_level = 0

    def encode(self, o):
        """Encode JSON object *o* with respect to single line lists."""
        if isinstance(o, (list, tuple)):
            return self._encode_list(o)
        if isinstance(o, dict):
            return self._encode_object(o)
        if isinstance(o, float):  # Use scientific notation for floats
            return format(o, "g")
        return json.dumps(
            o,
            skipkeys=self.skipkeys,
            ensure_ascii=self.ensure_ascii,
            check_circular=self.check_circular,
            allow_nan=self.allow_nan,
            sort_keys=self.sort_keys,
            indent=self.indent,
            separators=(self.item_separator, self.key_separator),
            default=self.default if hasattr(self, "default") else None,
        )

    def _encode_list(self, o):
        if self._put_on_single_line(o):
            return "[" + ", ".join(self.encode(el) for el in o) + "]"
        self.indentation_level += 1
        output = [self.indent_str + self.encode(el) for el in o]
        self.indentation_level -= 1
        return "[\n" + ",\n".join(output) + "\n" + self.indent_str + "]"

    def _encode_object(self, o):
        if not o:
            return "{}"
        if self._put_on_single_line(o):
            return "{ " + ", ".join(f"{self.encode(k)}: {self.encode(el)}" for k, el in o.items()) + " }"
        self.indentation_level += 1
        output = [f"{self.indent_str}{json.dumps(k)}: {self.encode(v)}" for k, v in o.items()]

        self.indentation_level -= 1
        return "{\n" + ",\n".join(output) + "\n" + self.indent_str + "}"

    def iterencode(self, o, **kwargs):
        """Required to also work with `json.dump`."""
        return self.encode(o)

    def _put_on_single_line(self, o):
        return self._primitives_only(o) and len(o) <= self.MAX_ITEMS and len(str(o)) - 2 <= self.MAX_WIDTH

    def _primitives_only(self, o):
        if isinstance(o, (list, tuple)):
            return not any(isinstance(el, self.CONTAINER_TYPES) for el in o)
        elif isinstance(o, dict):
            return not any(isinstance(el, self.CONTAINER_TYPES) for el in o.values())

    @property
    def indent_str(self) -> str:
        if isinstance(self.indent, int):
            return " " * (self.indentation_level * self.indent)
        elif isinstance(self.indent, str):
            return self.indentation_level * self.indent
        else:
            raise ValueError(f"indent must either be of type int or str (is: {type(self.indent)})")


if __name__ == "__main__":
    a = list(range(0, 12))
    print(a)
    for e in yield_chunks(a, 7):
        print(e)
