
import json
import logging 
from collections import defaultdict
from copy import deepcopy


import numpy as np
from tqdm import tqdm

from module.model.verification.attack import AttackWrapper

def handle_np_types(obj):
    if isinstance(obj, np.int64):
        return int(obj)  # Convert np.int64 to int
    elif isinstance(obj, np.longlong):
        return int(obj)
    raise TypeError(f"Object of type {type(obj)} is not serializable")


def _extract_bounding_boxes(tree, bounds):
    if "leaf" in tree:
        return [(deepcopy(bounds), tree["leaf"])]
    else:
        leaves = []

        # If the split's new right threshold (so the node on the left) is more specific
        # than the previous one, update the bound and recurse
        old_bound = bounds[tree["split"]][1]
        if (
            tree["split_condition"] <= bounds[tree["split"]][1]
            and tree["split_condition"] >= bounds[tree["split"]][0]
        ):
            bounds[tree["split"]][1] = tree["split_condition"]

        if tree["split_condition"] >= bounds[tree["split"]][0]:
            for subtree in tree["children"]:
                if subtree["nodeid"] == tree["yes"]:
                    leaves.extend(_extract_bounding_boxes(subtree, bounds))

        bounds[tree["split"]][1] = old_bound

        # If the split's new left threshold (so the node on the right) is more specific
        # than the previous one, update the bound and recurse
        old_bound = bounds[tree["split"]][0]
        if (
            tree["split_condition"] >= bounds[tree["split"]][0]
            and tree["split_condition"] <= bounds[tree["split"]][1]
        ):
            bounds[tree["split"]][0] = tree["split_condition"]

        if tree["split_condition"] < bounds[tree["split"]][1]:
            for subtree in tree["children"]:
                if subtree["nodeid"] == tree["no"]:
                    leaves.extend(_extract_bounding_boxes(subtree, bounds))

        bounds[tree["split"]][0] = old_bound

        return leaves

class DecisionTreeAttackWrapper(AttackWrapper):
    def __init__(self, json_model, n_classes):
        if len(json_model) != 1:
            raise ValueError("This attack can only be used with single decision trees")
        self.json_model = json_model

        if n_classes != 2:
            raise ValueError("Currently only binary classification is supported")
        self.n_classes = n_classes

        bounds = defaultdict(lambda: np.array([-np.inf, np.inf]))
        self.leaves = _extract_bounding_boxes(json_model[0], bounds)

    def get_bounding_box(self, X):
        # Turn 'leaves' into bounding boxes
        bound_dicts, leaf_values = zip(*self.leaves)
        predictions = [value > 0 for value in leaf_values]
        bounding_boxes = []
        for bound_dict in bound_dicts:
            bounding_box = np.tile(np.array([-np.inf, np.inf]), (X.shape[1], 1))
            for i, bound in bound_dict.items():
                bounding_box[i] = bound
            bounding_boxes.append(bounding_box)

        return np.array(bounding_boxes), predictions

    def assign_bounding_boxes(self, X):
        logger = logging.getLogger("DecisionTreeAttackWrapper.assign_bounding_boxes")
        bounding_boxes, predictions = self.get_bounding_box(X)

        ret = np.full(X.shape[0], -1)
        for i, sample in enumerate(X):
            for box_idx, bounding_box in enumerate(bounding_boxes):
                # check if sample is in this bounding box
                if not (np.any(sample <= bounding_box[:, 0]) or np.any(sample > bounding_box[:, 1])):
                    if ret[i] != -1:
                        logger.error("Sample is in multiple bounding boxes")
                        logger.error("Sample: %s", sample)
                        logger.error("Bounding box: %s", bounding_box)
                        assert False
                    ret[i] = box_idx
        if np.any(ret == -1):
            logger.error("Sample is not in any bounding boxes")
            logger.error("Sample: %s", sample)
            assert False
        return ret

    def adversarial_examples(self, X, y, order, options={}):
        logger = logging.getLogger("DecisionTreeAttackWrapper.adversarial_examples")
        # Turn 'leaves' into bounding boxes and leaf prediction values
        bounding_boxes, predictions = self.get_bounding_box(X)

        if "disable_progress_bar" in options:
            disable = options["disable_progress_bar"]
        else:
            disable = False

        X_adv = []
        epsilon = options.get("epsilon", None)
        shifter = 1e-7 # Fixing the issue where tiebreaks are not handled properly
        count_box_moves = defaultdict(int)
        count_leaves = defaultdict(int)
        for sample, label in tqdm(zip(X, y), total=X.shape[0], disable=disable):
            # Create a minimal adversarial example for each leaf
            # then choose the one with minimal distance
            best_distance = np.inf
            best_adv_example = None
            curr_box = None
            adv_box = None
            for idx, bounding_box in enumerate(bounding_boxes):
                prediction = predictions[idx]
                # check if sample is in this bounding box
                if not (np.any(sample <= bounding_box[:, 0]) or np.any(sample > bounding_box[:, 1])):
                    if curr_box is not None:
                        logger.error("Sample is in multiple bounding boxes")
                        logger.error("Sample: %s", sample)
                        logger.error("Bounding box: %s", bounding_box)
                        logger.error("Bounding curr: %s", curr_box)
                        assert False
                    curr_box = idx

                if prediction != label:
                    adv_example = np.clip(
                        sample, bounding_box[:, 0] + shifter, bounding_box[:, 1] - shifter
                    )
                    distance = np.linalg.norm(adv_example - sample, ord=order)

                    if distance < best_distance:
                        best_distance = distance
                        best_adv_example = adv_example
                        adv_box = idx

            X_adv_sample = None
            if best_distance == np.inf:
                adv_box = None
                X_adv_sample = np.full(len(sample), np.nan) if epsilon is None else sample
            else:
                direction = best_adv_example - sample
                if np.linalg.norm(direction, ord=order) > epsilon:
                    direction = direction / (np.linalg.norm(direction, ord=order) + shifter) * epsilon
                X_adv_sample = sample + direction if epsilon is not None else best_adv_example
                
                # direction /= (np.linalg.norm(direction, ord=order) + shifters) # Normalize direction to avoid division by zero
                # X_adv_sample = sample + epsilon * direction if epsilon is not None else best_adv_example
            X_adv.append(X_adv_sample)

            count_leaves[curr_box] += 1
            
            if epsilon is not None and best_distance < epsilon:
                count_box_moves[(curr_box, adv_box)] += 1


        if not options.get("logging", False):
            return np.array(X_adv)

        s = "\n"
        t = "\n"
        for key, value in sorted(count_box_moves.items(), key=lambda x: (x[0][0], x[0][1])):
            s += f"\tFrom {key[0]} to {key[1]}: {value}, {value / count_leaves[key[0]]:.2f}\n"
        for key, value in sorted(count_leaves.items(), key=lambda x: x[0]):
            t += f"\tLeaf {key}: {value}\n"
        logger.info("Bounding box moves: %s", s)
        logger.info("Leaf counts: %s", t)

        return np.array(X_adv)
        
