import torch
import pickle
from torch.nn import functional as F
from batlinet.data.transformation.base import BaseDataTransformation

def half_gaussian_kernel(kernel_size=33, sigma=5.0):
    """Returns a right-sided half-Gaussian kernel for causal filtering."""
    center = kernel_size - 1  # shift center to the left end
    x = torch.arange(0, kernel_size).float()
    kernel = torch.exp(-((x - center) ** 2) / (2 * sigma**2))
    kernel = kernel / kernel.sum()  # normalize
    return kernel.view(1, 1, -1)

def gaussian_filter1d(x, kernel_size=33, sigma=5.0):
    B, L = x.shape
    pad = kernel_size - 1

    x = torch.clip(x, x.min(), 1.0) # limit spikes

    x = x.unsqueeze(1)  # (B, 1, L)

    # Pad with the first value (replicate padding)
    x = F.pad(x, (pad, 0), mode='replicate')

    kernel = half_gaussian_kernel(kernel_size, sigma).to(x.device)
    x_filtered = F.conv1d(x, kernel).squeeze(1)  # (B, L)
    return x_filtered

class Dataset:
    def __init__(self, feature: torch.Tensor, label: torch.Tensor):
        assert len(feature) == len(label), (len(feature), len(label))

        self.label = label
        self.feature = feature

    def __len__(self):
        return len(self.label)

    def __getitem__(self, item: int):
        return {
            'feature': self.feature[item],
            'label': self.label[item]
        }

    @property
    def device(self):
        return self.label.device

    def to(self, device: str):
        self.label = self.label.to(device)
        self.feature = self.feature.to(device)
        return self


class DataBundle:
    def __init__(self,
                 train_feature: torch.Tensor,
                 train_label: torch.Tensor,
                 train_soh: torch.Tensor,
                 test_feature: torch.Tensor,
                 test_label: torch.Tensor,
                 test_soh: torch.Tensor,
                 feature_transformation: BaseDataTransformation = None,
                 label_transformation: BaseDataTransformation = None,
                 rul_transformation: BaseDataTransformation = None):
        # Convert the dtype
        train_feature = train_feature.float()
        train_label = train_label.float()
        train_soh = train_soh.float()
        test_feature = test_feature.float()
        test_label = test_label.float()
        test_soh = test_soh.float()

        self.feature_transformation = feature_transformation
        self.label_transformation = label_transformation
        self.rul_transformation = rul_transformation

        # Fit the stateful transformations
        if feature_transformation is not None:
            self.feature_transformation.fit(train_feature)
            train_feature = self.feature_transformation.transform(train_feature)
            test_feature = self.feature_transformation.transform(test_feature)
        if rul_transformation is not None:
            self.rul_transformation.fit(train_label)
            train_rul = self.rul_transformation.transform(train_label)
            test_rul = self.rul_transformation.transform(test_label)
        if label_transformation is not None:
            self.label_transformation.fit(train_soh)
            train_label = self.label_transformation.transform(train_soh)
            test_label = self.label_transformation.transform(test_soh)

        train_label = gaussian_filter1d(train_label)
        test_label = gaussian_filter1d(test_label)

        # Build datasets
        self.train_data = Dataset(train_feature, train_label)
        self.test_data = Dataset(test_feature, test_label)

        self.train_rul = Dataset(train_feature, train_rul)
        self.test_rul = Dataset(test_feature, test_rul)

    def to(self, device: str):
        self.train_data = self.train_data.to(device)
        self.test_data = self.test_data.to(device)
        self.train_rul = self.train_rul.to(device)
        self.test_rul = self.test_rul.to(device)
        if self.feature_transformation is not None:
            self.feature_transformation = self.feature_transformation.to(device)
        if self.label_transformation is not None:
            self.label_transformation = self.label_transformation.to(device)
        if self.rul_transformation is not None:
            self.rul_transformation = self.rul_transformation.to(device)
        return self

    @property
    def device(self):
        return self.train_data.feature.device

    @torch.no_grad()
    def evaluate(self, prediction: torch.Tensor, metric: str):
        target = self.test_rul.label
        if self.rul_transformation is not None:
            target = self.rul_transformation.inverse_transform(target)
            prediction = self.rul_transformation.inverse_transform(prediction)

        return self._evaluate_score(target, prediction, metric)

    @staticmethod
    def _evaluate_score(
        target: torch.Tensor,
        prediction: torch.Tensor,
        metric: str
    ) -> float:
        assert metric in ['RMSE', 'MAE', 'MAPE'], metric
        if metric == 'RMSE':
            score = torch.mean((target - prediction) ** 2) ** 0.5
        elif metric == 'MAE':
            score = torch.mean((target - prediction).abs())
        else:
            score = torch.abs((target - prediction) / target).mean()

        return float(score)

    @staticmethod
    def load(path: str):
        with open(path, 'rb') as f:
            return pickle.load(f)

    def dump(self, path: str):
        with open(path, 'wb') as f:
            pickle.dump(self, f)
