from matplotlib import pyplot as plt
import pandas as pd
import torch
import numpy as np
import re
import random


def fix_seed(seed: int = 42):
    import random
    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False


def dimensionless_jerk(movement, fs):
    # first enforce data into an numpy array.
    movement = np.array(movement)
    # calculate the scale factor and jerk.
    movement_peak = max(abs(movement))
    dt = 1.0 / fs
    movement_dur = len(movement) * dt
    jerk = np.diff(movement, 2) / pow(dt, 2)
    scale = pow(movement_dur, 3) / pow(movement_peak, 2)
    # estimate dj
    return -scale * sum(pow(jerk, 2)) * dt


def log_dimensionless_jerk(movement, fs):
    return -np.log(abs(dimensionless_jerk(movement, fs)))


def njs(acc_data, gyro_data):
    jerk_acc = np.zeros_like(acc_data)
    jerk_gyro = np.zeros_like(gyro_data)
    for axis in range(3):
        jerk_acc[:, axis] = log_dimensionless_jerk(acc_data[:, axis], fs=50)
        jerk_gyro[:, axis] = log_dimensionless_jerk(gyro_data[:, axis], fs=50)
    total_jerk = np.concatenate((jerk_acc, jerk_gyro), axis=1)
    return total_jerk.mean()


def window_wise_stability(data):
    assert len(data.shape) == 3
    results = np.zeros(shape=(data.shape[0], 1))
    for i in range(data.shape[0]):
        results[i, 0] = njs(data[i, :, :3], data[i, :, 3:6])
    return results


def get_rom_from_label(ts_label: torch.Tensor, dataset):
    label_str = [dataset.label_dict[index] for index in ts_label.view(-1).cpu().numpy()]
    rom = torch.Tensor([int(re.search(r"R(\d+)", lbl).group(1)) for lbl in label_str])
    return (rom / np.max(dataset.angles)).unsqueeze(-1)


def get_subject_from_label(ts_label: torch.Tensor, dataset):
    label_str = [dataset.label_dict[index] for index in ts_label.view(-1).cpu().numpy()]
    subjects = torch.Tensor(
        [int(re.search(r"S(\d+)", lbl).group(1)) for lbl in label_str]
    )
    return subjects.unsqueeze(-1)


def get_exercise_from_label(ts_label: torch.Tensor, dataset):
    label_str = [dataset.label_dict[index] for index in ts_label.view(-1).cpu().numpy()]
    exercises = torch.Tensor(
        [int(re.search(r"E(\d+)", lbl).group(1)) for lbl in label_str]
    )
    exercises = map_binary_labels_auto(exercises)
    return exercises.unsqueeze(-1)


def get_exercise_from_label_v2(ts_label: torch.Tensor, dataset):
    label_str = [dataset.label_dict[index] for index in ts_label.view(-1).cpu().numpy()]
    exercises = torch.Tensor(
        [int(re.search(r"E(\d+)", lbl).group(1)) for lbl in label_str]
    )
    return exercises.unsqueeze(-1)


def get_rom_from_label_v2(ts_label: torch.Tensor, dataset):
    label_str = [dataset.label_dict[index] for index in ts_label.view(-1).cpu().numpy()]
    rom = torch.Tensor([int(re.search(r"R(\d+)", lbl).group(1)) for lbl in label_str])
    return rom.unsqueeze(-1)


def map_labels(tensor, label_mapping):
    """
    Maps the labels in the input tensor to a standard set of labels based on the provided mapping.

    Parameters:
    - tensor: PyTorch tensor containing arbitrary label indicators
    - label_mapping: List where the index corresponds to the standard label, and the value at that index
                     corresponds to the original label

    Returns:
    - mapped_tensor: PyTorch tensor with standardized labels
    """
    # Create an inverse mapping from original labels to standard labels
    inverse_mapping = {value: idx for idx, value in enumerate(label_mapping)}

    # Apply the mapping to the tensor
    mapped_tensor = tensor.clone()
    for original_label, standard_label in inverse_mapping.items():
        mapped_tensor[tensor == original_label] = standard_label

    return mapped_tensor


def split_list(input_list, train_ratio, seed=None):
    """
    Splits the input_list into two lists based on the given train_ratio.

    Parameters:
    - input_list: List of elements to be split.
    - train_ratio: Ratio of elements to be included in the first list.

    Returns:
    - first_list: List containing elements with the train_ratio.
    - second_list: List containing the remaining elements.
    """
    if not 0 <= train_ratio <= 1:
        raise ValueError("train_ratio must be between 0 and 1")

    # Shuffle the input list to ensure randomness
    shuffled_list = input_list.copy()
    if seed:
        random.seed(seed)
    random.shuffle(shuffled_list)

    # Calculate the split index
    split_index = int(len(shuffled_list) * train_ratio)

    # Split the list
    first_list = shuffled_list[:split_index]
    second_list = shuffled_list[split_index:]

    return first_list, second_list


def map_binary_labels_auto(labels):
    unique_labels = torch.unique(labels)

    if len(unique_labels) != 2:
        raise ValueError("The input tensor must contain exactly two unique labels.")

    # Create a mapping dictionary
    mapping = {unique_labels[0].item(): 0, unique_labels[1].item(): 1}

    # Apply the mapping
    mapped_labels = labels.clone()  # Clone to avoid modifying the original tensor
    for original_label, new_label in mapping.items():
        mapped_labels[labels == original_label] = new_label

    return mapped_labels


class sliding_windows(torch.nn.Module):
    def __init__(self, width, step):
        # https://stackoverflow.com/questions/53972159/how-does-pytorchs-fold-and-unfold-work
        super(sliding_windows, self).__init__()
        self.width = width
        self.step = step

    def forward(self, input_time_series, labels):
        input_transformed = torch.swapaxes(
            input_time_series.unfold(-2, size=self.width, step=self.step), -2, -1
        )
        # For labels, we only have one dimension, so we unfold along that dimension
        if labels != None:
            labels_transformed = labels.unfold(0, self.width, self.step)
        else:
            labels_transformed = None
        return input_transformed, labels_transformed

    def get_num_sliding_windows(self, total_length):
        return round((total_length - (self.width - self.step)) / self.step)


def get_confusion_matrix(
    file_dir,
    prediction_targets=["gender", "smile", "glasses", "head_pose"],
    show=True,
    title="Baseline",
    show_y=True,
    hinge=False,
    vmin=80,
    vmax=100,
    cmap="YlGnBu",
    fontsize=8,
    save_title=None,
):
    # n x n x m matrix create:
    res_Acc = [
        [[] for _ in range(len(prediction_targets))]
        for _ in range(len(prediction_targets))
    ]
    # add

    df_all = pd.read_csv(file_dir)
    seeds = sorted(df_all["seed"].unique())
    for i, seed in enumerate(seeds):
        for j, target in enumerate(prediction_targets):
            df = df_all[df_all["seed"] == seed]
            df = df[df["cur_task"] == target]
            # print(len(df))
            # find the row with the lowest val_loss_pred
            best_row = df.loc[df["val_loss_pred"].idxmin()]
            for k in range(j + 1):
                res_Acc[j][k].append(
                    best_row[
                        (
                            prediction_targets[k] + "_hinge"
                            if hinge
                            else prediction_targets[k]
                        )
                    ]
                )
    # fill the rest with none values on the diagonal and top right
    for i in range(len(prediction_targets)):
        for j in range(i + 1, len(prediction_targets)):
            res_Acc[i][j] = [None] * len(seeds)
    res = np.array(res_Acc, dtype=float)
    if show:
        df_c = pd.DataFrame(
            np.nanmean(res, axis=2)[:, ::-1].T * 100,
            columns=[f"Time {i}" for i in range(len(prediction_targets))],
            index=prediction_targets[::-1],
        )
        plt.figure(figsize=(10, 10))
        import seaborn as sns

        # text size in heatmap
        g = sns.heatmap(
            df_c,
            annot=True,
            cmap=cmap,
            fmt=".3g",
            vmin=vmin,
            vmax=vmax,
            linewidth=0.5,
            cbar=False,
            annot_kws={"fontsize": fontsize},
        )
        g.set_facecolor("xkcd:grey")
        plt.title(title)

        if not show_y:
            plt.yticks([])
        # ticks size:
        plt.xticks(fontsize=fontsize)
        plt.yticks(fontsize=fontsize)
        # plt.legend()
        fig = plt.gcf()
        plt.show()
        # save the figure
        if save_title is not None:
            fig.savefig(f"./figures/{save_title}.png")

    return res


def calculate_accuracy(y_pred, y_true):
    _, predicted = torch.max(y_pred, 1)
    correct = (predicted == y_true).sum().item()
    return correct / y_true.size(0)
