import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from typing import List, Dict, Tuple, Optional
import math
from dataclasses import dataclass
from collections import defaultdict
import torch.nn.functional as F

@dataclass
class MultiSeriesBatch:
    global_windows: torch.Tensor     
    individual_windows: torch.Tensor 
    series_ids: torch.Tensor   
    time_indices: torch.Tensor     
    aligned_values: torch.Tensor  
    global_future_targets: torch.Tensor 
    batch_metadata: Dict            


class GlobalManifoldDataset(Dataset):   

    def __init__(
        self,
        csv_file: str,
        sequence_columns: List[str] = None,
        window_size: int = 50,
        window_step: int = 1,
        batch_time_steps: int = 5,
        stride: int = None,
        train_frac: float = 0.7,
        mode: str = 'train',
        normalize: bool = True,
        global_win_last_n: int = 1,
        global_horizon: int = 1,  
        use_normalized_aligned: bool = True
    ):

        self.data = pd.read_csv(csv_file, header=0, index_col=0)
        self.window_size = window_size
        self.window_step = window_step
        self.batch_time_steps = batch_time_steps
        self.stride = stride
        self.mode = mode
        self.train_frac = train_frac
        self.normalize = normalize
        self.global_win_last_n = global_win_last_n
        self.global_horizon = global_horizon
        self.use_normalized_aligned = use_normalized_aligned

        if sequence_columns is None:
            self.sequence_columns = [col for col in self.data.columns if col != 'timestamp']
        else:
            self.sequence_columns = sequence_columns
            
        self.num_series = len(self.sequence_columns)
 
        self._preprocess_data()

        self.time_aligned_samples = self._build_time_aligned_samples()

        self.batch_indices = self._build_batch_indices()
        
    def _preprocess_data(self):

        self.normalized_series = {}
        self.stats = {}
        
        for col in self.sequence_columns:
            series = self.data[col].values
            train_size = int(len(series) * self.train_frac)

            if self.normalize:

                train_series = series  
                mean = np.mean(train_series)
                std = np.std(train_series) + 1e-8
                normalized = (series - mean) / std
                self.stats[col] = {'mean': mean, 'std': std}
                
            else:
                normalized = series
                
            self.normalized_series[col] = normalized

        total_length = len(self.data)
        self.split_point = int(total_length * self.train_frac)
        
    def _build_downsampled_windows(self, series: np.ndarray, start_idx: int) -> np.ndarray:

        required_length = (self.window_size - 1) * self.window_step + 1
        
        if start_idx + required_length > len(series):
            return None
            
        window_data = []
        for i in range(self.window_size):
            pos = start_idx + i * self.window_step
            window_data.append(series[pos])
            
        return np.array(window_data)
    
    def _build_time_aligned_samples(self) -> List[Dict]:
        samples = []
        if self.mode == 'train':
            range_start = 0
            range_end = self.split_point
        else:
            range_start = self.split_point
            range_end = len(self.data)

        required_length = (self.window_size - 1) * self.window_step + 1
        extra_future = self.global_horizon * self.window_step

        rel_max_start_idx = (range_end - range_start) - (required_length + extra_future)

        stride = 1 if (self.stride is None) else self.stride

        for rel_start in range(0, rel_max_start_idx + 1, stride):
            abs_start = range_start + rel_start 
            time_sample = {
                'global_start_idx': abs_start,
                'series_windows': {},
                'valid': True
            }

            for series_idx, col in enumerate(self.sequence_columns):
                series = self.normalized_series[col]
                window = self._build_downsampled_windows(series, abs_start)
                if window is None:
                    time_sample['valid'] = False
                    break
                time_sample['series_windows'][series_idx] = {
                    'window': window,
                    'series_id': series_idx,
                    'column_name': col
                }

            if time_sample['valid']:
                samples.append(time_sample)

        return samples

    
    def _build_batch_indices(self) -> List[List[int]]:

        batch_indices = []
        
        for i in range(0, len(self.time_aligned_samples), self.batch_time_steps):
            batch = list(range(i, min(i + self.batch_time_steps, len(self.time_aligned_samples))))
            if len(batch) == self.batch_time_steps:
                batch_indices.append(batch)
                
        return batch_indices
    
    def __len__(self):
        return len(self.batch_indices)
    
    
    def __getitem__(self, idx):
        batch_sample_indices = self.batch_indices[idx]

        global_windows = []      
        individual_data = []     
        series_ids = []
        time_indices = []
        aligned_values = []      

        global_future_targets = []  

        for batch_idx, sample_idx in enumerate(batch_sample_indices):
            sample = self.time_aligned_samples[sample_idx]
            
            time_window = []
            for series_idx in range(self.num_series):
                series_data = sample['series_windows'][series_idx]['window']
                time_window.append(series_data)
            time_window = np.array(time_window).T  
            global_windows.append(time_window)
            
            for series_idx in range(self.num_series):
                series_window_info = sample['series_windows'][series_idx]
                series_window = series_window_info['window']
                col = series_window_info['column_name']
                
                individual_data.append(series_window)
                series_ids.append(series_idx)
                time_indices.append(sample['global_start_idx'])
                
                start_idx = sample['global_start_idx']
                t_aligned = start_idx + (self.window_size - 1) * self.window_step
                
                if self.use_normalized_aligned:
                    aligned_val = self.normalized_series[col][t_aligned]
                else:
                    aligned_val = self.data[col].values[t_aligned]
                aligned_values.append(aligned_val)

            t_future = t_aligned + self.global_horizon * self.window_step

            future_vals = []
            for series_idx, col in enumerate(self.sequence_columns):
                if self.normalize:
                    future_val = self.normalized_series[col][t_future]
                else:
                    future_val = self.data[col].values[t_future]
                    
                future_vals.append(future_val)
            future_vals = np.array(future_vals, dtype=float)  
            global_future_targets.append(future_vals)

        global_windows_tensor_all = torch.FloatTensor(np.array(global_windows))       
        global_windows_tensor = global_windows_tensor_all[:, -self.global_win_last_n:, :]  
        
        individual_windows_tensor = torch.FloatTensor(np.array(individual_data))  
        series_ids_tensor = torch.LongTensor(series_ids)                        
        time_indices_tensor = torch.LongTensor(time_indices)                   
        aligned_values_tensor = torch.FloatTensor(np.array(aligned_values))       

        global_future_targets_tensor = torch.FloatTensor(np.array(global_future_targets)) 
        
        return MultiSeriesBatch(
            global_windows=global_windows_tensor,
            individual_windows=individual_windows_tensor,
            series_ids=series_ids_tensor,
            time_indices=time_indices_tensor,
            aligned_values=aligned_values_tensor,
            global_future_targets=global_future_targets_tensor,  
            batch_metadata={
                'batch_size': len(batch_sample_indices),
                'window_size': self.window_size,
                'num_series': self.num_series,
                'mode': self.mode,
                'global_horizon': self.global_horizon,
            }
        )


class ContrastiveManifoldSampler:
    def __init__(
        self,
        dataset: GlobalManifoldDataset,
        batches_per_epoch: int = 100,
        pos_time_threshold: int = 10,
        pos_series_threshold: float = 0.3
    ):
        self.dataset = dataset
        self.batches_per_epoch = batches_per_epoch
        self.pos_time_threshold = pos_time_threshold
        self.pos_series_threshold = pos_series_threshold

    def __iter__(self):
        indices = torch.arange(len(self.dataset))

        perm = torch.randperm(len(indices))
        shuffled_indices = indices[perm]

        for batch_idx in shuffled_indices:
            yield batch_idx.item()

    def __len__(self):

        return self.batches_per_epoch



def collate_manifold_batch(batches):

    return batches[0] if isinstance(batches, list) else batches


def create_manifold_data_loaders(
    csv_file: str,
    sequence_columns: List[str] = None,
    window_size: int = 50,
    window_step: int = 1,
    batch_time_steps: int = 5,
    stride: int = 5,
    train_frac: float = 0.7,
    normalize: bool = True,
    batches_per_epoch: int = 100,
    num_workers: int = 4,
    global_win_last_n: int = 1,
    global_horizon: int = 1
) -> Tuple[DataLoader, DataLoader, Dict]:

    train_dataset = GlobalManifoldDataset(
        csv_file=csv_file,
        sequence_columns=sequence_columns,
        window_size=window_size,
        window_step=window_step,
        batch_time_steps=batch_time_steps,
        stride=stride,
        train_frac=train_frac,
        mode='train',
        normalize=normalize,
        global_win_last_n=global_win_last_n,
        global_horizon=global_horizon
    )

    test_dataset = GlobalManifoldDataset(
        csv_file=csv_file,
        sequence_columns=sequence_columns,
        window_size=window_size,
        window_step=window_step,
        batch_time_steps=batch_time_steps,
        stride=stride, 
        train_frac=train_frac,
        mode='test',
        normalize=normalize,
        global_win_last_n=global_win_last_n,
        global_horizon=global_horizon
    )

    train_sampler = ContrastiveManifoldSampler(
        train_dataset,
        batches_per_epoch=batches_per_epoch
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=1,
        sampler=train_sampler,
        collate_fn=collate_manifold_batch,
        num_workers=num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        collate_fn=collate_manifold_batch,
        num_workers=num_workers,
        pin_memory=True
    )

    dataset_info = {
        'num_series': train_dataset.num_series,
        'window_size': window_size,
        'batch_time_steps': batch_time_steps,
        'total_train_batches': len(train_dataset),
        'total_test_batches': len(test_dataset),
        'input_dim': window_size
    }
    
    return train_loader, test_loader, dataset_info