import random
from itertools import product

import pandas as pd
import torch
from torch.utils.data import Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


class DepthDataset(Dataset):

    def __init__(self, json_data, max_root_evid_vars, train=True):
        super().__init__()
        self.train = train
        self.max_root_evid_vars = max_root_evid_vars
        self.data_props = list(json_data.values())[0]["props"]
        self.num_query_vars = self.data_props["num_query_vars"]
        self.num_evid_vars = self.data_props["num_evid_vars"]
        self.total_vars = self.num_evid_vars + self.num_query_vars
        self.var_values_to_idx = {
            (var, val): idx
            for idx, (var, val) in enumerate(product(range(self.total_vars), range(2)))
        }
        self.idx_to_var_values = {v: k for k, v in self.var_values_to_idx.items()}
        # Initialize the padding_idx and add it to the idx to var values mapping
        self.padding_idx = len(self.idx_to_var_values)
        self.idx_to_var_values[self.padding_idx] = None

        self.flattened_data = self._get_flattened_data(json_data)

    def _decode_list(self, l):
        return list(map(int, l.split(",")))

    def _decode_dict(self, dict_):
        l = []
        for item in dict_.split(","):
            var, val = item.split("=")
            l.append((int(var), int(val)))
        return l

    def _get_ns_target_cols(self, parent_stats, child_stats, assigned_score):
        status_change = None
        objective_value = None
        solving_time = None

        parser = lambda x: (parent_stats[x], child_stats[x])
        parent_status, child_status = parser("status")
        parent_solving_time, child_solving_time = parser("solving_time")
        parent_objective, child_objective = parser("objective_value")

        if parent_status in ("optimal", "bestsollimit"):
            if child_status in ("optimal", "bestsollimit"):
                solving_time = child_solving_time
            else:
                status_change = 0

        elif parent_status == "timelimit":
            if child_status in ("optimal", "bestsollimit"):
                status_change = 1  # Improving status
            elif child_status == "timelimit":
                solving_time = child_solving_time
                objective_value = child_objective

        return [status_change, objective_value, solving_time]

    def _get_rows_ns(self, json_id, root_var_val_idxs, root_record, choice_records):
        rows = []
        for r in choice_records:
            evidence = self._decode_dict(r["evidence"])
            choice_evidence = evidence[-1]
            var_val_idxs = root_var_val_idxs + [self.var_values_to_idx[choice_evidence]]

            parent_stats = root_record["root_stats"]
            child_stats = r["stats"]
            assigned_score = r["assigned_score"]
            target_cols = [
                child_stats["objective_value"],
                child_stats["num_nodes"],
                child_stats["solving_time"],
            ]
            rows.append([json_id] + var_val_idxs + target_cols + [1])

            # Add the non-optimal selection as a new column
            q_val, optimal_val = evidence[-1]
            non_optimal_assignment = (q_val, 1 - optimal_val)
            var_val_idxs = var_val_idxs[:-1] + [
                self.var_values_to_idx[non_optimal_assignment]
            ]
            rows.append([json_id] + var_val_idxs + [None, None, None, -1])
        return rows

    def _get_rows_ovno(
        self, json_id, root_var_val_idxs, optimal_assignments, query_vars
    ):
        rows = []
        choice_vars = query_vars
        for var in choice_vars:
            var_assignment = optimal_assignments[var]
            opposite_assignment = 1 - var_assignment
            rows.append(
                [json_id]
                + root_var_val_idxs
                + [self.var_values_to_idx[(var, var_assignment)], None, None, None, 1]
            )
            rows.append(
                [json_id]
                + root_var_val_idxs
                + [
                    self.var_values_to_idx[(var, opposite_assignment)],
                    None,
                    None,
                    None,
                    -1,
                ]
            )
        return rows

    def _get_flattened_data(self, json_data):
        rows = []
        self.common_evid_row_id = []
        for json_id, j in json_data.items():
            root_record = j["root_record"]
            root_evidence = self._decode_dict(root_record["evidence"])
            root_var_val_idxs = [self.var_values_to_idx[item] for item in root_evidence]
            while len(root_var_val_idxs) < self.max_root_evid_vars:
                root_var_val_idxs.append(self.padding_idx)

            assert len(root_var_val_idxs) == self.max_root_evid_vars
            optimal_assignments = root_record["uninterrupted_assignments"]
            optimal_assignments = self._decode_list(optimal_assignments)
            query_vars = self._decode_list(j["query_vars"])
            self.common_evid_row_id.append(json_id)
            if j["props"]["data_collection_strategy"] == "non_sequential":
                depth_record = j["depths_record"]
                choices_tried = depth_record["choices_tried"]
                better_ct = depth_record["better_choices_ct"]
                worse_ct = depth_record["bad_choices_ct"]
                choice_records = depth_record["choices"]
                assert len(choice_records) == depth_record["choices_tried"]
                if choices_tried == 0 or better_ct == 0:
                    json_rows = self._get_rows_ovno(
                        json_id,
                        root_var_val_idxs,
                        optimal_assignments,
                        query_vars,
                    )
                else:
                    json_rows = self._get_rows_ns(
                        json_id,
                        root_var_val_idxs,
                        root_record,
                        choice_records,
                    )
            else:
                json_rows = self._get_rows_ovno(
                    json_id,
                    root_var_val_idxs,
                    optimal_assignments,
                    query_vars,
                )
            rows.extend(json_rows)

        self.root_evidence_cols = [f"e_{idx}" for idx in range(self.max_root_evid_vars)]
        self.choice_evid_column = "e_choice"

        self.target_columns = [
            "objective_value",
            "num_nodes",
            "solving_time",
            "ml_target",
        ]
        columns = (
            ["common_evid_id"]
            + self.root_evidence_cols
            + [self.choice_evid_column]
            + self.target_columns
        )
        df = pd.DataFrame(rows, columns=columns)
        df.set_index("common_evid_id", inplace=True)
        return df

    def __len__(self):
        return len(self.common_evid_row_id)

    def _get_probability_as_scores(self, choices, solving_time, epsilon=1e-5):
        output_prob = torch.zeros(len(self.var_values_to_idx), dtype=torch.float)
        mask = torch.zeros(len(self.var_values_to_idx), dtype=torch.int)
        reciprocal = 1 / (torch.tensor(solving_time) + epsilon)
        output_prob[choices] = reciprocal / reciprocal.sum()
        mask[choices] = 1
        return output_prob, mask

    def _get_scores(self, df):
        output_prob = torch.zeros(len(self.var_values_to_idx), dtype=torch.float)
        mask = torch.zeros(len(self.var_values_to_idx), dtype=torch.int)

        if len(df) == 0:
            return output_prob, mask

        weights = {
            "status_change": (3, False),
            "objective_value": (1, False),
            "solving_time": (2, True),
        }

        dfs = []
        for dim, (weight, ascending) in weights.items():
            df_dim = df[df[dim].notnull()].copy()
            if len(df_dim) == 0:
                continue

            min_val = df_dim[dim].min()
            max_val = df_dim[dim].max()

            # Avoid division by zero if all values are the same
            if max_val == min_val:
                df_dim["normalized"] = 1.0
            else:
                if dim == "solving_time":
                    df_dim["normalized"] = (max_val - df_dim[dim]) / (max_val - min_val)
                else:
                    df_dim["normalized"] = (df_dim[dim] - min_val) / (max_val - min_val)

            # Assign scores based on normalized values and weight
            df_dim["score"] = df_dim["normalized"] * weight
            dfs.append(df_dim)

        if dfs:
            combined_df = pd.concat(dfs, ignore_index=True)
            combined_df["probability"] = (
                combined_df["score"] / combined_df["score"].sum()
            )
            choices = list(combined_df[self.choice_evid_column])
            output_prob[choices] = torch.tensor(list(combined_df["probability"]))
            mask[choices] = 1

        return output_prob, mask

    def _get_scores_v2(self, df):
        output_prob = torch.zeros(len(self.var_values_to_idx), dtype=torch.float)
        mask = torch.zeros(len(self.var_values_to_idx), dtype=torch.int)

        if len(df) == 0:
            return output_prob, mask

        softmax_probs = lambda t: torch.nn.functional.log_softmax(t, dim=0)
        tensor = lambda l: torch.tensor(list(df[l]))

        prob_objective = softmax_probs(tensor("objective_value"))
        reciprocal_nodes = 1 / (tensor("num_nodes") + 1e-5)
        prob_nodes = softmax_probs(reciprocal_nodes)
        reciprocal_st = 1 / (tensor("solving_time") + 1e-5)
        prob_st = softmax_probs(reciprocal_st)
        joint_prob = prob_objective + prob_nodes + prob_st
        joint_prob_norm = joint_prob - torch.logsumexp(joint_prob, dim=0)

        choices = list(df[self.choice_evid_column])
        output_prob[choices] = joint_prob_norm
        mask[choices] = 1
        return output_prob, mask

    def __getitem__(self, index):
        common_evid_id = self.common_evid_row_id[index]

        all_rows = self.flattened_data.loc[common_evid_id]
        first_row = all_rows.iloc[0]
        evidence = list(first_row[self.root_evidence_cols].astype(int))

        ml_choices = list(all_rows[self.choice_evid_column])
        ml_target = list(all_rows["ml_target"].astype(float))
        rows_not_na = all_rows[
            (all_rows["objective_value"].notna())
            | (all_rows["num_nodes"].notna())
            | (all_rows["solving_time"].notna())
        ]

        evidence_tensor = torch.tensor(evidence, dtype=torch.int).to(device)
        padding_mask = (evidence_tensor != self.padding_idx).type(torch.int).to(device)
        ranking_cols = self.target_columns + [self.choice_evid_column]
        scoring_label, scoring_mask = self._get_scores_v2(rows_not_na[ranking_cols])

        scoring_label_copy = scoring_label.clone()
        scoring_label_copy[scoring_mask == 0] = -1e6
        _, scoring_indices = torch.sort(scoring_label_copy, dim=0, descending=True)

        is_good_score = torch.zeros(len(self.var_values_to_idx), dtype=torch.float)
        is_good_score[ml_choices] = torch.tensor(ml_target)
        ml_class_mask = torch.where(is_good_score == 0, 0, 1).type(torch.int)
        ml_class_label = torch.where(is_good_score == 1, 1, 0).type(torch.float)

        return {
            "evidence": evidence_tensor,
            "padding_mask": padding_mask,
            "all_choices": torch.arange(len(self.var_values_to_idx)).to(device),
            "scoring_label": scoring_label.to(device),
            "scoring_indices": scoring_indices.to(device),
            "scoring_mask": scoring_mask.to(device),
            "ml_class_label": ml_class_label.to(device),
            "ml_class_mask": ml_class_mask.to(device),
        }


class MemoryEfficientDepthDataset(Dataset):

    def __init__(self, json_data, max_root_evid_vars, train=True):
        super().__init__()
        self.train = train
        self.json_ids = list(json_data.keys())
        self.jsons = json_data
        self.max_root_evid_vars = max_root_evid_vars
        self.data_props = list(json_data.values())[0]["props"]
        self.num_query_vars = self.data_props["num_query_vars"]
        self.num_evid_vars = self.data_props["num_evid_vars"]
        self.total_vars = self.num_evid_vars + self.num_query_vars
        self.var_values_to_idx = {
            (var, val): idx
            for idx, (var, val) in enumerate(product(range(self.total_vars), range(2)))
        }
        self.idx_to_var_values = {v: k for k, v in self.var_values_to_idx.items()}
        # Initialize the padding_idx and add it to the idx to var values mapping
        self.padding_idx = len(self.idx_to_var_values)
        self.idx_to_var_values[self.padding_idx] = None

    def _decode_list(self, l):
        return list(map(int, l.split(",")))

    def _decode_dict(self, dict_):
        l = []
        for item in dict_.split(","):
            var, val = item.split("=")
            l.append((int(var), int(val)))
        return l

    def __len__(self):
        return len(self.json_ids)

    def get_ml_labels(self, json_value):
        root_record = json_value["root_record"]
        optimal_assignments = root_record["uninterrupted_assignments"]
        optimal_assignments = self._decode_list(optimal_assignments)
        query_vars = self._decode_list(json_value["query_vars"])

        ml_choices, ml_target = [], []
        is_good_score = torch.zeros(len(self.var_values_to_idx), dtype=torch.float)
        for var in query_vars:
            var_assignment = optimal_assignments[var]
            opposite_assignment = 1 - var_assignment
            choice = self.var_values_to_idx[(var, var_assignment)]
            opposite_choice = self.var_values_to_idx[(var, opposite_assignment)]
            ml_choices.extend([choice, opposite_choice])
            ml_target.extend([1.0, -1.0])

        is_good_score[ml_choices] = torch.tensor(ml_target)
        ml_class_mask = torch.where(is_good_score == 0, 0, 1).type(torch.int)
        ml_class_label = torch.where(is_good_score == 1, 1, 0).type(torch.float)
        return ml_class_label, ml_class_mask

    def get_scoring_labels(self, json_value):
        output_prob = torch.zeros(len(self.var_values_to_idx), dtype=torch.float)
        mask = torch.zeros(len(self.var_values_to_idx), dtype=torch.int)
        if json_value["props"]["data_collection_strategy"] != "non_sequential":
            return output_prob, mask

        depth_record = json_value["depths_record"]
        choices_tried = depth_record["choices_tried"]
        better_ct = depth_record["better_choices_ct"]
        worse_ct = depth_record["bad_choices_ct"]
        choice_records = depth_record["choices"]
        assert len(choice_records) == depth_record["choices_tried"]
        if choices_tried == 0 or better_ct == 0:
            return output_prob, mask

        scoring_choices = []
        objectives = []
        num_nodes = []
        times = []
        for r in choice_records:
            evidence = self._decode_dict(r["evidence"])
            choice_evidence = evidence[-1]
            scoring_choices.append(self.var_values_to_idx[choice_evidence])
            child_stats = r["stats"]
            objectives.append(child_stats["objective_value"])
            num_nodes.append(child_stats["num_nodes"])
            times.append(child_stats["solving_time"])

        softmax_probs = lambda t: torch.nn.functional.log_softmax(t, dim=0)
        tensor = lambda l: torch.tensor(l)

        prob_objective = softmax_probs(tensor(objectives))
        reciprocal_nodes = 1 / (tensor(num_nodes) + 1e-5)
        prob_nodes = softmax_probs(reciprocal_nodes)
        reciprocal_st = 1 / (tensor(times) + 1e-5)
        prob_st = softmax_probs(reciprocal_st)
        joint_prob = prob_objective + prob_nodes + prob_st
        joint_prob_norm = joint_prob - torch.logsumexp(joint_prob, dim=0)

        output_prob[scoring_choices] = joint_prob_norm
        mask[scoring_choices] = 1
        return output_prob, mask

    def __getitem__(self, index):
        json_id = self.json_ids[index]
        json_value = self.jsons[json_id]

        root_record = json_value["root_record"]
        root_evidence = self._decode_dict(root_record["evidence"])
        root_var_val_idxs = [self.var_values_to_idx[item] for item in root_evidence]
        while len(root_var_val_idxs) < self.max_root_evid_vars:
            root_var_val_idxs.append(self.padding_idx)

        assert len(root_var_val_idxs) == self.max_root_evid_vars
        evidence_tensor = torch.tensor(root_var_val_idxs, dtype=torch.int).to(device)
        padding_mask = (evidence_tensor != self.padding_idx).type(torch.int).to(device)

        ml_class_label, ml_class_mask = self.get_ml_labels(json_value)
        scoring_label, scoring_mask = self.get_scoring_labels(json_value)

        scoring_label_copy = scoring_label.clone()
        scoring_label_copy[scoring_mask == 0] = -1e6
        _, scoring_indices = torch.sort(scoring_label_copy, dim=0, descending=True)

        return {
            "evidence": evidence_tensor,
            "padding_mask": padding_mask,
            "all_choices": torch.arange(len(self.var_values_to_idx)).to(device),
            "scoring_label": scoring_label.to(device),
            "scoring_indices": scoring_indices.to(device),
            "scoring_mask": scoring_mask.to(device),
            "ml_class_label": ml_class_label.to(device),
            "ml_class_mask": ml_class_mask.to(device),
        }
