# data_preprocessing.py
# -*- coding: utf-8 -*-
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, Any

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

@dataclass
class DataBundle:
    train_loader: DataLoader
    test_loader: DataLoader
    x_test: torch.Tensor
    y_test: torch.Tensor
    x_ood_test: Optional[torch.Tensor]
    in_dim: int
    num_classes: int
    ood_loader: Optional[DataLoader] = None     
    info: Dict[str, Any] = None

class NumpyTensorDS(Dataset):
    def __init__(self, X, y, x_dtype=torch.float32, y_dtype=torch.long):
        if isinstance(X, np.ndarray):
            X = torch.from_numpy(X)
        if isinstance(y, np.ndarray):
            y = torch.from_numpy(y)
        X = X.to(dtype=x_dtype)
        y = y.to(dtype=y_dtype)
        self.X, self.y = X, y

    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, i):
        return self.X[i], self.y[i]


class FlattenTransform:
    def __call__(self, x):
        return x.view(-1)  

def load_mnist(batch_size=64, shuffle=True, num_workers=2, flatten=True):
    transforms_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ]
    if flatten:
        transforms_list.append(FlattenTransform())
    
    transform = transforms.Compose(transforms_list)
    
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
    
    return train_loader, test_loader

def load_fashion_mnist(batch_size=64, shuffle=True, num_workers=2, flatten=True):
    transforms_list = [
        transforms.ToTensor(),
        transforms.Normalize((0.2860,), (0.3530,))
    ]
    if flatten:
        transforms_list.append(FlattenTransform())

    transform = transforms.Compose(transforms_list)

    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)   

    return train_loader, test_loader

def get_x_y_from_loader(loader: DataLoader, device: str = "cpu") -> Tuple[torch.Tensor, torch.Tensor]:
    """Extract all X, Y tensors from a DataLoader."""
    x_list, y_list = [], []
    for x_batch, y_batch in loader:
        x_list.append(x_batch.to(device))
        y_list.append(y_batch.to(device))
    x_all = torch.cat(x_list, dim=0)
    y_all = torch.cat(y_list, dim=0)
    return x_all, y_all

def build_id_ood_from_config(cfg: Dict[str, Any]) -> DataBundle:
    data_cfg = cfg.get("data", cfg)

    id_dataset = data_cfg["id_dataset"]
    ood_dataset = data_cfg.get("ood_dataset", None)

    nwrk = int(data_cfg.get("num_workers", 0))

    dataset_name = id_dataset.lower().replace("-", "_")
    if dataset_name == "mnist":  
        train_loader, test_loader = load_mnist(
            batch_size=64, 
            shuffle=True, 
            num_workers=nwrk,
            flatten=True
        )

    elif dataset_name == "fashion_mnist":
        train_loader, test_loader = load_fashion_mnist(
            batch_size=64, 
            shuffle=True, 
            num_workers=nwrk,
            flatten=True
        )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    x_test, y_test = get_x_y_from_loader(test_loader, device=device)

    in_dim = 28 * 28 
    num_classes = 10
    x_ood_test = None
    ood_loader = None
    info = {
        "id": {
            "dataset": id_dataset,
            "in_dim": in_dim,
            "num_classes": num_classes
        }
    }

    if ood_dataset is not None:
        ood_dataset_name = ood_dataset.lower().replace("-", "_")
        if ood_dataset_name == "mnist":
            _, ood_test_loader = load_mnist(
                shuffle=False, 
                num_workers=nwrk,
                flatten=True
            )
        elif ood_dataset_name == "fashion_mnist":
            _, ood_test_loader = load_fashion_mnist(
                shuffle=False, 
                num_workers=nwrk,
                flatten=True
            )
        x_ood_test, _y_ood_dummy = get_x_y_from_loader(ood_test_loader, device=device)
        ood_loader = ood_test_loader
        
        info["ood"] = {
            "dataset": ood_dataset,
            "in_dim": in_dim,  
            "num_classes": num_classes  
        }

    return DataBundle(
        train_loader=train_loader,
        test_loader=test_loader,
        x_test=x_test,
        y_test=y_test,
        x_ood_test=x_ood_test,
        in_dim=in_dim,
        num_classes=num_classes,
        ood_loader=ood_loader,
        info=info,
    )
