import numpy as np
import pandas as pd
import itertools
import random
from sklearn.preprocessing import RobustScaler
from scipy.signal import butter, filtfilt
import torch
from torch.utils.data import Dataset
import os
from utilities.utils import (
    sliding_windows,
    njs,
    get_rom_from_label_v2,
    get_exercise_from_label_v2,
)
from .BaseDataset import BaseDataset


def sliding_window(df, window_size=150, window_step=50):
    # returns (B, L, D) numpy array
    samples = []
    for i in range(0, len(df) - window_size, window_step):
        samples.append(df.iloc[i: i + window_size].values)
    return np.stack(samples, axis=0)


class PhysiQTorchDatasetBuilder:
    def __init__(
            self,
            subjects_num: list,
            exercises_num: list,
            angles: list,
            standardize=False,
    ):
        self.data = {}
        self.extra_label = {}
        self.features = None
        self.subjects = subjects_num
        self.train_subjects = None
        self.exercises = exercises_num
        self.angles = angles
        scaler = RobustScaler()

        valid_subjects = []
        for subject, exercise, angle in itertools.product(
                *[subjects_num, exercises_num, angles]
        ):
            path = f"./data/PHYSIQ/S{subject}_E{exercise}_R_{angle}_0.csv"
            key = f"S{subject}_E{exercise}_R{angle}"
            if not os.path.exists(path):
                continue

            watch = pd.read_csv(path, index_col="Timestamp").iloc[150:-150, :6]

            if standardize:
                columns = watch.columns
                watch = pd.DataFrame(
                    scaler.fit_transform(watch), columns=columns
                )

            self.data[key] = watch

            self.extra_label[f"S{subject}_E{exercise}_R{angle}"] = njs(
                watch.to_numpy()[:, :3], watch.to_numpy()[:, 3:6]
            )

            if self.features is None:
                self.features = watch.columns.tolist()
            valid_subjects.append(subject)

        self.subjects = list(set(valid_subjects))

    def subjectwise_train_valid_split(
            self,
            train_ratio=0.8,
            augmentations=None,
            window_size=128,
            step_size=32,
            output_stability=False,
            seed=None,
    ):
        train_num = int(len(self.subjects) * train_ratio)
        if seed is not None:
            random.seed(seed)
        self.train_subjects = random.sample(self.subjects, train_num)
        # print("subjects in first dataset: ", self.train_subjects)
        train_value = []
        train_label = []
        valid_value = []
        valid_label = []
        train_stb = []
        valid_stb = []

        for subject, exercise, angle in itertools.product(
                self.subjects, self.exercises, self.angles
        ):
            path = f"./data/PHYSIQ/S{subject}_E{exercise}_R_{angle}_0.csv"
            if not os.path.exists(path):
                continue

            key = f"S{subject}_E{exercise}_R{angle}"
            stability = njs(
                self.data[key].values[:, :3], self.data[key].values[:, 3:6]
            )

            if subject in self.train_subjects:
                train_value.append(self.data[key].values)
                train_label.append(key)
                train_stb.append(stability)

            else:
                valid_value.append(self.data[key].values)
                valid_label.append(key)
                valid_stb.append(stability)

        if output_stability:
            if len(train_value) == 0:
                train_dataset = None
            else:
                train_dataset = PhysiQTorchDataset(
                    train_value,
                    train_label,
                    train_stb,
                    window_size,
                    step_size,
                    augmentations,
                )
            if len(valid_value) == 0:
                valid_dataset = None
            else:
                valid_dataset = PhysiQTorchDataset(
                    valid_value, valid_label, valid_stb, window_size, step_size
                )
        else:
            if len(train_value) == 0:
                train_dataset = None
            else:
                train_dataset = IMUTorchDataset(
                    train_value,
                    train_label,
                    window_size,
                    step_size,
                    augmentations,
                )

            if len(valid_value) == 0:
                valid_dataset = None
            else:
                valid_dataset = IMUTorchDataset(
                    valid_value, valid_label, window_size, step_size
                )

        return train_dataset, valid_dataset


class IMUTorchDataset(Dataset):
    def __init__(
            self, X, label, window_size, step_size, transform_function_list=None
    ):
        self.X = X
        if not all(isinstance(item, int) for item in label):
            unique_labels = sorted(set(label))
            label_to_int = {lbl: idx for idx, lbl in enumerate(unique_labels)}
            self.label = [label_to_int[lbl] for lbl in label]
            self.label_dict = {idx: lbl for lbl, idx in label_to_int.items()}
        else:
            self.label = label
            self.label_dict = {lbl: lbl for lbl in label}
        self.window_size = window_size
        self.step_size = step_size
        self.transform_functions = transform_function_list
        self.generate_data()

    def generate_data(self):
        sw = sliding_windows(self.window_size, self.step_size)
        print("Generating data")
        print(len(self.X))
        x_temp = []
        y_temp = []
        for i in range(len(self.X)):
            x = torch.tensor(self.X[i])
            y = self.label[i]
            x_sw, _ = sw(x, None)
            x_temp.append(x_sw)
            y_temp.append(torch.tensor([y] * x_sw.shape[0]))
        self.X = torch.cat(x_temp, dim=0)
        self.label = torch.cat(y_temp, dim=0)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        x, label = self.X[idx], self.label[idx]
        x = x.unsqueeze(0).numpy()
        label = label.unsqueeze(0).numpy()

        if self.transform_functions is None:
            return x, label

        x_out = [x]
        label_out = [label]
        # apply transformations
        for transform_function in self.transform_functions:
            x_out.append(transform_function(x))
            label_out.append(label)
        x_out = np.concatenate(x_out, axis=0)
        label_out = np.concatenate(label_out, axis=0)

        return x_out, label_out


class PhysiQTorchDataset(IMUTorchDataset):
    def __init__(
            self,
            X,
            label,
            stb_values,
            window_size,
            step_size,
            transform_function_list=None,
    ):
        self.stb_values = stb_values
        super().__init__(
            X, label, window_size, step_size, transform_function_list
        )

    def generate_data(self):
        sw = sliding_windows(self.window_size, self.step_size)
        print("Generating data")
        print(len(self.X))
        x_temp = []
        y_temp = []
        stb_temp = []

        for i in range(len(self.X)):
            x = torch.tensor(self.X[i])
            y = self.label[i]
            stb = self.stb_values[i]
            x_sw, _ = sw(x, None)
            x_temp.append(x_sw)
            y_temp.append(torch.tensor([y] * x_sw.shape[0]))
            stb_temp.append(torch.tensor([stb] * x_sw.shape[0]))

        self.X = torch.cat(x_temp, dim=0)
        self.label = torch.cat(y_temp, dim=0)
        self.stb_values = torch.cat(stb_temp, dim=0)

    def __getitem__(self, idx):
        x, label, stb = self.X[idx], self.label[idx], self.stb_values[idx]
        x = x.unsqueeze(0).numpy()
        label = label.unsqueeze(0).numpy()
        stb = stb.unsqueeze(0).numpy()

        if self.transform_functions is None:
            return x, label, stb

        x_out = [x]
        label_out = [label]
        stb_out = [stb]
        # apply transformations
        for transform_function in self.transform_functions:
            x_out.append(transform_function(x))
            label_out.append(label)
            stb_out.append(stb)
        x_out = np.concatenate(x_out, axis=0)
        label_out = np.concatenate(label_out, axis=0)
        stb_out = np.concatenate(stb_out, axis=0)

        return x_out, label_out, stb_out


class PhysiQDataset(BaseDataset):
    # prediction target as static variable
    # apparently this is how you would declare static variable
    prediction_targets = ["stability", "exercise", "rom"]

    def __init__(
            self,
            split="train",
            seed=42,
            all_binary=False,
    ):
        super(PhysiQDataset, self).__init__(
            split=split,  # this means nothing as of now
            transform=None,
            num_tasks=len(self.prediction_targets),
            seed=seed,
        )

        total_subjects = [4, 12, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31]
        exercises = [1, 3]
        angles = [30, 60, 90, 120, 150]
        step_size = 32

        dataset_builder = PhysiQTorchDatasetBuilder(
            subjects_num=total_subjects,
            exercises_num=exercises,  # shoulder abduction and forward flexion
            angles=angles,  # 5 Different angles
        )

        # get the whole dataset
        # note that train_ratio=1, so it is getting the whole dataset
        self.dataset, _ = dataset_builder.subjectwise_train_valid_split(
            train_ratio=1,
            augmentations=None,
            window_size=128,
            step_size=step_size,
            output_stability=True,
            seed=seed,
        )

        self.all_binary = all_binary
        self.annotations = list(range(len(self.dataset)))
        # use train_valid_test_split to only get corresponding ones
        self.get_train_valid_test_split()

        np.random.seed(seed)
        np.random.shuffle(self.annotations)
        self.sub_dataset = self.annotations

    def get_train_valid_test_split(self):
        total_len = len(self.annotations)
        train_size = int(0.6 * total_len)
        valid_size = int(0.2 * total_len)

        # Ensure consistent splits
        fixed_split_seed = 1123456
        np.random.seed(fixed_split_seed)
        indices = np.random.permutation(total_len)

        train_indices = indices[:train_size]
        valid_indices = indices[train_size:train_size + valid_size]
        test_indices = indices[train_size + valid_size:]

        if self.split == 'train':
            self.annotations = [self.annotations[i] for i in train_indices]
        elif self.split == 'valid':
            self.annotations = [self.annotations[i] for i in valid_indices]
        elif self.split == 'test':
            self.annotations = [self.annotations[i] for i in test_indices]
        else:
            raise ValueError(f"Invalid split value: {self.split}. Must be 'train', 'valid', or 'test'.")

    def __getitem__(self, idx):
        dataset_idx = self.sub_dataset[idx]
        x, labels, stability = self.dataset[dataset_idx]
        labels = torch.Tensor(labels).type(torch.long)
        x = torch.Tensor(x)
        stability = torch.Tensor(stability)

        stability = stability > -24.4
        stability = stability.type(torch.long).squeeze()

        exercise = get_exercise_from_label_v2(labels, self.dataset)
        rom = get_rom_from_label_v2(labels, self.dataset)

        sample = {
            "image": x.squeeze(0),  # "image" to make it consistent with other datasets
            "stability": stability,
        }

        exercise = exercise.type(torch.long).squeeze()
        rom = rom.type(torch.long).squeeze()

        if self.all_binary:
            # hardcoded, only works when exercise = [1, 3]
            # TODO: come up with better solution
            binary_exercise = {
                "exercise": 1 if exercise == 1 else 0,
            }
            binary_rom = {
                "rom": 1 if rom in [90, 120, 150] else 0,
            }
            sample.update(binary_rom)
            sample.update(binary_exercise)
        else:
            sample["rom"] = rom
            sample["exercise"] = exercise

        return sample
