import torch
import torch.nn.functional as F
import numpy as np
from scipy.stats import gaussian_kde
from utils.common import calculate_quantile, compute_class_weights
import logging

txt_logger = logging.getLogger("sfda_reg")


def create_interval(
    y_bank,
    class_num_mode='self_peak',
    class_value_mode="self",
    q_range=0.1,
    y_min=None,
    y_max=None,
):
    """
    return:
        Tensor
            class_boundaries: [num_classes+1]
    """
    if y_min is None:
        y_bank_min, _ = torch.min(y_bank, dim=-1)
        y_min = y_bank_min.min().item()
    if y_max is None:
        y_bank_max, _ = torch.max(y_bank, dim=-1)
        y_max = y_bank_max.max().item()

    global_min = y_min
    global_max = y_max

    q_u = 1.0
    q_l = 0.0
    y_middle = 0.5
    match class_num_mode:
        case "self":
            q_u = 0.5 + q_range / 2.0
            q_l = 0.5 - q_range / 2.0
            y_middle = 0.5
        case "self_mean":
            y_mean = torch.mean(y_bank.flatten()).item()
            mean_q = calculate_quantile(y_bank.flatten(), y_mean)
            q_u = mean_q + q_range / 2.0
            q_l = mean_q - q_range / 2.0
            y_middle = mean_q
        case "self_peak":
            y_bank_mean_np = torch.mean(y_bank, dim=1).numpy()
            kde = gaussian_kde(y_bank_mean_np)
            x_grid = np.linspace(y_bank_mean_np.min(),
                                 y_bank_mean_np.max(), 1000)
            density = kde(x_grid)
            peak_index = np.argmax(density)
            y_peak = x_grid[peak_index]
            peak_q = calculate_quantile(y_bank.flatten(), y_peak)
            q_u = peak_q + q_range / 2.0
            q_l = peak_q - q_range / 2.0
            y_middle = peak_q

    if y_middle < 0.48:
        value_mode = 'left'
    elif y_middle > 0.52:
        value_mode = 'right'
    else:
        value_mode = 'middle'

    y_interval = torch.quantile(y_bank, q_u) - torch.quantile(y_bank, q_l)
    num_classes = int((global_max - global_min) / y_interval)
    if num_classes > 100:
        num_classes = int(0.5 * num_classes)

    # generate equal length interval
    total_range = global_max - global_min
    interval_length = float(total_range / num_classes)
    class_boundaries = []
    current = global_min

    while len(class_boundaries) <= num_classes:
        class_boundaries.append(current)
        current += interval_length

        # cover global_max
        if len(class_boundaries) == num_classes and current < global_max:
            current = global_max
    class_boundaries = torch.tensor(class_boundaries)
    class_boundaries = torch.round(class_boundaries * 1000) / 1000

    if class_value_mode in ['left', 'right', 'middle']:
        value_mode = class_value_mode

    y_cls_value = generate_interval_value(class_boundaries, value_mode)

    return class_boundaries, y_cls_value, value_mode, q_u, q_l, interval_length, global_max, global_min


def assign_partial_label(
        y_sample_bank,
        class_boundaries,
        overlap_threshold=0.0,
        mode='y_overlap_den_relax',
        relax_num=2):
    """
    args:
        mode: y_overlap_len, y_overlap_den, y_overlap_len_relax, y_overlap_den_relax
    return:
        - multi_hot_labels: [N, num_classes]
    """

    dataset_len = y_sample_bank.size(0)
    num_classes = class_boundaries.size(0) - 1
    multi_hot_labels = torch.zeros(
        (dataset_len, num_classes), dtype=torch.float32)
    multi_hot_ratio = torch.zeros((dataset_len, num_classes), dtype=torch.float32)

    # generating labels
    if 'y_overlap' in mode:
        y_bank_min, _ = torch.min(y_sample_bank, dim=-1)
        y_bank_max, _ = torch.max(y_sample_bank, dim=-1)
        # for each point
        for i in range(dataset_len):
            sample_values = y_sample_bank[i]
            point_min = y_bank_min[i].item()
            point_max = y_bank_max[i].item()
            point_width = point_max - point_min

            for c in range(num_classes):
                class_min = class_boundaries[c].item()
                class_max = class_boundaries[c + 1].item()

                if 'len' in mode:
                    # overlap length
                    overlap_start = max(point_min, class_min)
                    overlap_end = min(point_max, class_max)
                    if overlap_end > overlap_start:  # if exist overlap
                        overlap_length = overlap_end - overlap_start
                        overlap_ratio = overlap_length / point_width
                        # overlap ratio larger than overlap_threshold
                        if overlap_ratio > overlap_threshold:
                            multi_hot_ratio[i, c] = overlap_ratio
                            multi_hot_labels[i, c] = 1.0
                elif 'den' in mode:
                    in_class_mask = (sample_values
                                     >= class_min) & (sample_values < class_max)
                    in_class_count = in_class_mask.sum().item()
                    if in_class_count > 0:
                        density = in_class_count / sample_values.size(0)
                        multi_hot_ratio[i, c] = density
                        if density > overlap_threshold:
                            multi_hot_labels[i, c] = 1.0

        if 'relax' in mode:
            add_label_num = relax_num
            for i in range(dataset_len):
                # if var_filter[i]:
                current_labels = multi_hot_labels[i]
                current_ratio = multi_hot_ratio[i]
                # print(current_ratio)
                non_zero_idx = torch.nonzero(current_labels, as_tuple=True)[0]
                left_start = max(non_zero_idx[0] - add_label_num, 0)
                right_end = min(
                    non_zero_idx[-1] + add_label_num, num_classes - 1) + 1
                current_labels[left_start:right_end] = 1.0
                current_ratio[left_start:right_end] += float(1 / num_classes)
                current_ratio = current_ratio / torch.sum(current_ratio)
                # print(current_ratio)
                multi_hot_labels[i] = current_labels
                multi_hot_ratio[i] = current_ratio

    return multi_hot_labels, multi_hot_ratio


def assign_reg_cls_label(values, boundaries):
    """
    Input:
        values (torch.Tensor or list): 
        boundaries (torch.Tensor or list): 

    return:
        torch.Tensor: (len(values),)
    """
    if values.dim() > 1:
        values_flatten = values.flatten()
        reshape = values.size(1)
    else:
        values_flatten = values
        reshape = False

    num_classes = len(boundaries) - 1
    labels = torch.zeros_like(values_flatten, dtype=torch.long)
    labels[values_flatten > boundaries[-1]] = num_classes - 1

    for i in range(num_classes):
        lower_bound = boundaries[i]
        upper_bound = boundaries[i + 1]

        in_range_mask = (values_flatten >= lower_bound) & (
            values_flatten < upper_bound)
        if i == num_classes - 1:
            in_range_mask = (values_flatten
                             >= lower_bound) & (values_flatten <= upper_bound)
        labels[in_range_mask] = i

    if reshape:
        labels = labels.view(-1, reshape)
    return labels


def generate_interval_value(y_cls_bound, eval_mode="middle"):
    match eval_mode:
        case "middle":
            class_reg_value = (
                y_cls_bound[1:] + y_cls_bound[:-1]).float() / 2.0
        case "right":
            class_reg_value = y_cls_bound[1:]
        case "left":
            class_reg_value = y_cls_bound[:-1]
    return class_reg_value.float()  # tensor


class RegValueHist:

    def __init__(
        self,
        t_dict,
        class_num_mode="self_peak",
        class_value_mode="left",
        partial_label_mode="y_overlap_den_relax",
        q_range=0.1,
        y_min=None,
        y_max=None,
    ):
        self.initialized_from_do_dict(
            t_dict, class_num_mode, class_value_mode, partial_label_mode, q_range,
            y_min, y_max)



    def initialized_from_do_dict(
            self, t_dict, class_num_mode, class_value_mode, partial_label_mode, q_range,
            y_min, y_max):
        y_sample_bank = t_dict["y_pred_sample"]  # reg result with dropout
        y_pred_reg = t_dict["y_pred"].flatten()  # reg prediction wo dropout
        y_true_reg = t_dict["y_true"].flatten(
        )  # reg ground truth - for validation and comparison

        # generate class boundary
        y_cls_bound, y_cls_value, used_class_value_mode, q_u, q_l, interval_length, global_max, global_min = create_interval(
            y_sample_bank,
            class_num_mode=class_num_mode,
            class_value_mode=class_value_mode,
            q_range=q_range,
            y_min=y_min,
            y_max=y_max)
        num_classes = y_cls_bound.size(0) - 1
        partial_label_relax_num = max(1, int(num_classes * 0.05))

        # generate class labels
        y_true_cls = assign_reg_cls_label(y_true_reg, y_cls_bound)

        y_pred_partial_cls, y_pred_partial_cls_ratio = assign_partial_label(
            y_sample_bank,
            y_cls_bound,
            mode=partial_label_mode,
            relax_num=partial_label_relax_num)
        y_pred_cls = assign_reg_cls_label(y_pred_reg, y_cls_bound)

        y_pred_sample_cls = assign_reg_cls_label(
            t_dict["y_pred_sample"], y_cls_bound)

        self.point_weight, self.weights, self.cls_counts = compute_class_weights(
            y_pred_cls, num_classes, return_counts=True)
        self.do_point_weight, self.do_weights, self.do_cls_counts = compute_class_weights(
            y_pred_sample_cls, num_classes, return_counts=True)

        # ground true info
        self.y_true_reg = y_true_reg.cpu()
        self.y_true_cls = y_true_cls.cpu()
        # prediction info
        self.y_pred_reg = y_pred_reg.cpu()
        self.y_pred_sample = t_dict["y_pred_sample"].cpu()
        self.y_pred_cls = y_pred_cls.cpu()
        self.y_pred_partial_cls = y_pred_partial_cls.cpu()
        self.y_pred_partial_cls_ratio = y_pred_partial_cls_ratio.cpu()
        # hist info
        self.y_cls_bound = y_cls_bound.cpu()
        self.y_cls_value = y_cls_value.cpu()
        self.y_cls_value_mode = used_class_value_mode
        self.num_classes = num_classes
        self.partial_label_relax_num = partial_label_relax_num
        self.partial_label_mode = partial_label_mode
        self.q_range = q_range
        self.q_u = q_u
        self.q_l = q_l
        self.interval_length = interval_length
        self.do_max = global_max
        self.do_min = global_min
        self.class_num_mode = class_num_mode
        self.class_value_mode = class_value_mode



    def get_y_hist_info(self):
        y_hist_info = {
            "q_range": self.q_range,
            "q_u": self.q_u,
            "q_l": self.q_l,
            "num_classes": self.num_classes,
            "interval_length": self.interval_length,
            "sample_max": self.do_max,
            "sample_min": self.do_min,
            "true_max": self.y_true_reg.max().item(),
            "true_min": self.y_true_reg.min().item(),
            "class_num_mode": self.class_num_mode,
            "class_value_mode": self.class_value_mode,
            "used_class_value_mode": self.y_cls_value_mode,
            "partial_label_relax_num": self.partial_label_relax_num,
            "partial_label_mode": self.partial_label_mode,
        }
        return y_hist_info

    def update_y(self, y_dict, mv_coef=0.5):
        new_y_pred_cls = y_dict['y_pred_cls']
        new_y_pred_reg = y_dict['y_pred_reg']
        new_y_pred_partial_cls_ratio = y_dict['y_pred_cls_prob']
    

        # update self attribute
        last_y_pred_reg = self.y_pred_reg
        last_y_pred_partial_cls_ratio = self.y_pred_partial_cls_ratio
        self.y_pred_reg = last_y_pred_reg * \
            mv_coef + new_y_pred_reg * (1 - mv_coef)
        self.y_pred_partial_cls_ratio = F.normalize(
            last_y_pred_partial_cls_ratio * mv_coef + new_y_pred_partial_cls_ratio * (1 - mv_coef), dim=-1)
        self.y_pred_cls = new_y_pred_cls
        txt_logger.info(f"Y Hist Updated")

