import os
import numpy as np
import pandas as pd
import glob
import re
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from utils.timefeatures import time_features
from data_provider.m4 import M4Dataset, M4Meta
from data_provider.uea import subsample, interpolate_missing, Normalizer
from sktime.datasets import load_from_tsfile_to_dataframe
import warnings

import os
import torch
import numpy as np
from torch.utils.data import Dataset
import torch.nn.functional as F

warnings.filterwarnings('ignore')


class UEAloader(Dataset):
    """
    Dataset class for datasets included in:
        Time Series Classification Archive (www.timeseriesclassification.com)
    Argument:
        limit_size: float in (0, 1) for debug
    Attributes:
        all_df: (num_samples * seq_len, num_columns) dataframe indexed by integer indices, with multiple rows corresponding to the same index (sample).
            Each row is a time step; Each column contains either metadata (e.g. timestamp) or a feature.
        feature_df: (num_samples * seq_len, feat_dim) dataframe; contains the subset of columns of `all_df` which correspond to selected features
        feature_names: names of columns contained in `feature_df` (same as feature_df.columns)
        all_IDs: (num_samples,) series of IDs contained in `all_df`/`feature_df` (same as all_df.index.unique() )
        labels_df: (num_samples, num_labels) pd.DataFrame of label(s) for each sample
        max_seq_len: maximum sequence (time series) length. If None, script argument `max_seq_len` will be used.
            (Moreover, script argument overrides this attribute)
    """

    def __init__(self, root_path, file_list=None, limit_size=None, flag=None):
        self.root_path = root_path
        self.all_df, self.labels_df = self.load_all(root_path, file_list=file_list, flag=flag)
        self.all_IDs = self.all_df.index.unique()  # all sample IDs (integer indices 0 ... num_samples-1)

        if limit_size is not None:
            if limit_size > 1:
                limit_size = int(limit_size)
            else:  # interpret as proportion if in (0, 1]
                limit_size = int(limit_size * len(self.all_IDs))
            self.all_IDs = self.all_IDs[:limit_size]
            self.all_df = self.all_df.loc[self.all_IDs]

        # use all features
        self.feature_names = self.all_df.columns
        self.feature_df = self.all_df

        # pre_process
        normalizer = Normalizer()
        self.feature_df = normalizer.normalize(self.feature_df)
        print("All sampele ..............", len(self.all_IDs))
        # exit()

    def load_all(self, root_path, file_list=None, flag=None):
        """
        Loads datasets from csv files contained in `root_path` into a dataframe, optionally choosing from `pattern`
        Args:
            root_path: directory containing all individual .csv files
            file_list: optionally, provide a list of file paths within `root_path` to consider.
                Otherwise, entire `root_path` contents will be used.
        Returns:
            all_df: a single (possibly concatenated) dataframe with all data corresponding to specified files
            labels_df: dataframe containing label(s) for each sample
        """
        # Select paths for training and evaluation
        if file_list is None:
            data_paths = glob.glob(os.path.join(root_path, '*'))  # list of all paths
        else:
            data_paths = [os.path.join(root_path, p) for p in file_list]
        if len(data_paths) == 0:
            raise Exception('No files found using: {}'.format(os.path.join(root_path, '*')))
        if flag is not None:
            data_paths = list(filter(lambda x: re.search(flag, x), data_paths))
        input_paths = [p for p in data_paths if os.path.isfile(p) and p.endswith('.ts')]
        if len(input_paths) == 0:
            pattern='*.ts'
            raise Exception("No .ts files found using pattern: '{}'".format(pattern))
        # print("Loading data from files: ", input_paths)
        # exit()
        # if 'CharacterTrajectories' in input_paths[0]:
        #     for path in input_paths:
        #         if 'eq' in path:
        #             all_df, labels_df = self.load_single(path)
        # else:
        #     all_df, labels_df = self.load_single(input_paths[0])  # a single file contains dataset
        if 'CharacterTrajectories' in input_paths[0]:
            # for path in input_paths:
            #     if 'eq' in path:
            #         all_df, labels_df = self.load_single(path)
            if flag == 'TRAIN':
                all_df, labels_df = self.load_single('./data/UEA_multivariate/CharacterTrajectories/CharacterTrajectories_TRAIN.ts')  # train file
            elif flag == 'TEST':
                all_df, labels_df = self.load_single('./data/UEA_multivariate/CharacterTrajectories/CharacterTrajectories_TEST.ts')  # test file
            # if flag == 'TRAIN':
            #     all_df, labels_df = self.load_single('./data/UEA_multivariate/CharacterTrajectories/CharacterTrajectories_eq_TRAIN.ts')  # train file
            # elif flag == 'TEST':
            #     all_df, labels_df = self.load_single('./data/UEA_multivariate/CharacterTrajectories/CharacterTrajectories_eq_TEST.ts')  # test file
        else:
            all_df, labels_df = self.load_single(input_paths[0])  # a single file contains dataset

        return all_df, labels_df

    def load_single(self, filepath):
        # print("Loading ts file: ", filepath)
        # exit()
        df, labels = load_from_tsfile_to_dataframe(filepath, return_separate_X_and_y=True,
                                                             replace_missing_vals_with='NaN')
        print("Check df shape:", df.shape, "from file:", filepath)
        # exit()
        labels = pd.Series(labels, dtype="category")
        self.class_names = labels.cat.categories
        labels_df = pd.DataFrame(labels.cat.codes,
                                 dtype=np.int8)  # int8-32 gives an error when using nn.CrossEntropyLoss

        lengths = df.applymap(
            lambda x: len(x)).values  # (num_samples, num_dimensions) array containing the length of each series

        horiz_diffs = np.abs(lengths - np.expand_dims(lengths[:, 0], -1))

        if np.sum(horiz_diffs) > 0:  # if any row (sample) has varying length across dimensions
            df = df.applymap(subsample)

        lengths = df.applymap(lambda x: len(x)).values
        vert_diffs = np.abs(lengths - np.expand_dims(lengths[0, :], 0))
        if np.sum(vert_diffs) > 0:  # if any column (dimension) has varying length across samples
            self.max_seq_len = int(np.max(lengths[:, 0]))
        else:
            self.max_seq_len = lengths[0, 0]
        print("Max Sequence Length:", self.max_seq_len, np.sum(vert_diffs))
        # exit()

        # First create a (seq_len, feat_dim) dataframe for each sample, indexed by a single integer ("ID" of the sample)
        # Then concatenate into a (num_samples * seq_len, feat_dim) dataframe, with multiple rows corresponding to the
        # sample index (i.e. the same scheme as all datasets in this project)

        df = pd.concat((pd.DataFrame({col: df.loc[row, col] for col in df.columns}).reset_index(drop=True).set_index(
            pd.Series(lengths[row, 0] * [row])) for row in range(df.shape[0])), axis=0)

        # Replace NaN values
        grp = df.groupby(by=df.index)
        df = grp.transform(interpolate_missing)

        return df, labels_df

    def instance_norm(self, case):
        if self.root_path.count('EthanolConcentration') > 0:  # special process for numerical stability
            mean = case.mean(0, keepdim=True)
            case = case - mean
            stdev = torch.sqrt(torch.var(case, dim=1, keepdim=True, unbiased=False) + 1e-5)
            case = case / stdev
            return case
        else:
            return case

    def __getitem__(self, ind):
        return self.instance_norm(torch.from_numpy(self.feature_df.loc[self.all_IDs[ind]].values)), \
               torch.from_numpy(self.labels_df.loc[self.all_IDs[ind]].values)

    def __len__(self):
        return len(self.all_IDs)


def collate_fn(data, max_len=None):

    batch_size = len(data)
    features, labels = zip(*data)

    # Stack and pad features and masks (convert 2D to 3D tensors, i.e. add batch dimension)
    lengths = [X.shape[0] for X in features]  # original sequence length for each time series
    if max_len is None:
        max_len = max(lengths)

    X = torch.zeros(batch_size, max_len, features[0].shape[-1])  # (batch_size, padded_length, feat_dim)
    for i in range(batch_size):
        end = min(lengths[i], max_len)
        X[i, :end, :] = features[i][:end, :]

    targets = torch.stack(labels, dim=0)  # (batch_size, num_labels)

    padding_masks = padding_mask(torch.tensor(lengths, dtype=torch.int16),
                                 max_len=max_len)  # (batch_size, padded_length) boolean tensor, "1" means keep

    return X, targets, padding_masks

def padding_mask(lengths, max_len=None):
    """
    Used to mask padded positions: creates a (batch_size, max_len) boolean mask from a tensor of sequence lengths,
    where 1 means keep element at this position (time step)
    """
    batch_size = lengths.numel()
    max_len = max_len or lengths.max_val()  # trick works because of overloading of 'or' operator for non-boolean types
    return (torch.arange(0, max_len, device=lengths.device)
            .type_as(lengths)
            .repeat(batch_size, 1)
            .lt(lengths.unsqueeze(1)))

class AmortizedExplanationLoader(UEAloader):
    """
    Consistently loads Meta-Inputs [Raw TS + Saliency] and Meta-Targets [Ensemble Map].
    Inherits properly from UEAloader to preserve raw data loading logic.
    """
    def __init__(self, args, root_path, file_list=None, limit_size=None, flag=None, target_set='train'):
        # 1. CRITICAL FIX: Call the parent constructor first
        # This populates self.all_IDs, self.feature_df, and self.labels_df
        super().__init__(root_path, file_list=file_list, limit_size=limit_size, flag=flag)
        
        self.args = args
        self.target_set = target_set
        
        # 2. Define directory where Stage 1 explanations are stored
        self.save_dir = os.path.join(f"./explanations/{args.dataset}", args.model, args.dnn_type)
        
        # 3. Load pre-computed maps [Batch, Time, Channel]
        saliency_file = f"{args.dataset}-{args.seed}-{args.model}-{args.dnn_type}-{target_set}_InputXGradient.npy"
        ensemble_file = f"{args.dataset}-{args.seed}-{args.model}-{args.dnn_type}-{target_set}_BestEnsemble.npy"
        
        saliency_path = os.path.join(self.save_dir, saliency_file)
        ensemble_path = os.path.join(self.save_dir, ensemble_file)

        # print(saliency_path, ensemble_path)
        # exit()
        
        # if not os.path.exists(saliency_path) or not os.path.exists(ensemble_path):
            # raise FileNotFoundError(f"Explanations not found at {self.save_dir}. Run Stage 1 first.")

        # Load data into memory
        self.saliency_maps = np.load(saliency_path, allow_pickle=True).item()['attributions']
        self.target_maps = np.load(ensemble_path, allow_pickle=True).item()['attributions']

        # print(self.saliency_maps.shape, self.target_maps.shape)
        # exit()
        
        # 4. Apply Per-instance Min-Max Normalization to Targets
        self.target_maps = self._normalize_instance_wise(self.target_maps)
        
        print(f"AmortizedLoader ({target_set}) successfully initialized with {len(self.all_IDs)} samples.")

    def _normalize_instance_wise(self, data):
        """Scales each attribution map to [0, 1] range."""
        B, T, C = data.shape
        flat = data.reshape(B, -1)
        mins = flat.min(axis=1, keepdims=True)
        maxs = flat.max(axis=1, keepdims=True)
        norm = (flat - mins) / (maxs - mins + 1e-8)
        return norm.reshape(B, T, C)

    def __getitem__(self, index):
        # A. Fetch Raw Time Series [T, C] from original UEAloader
        # This now works because super().__init__ was called
        raw_x, _ = super().__getitem__(index) 
        
        # B. Fetch the corresponding Saliency [T, C] and Ensemble Target [T, C]
        sal_x = torch.from_numpy(self.saliency_maps[index]).float()
        target_y = torch.from_numpy(self.target_maps[index]).float()

        # C. Handle Length Mismatch with Zero Padding
        # T_fixed is the expected length from the attribution maps
        T_fixed = sal_x.shape[0]
        T_raw = raw_x.shape[0]
        
        if T_raw < T_fixed:
            # Pad the temporal dimension (dim 0) with zeros at the end
            # pad format for 2D [T, C] is (left, right, top, bottom) -> (0, 0, 0, padding_length)
            padding_len = T_fixed - T_raw
            raw_x = F.pad(raw_x, (0, 0, 0, padding_len), "constant", 0)
        elif T_raw > T_fixed:
            # If raw is somehow longer, truncate it to match
            raw_x = raw_x[:T_fixed, :]
            
            # C. Concatenate Raw TS + Saliency -> Input [T, 2C]
            # print(raw_x.shape, sal_x.shape)
            # exit()
        # print(raw_x.shape, sal_x.shape)
        # exit()
        amortized_input = torch.cat([raw_x.float(), sal_x], dim=-1)
        
        return amortized_input, target_y


class AmortizedExplanationLoader_V2(UEAloader):
    """
    Consistently loads Meta-Inputs [Raw TS + Saliency] and Meta-Targets [Ensemble Map].
    Inherits properly from UEAloader to preserve raw data loading logic.
    """
    def __init__(self, args, root_path, file_list=None, limit_size=None, flag=None, target_set='train'):
        # 1. CRITICAL FIX: Call the parent constructor first
        # This populates self.all_IDs, self.feature_df, and self.labels_df
        super().__init__(root_path, file_list=file_list, limit_size=limit_size, flag=flag)
        # print(self.max_seq_len)
        # exit()
        
        self.args = args
        self.target_set = target_set
        
        # 2. Define directory where Stage 1 explanations are stored
        self.save_dir = os.path.join(f"./explanations/{args.dataset}", args.model, args.dnn_type)
        
        # 3. Load pre-computed maps [Batch, Time, Channel]
        ensemble_file = f"{args.dataset}-{args.seed}-{args.model}-{args.dnn_type}-{target_set}_BestEnsemble.npy"
        
        ensemble_path = os.path.join(self.save_dir, ensemble_file)

        # print(saliency_path, ensemble_path)
        # exit()
        
        # if not os.path.exists(saliency_path) or not os.path.exists(ensemble_path):
            # raise FileNotFoundError(f"Explanations not found at {self.save_dir}. Run Stage 1 first.")

        # Load data into memory
        self.target_maps = np.load(ensemble_path, allow_pickle=True).item()['attributions']

        # print(self.saliency_maps.shape, self.target_maps.shape)
        # exit()
        
        # 4. Apply Per-instance Min-Max Normalization to Targets
        self.target_maps = self._normalize_instance_wise(self.target_maps)
        
        print(f"AmortizedLoader ({target_set}) successfully initialized with {len(self.all_IDs)} samples.")

    def _normalize_instance_wise(self, data):
        """Scales each attribution map to [0, 1] range."""
        B, T, C = data.shape
        flat = data.reshape(B, -1)
        mins = flat.min(axis=1, keepdims=True)
        maxs = flat.max(axis=1, keepdims=True)
        norm = (flat - mins) / (maxs - mins + 1e-8)
        return norm.reshape(B, T, C)

    def __getitem__(self, index):
        # A. Fetch Raw Time Series [T, C] from original UEAloader
        # This now works because super().__init__ was called
        raw_x, labels = super().__getitem__(index) 
        
        # B. Fetch the corresponding Saliency [T, C] and Ensemble Target [T, C]
        target_y = torch.from_numpy(self.target_maps[index]).float()

        # C. Handle Length Mismatch with Zero Padding
        # T_fixed is the expected length from the attribution maps
        T_raw = raw_x.shape[0]
        
        amortized_input = raw_x.float()
        
        return amortized_input, target_y, raw_x, labels

