# ============================ Imports ============================
import os
import sys
import glob
import math
import time
import random
import argparse
import pickle
import json
import warnings
warnings.filterwarnings("ignore")

import ast
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, MinMaxScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import torchvision.models as models
from torchvision import transforms

from scipy.signal import stft, hilbert
from torch.autograd import Function

from thop import profile

class DaliaDataset(Dataset):
    def __init__(self, root_dir, modalities, subjects, cfg, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.subjects = subjects
        self.modalities = modalities
        self.base_sr = cfg.base_sample_rate
        self.duration = cfg.duration
        self.sampling_rates = cfg.sampling_rates
        self.file_paths = []

        for subject in subjects:
            subject_dir = os.path.join(root_dir, subject)
            if os.path.exists(subject_dir):
                self.file_paths.extend(glob.glob(os.path.join(subject_dir, "*.pt")))

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        data = torch.load(self.file_paths[idx])

        if len(self.modalities) > 1:
            resampled_data = []
            for modality in self.modalities:
                if modality in data:
                    mod_data = data[modality]
                    expected_length = int(self.duration * self.base_sr)
                    orig_sr = self.sampling_rates[modality]
                    resample_factor = self.base_sr / orig_sr

                    if resample_factor < 1:
                        step = int(1 / resample_factor)
                        resampled = mod_data[::step]
                    else:
                        if len(mod_data.shape) == 2:
                            time_dim, feature_dim = mod_data.shape
                            mod_data_reshaped = mod_data.permute(1, 0).unsqueeze(0)
                            target_len = int(time_dim * resample_factor)
                            resampled = F.interpolate(mod_data_reshaped, size=target_len, mode='linear', align_corners=False).squeeze(0).permute(1, 0)
                        else:
                            raise ValueError(f"Unexpected tensor shape for {modality}: {mod_data.shape}")

                    current_length = resampled.shape[0]
                    if current_length > expected_length:
                        resampled = resampled[:expected_length]
                    elif current_length < expected_length:
                        padding_needed = expected_length - current_length
                        last_frame = resampled[-1:].repeat(padding_needed, 1)
                        resampled = torch.cat([resampled, last_frame], dim=0)

                    assert resampled.shape[0] == expected_length
                    resampled_data.append(resampled)
            x = torch.cat(resampled_data, dim=1)
        else:
            if self.modalities[0] == 'chest_ACC':
                mod_data = data[self.modalities[0]]
                expected_length = int(self.duration * self.base_sr)
                orig_sr = self.sampling_rates[self.modalities[0]]
                resample_factor = self.base_sr / orig_sr
                if resample_factor < 1:
                    step = int(1 / resample_factor)
                    resampled = mod_data[::step]
                else:
                    if len(mod_data.shape) == 2:
                        time_dim, feature_dim = mod_data.shape
                        mod_data_reshaped = mod_data.permute(1, 0).unsqueeze(0)
                        target_len = int(time_dim * resample_factor)
                        resampled = F.interpolate(mod_data_reshaped, size=target_len, mode='linear', align_corners=False).squeeze(0).permute(1, 0)
                    else:
                        raise ValueError(f"Unexpected tensor shape for {self.modalities[0]}: {mod_data.shape}")
                current_length = resampled.shape[0]
                if current_length > expected_length:
                    resampled = resampled[:expected_length]
                elif current_length < expected_length:
                    padding_needed = expected_length - current_length
                    last_frame = resampled[-1:].repeat(padding_needed, 1)
                    resampled = torch.cat([resampled, last_frame], dim=0)
                assert resampled.shape[0] == expected_length
                x = resampled
            else:
                x = data[self.modalities[0]]

        y = data['label']
        return x, y
