import os 
import torch  
from dataclasses import dataclass
from typing import Optional, Tuple
import random 
from torch.utils.data import DataLoader, Dataset
import random  
import h5py 
import torch.nn.functional as F


def create_data2D(
    n_input_scalar_components: int,
    n_input_vector_components: int,
    n_output_scalar_components: int,
    n_output_vector_components: int,
    scalar_fields: torch.Tensor,
    vector_fields: torch.Tensor,
    grid: Optional[torch.Tensor],
    start: int,
    time_history: int,
    time_future: int,
    time_gap: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Create 2D training data for one-step prediction.
    Returns data of shape [time_history, n_in_channels, H, W]
    and targets of shape [time_future, n_out_channels, H, W].
    """
    assert n_input_scalar_components > 0 or n_input_vector_components > 0
    assert n_output_scalar_components > 0 or n_output_vector_components > 0
    assert time_history > 0

    end_time = start + time_history
    target_start_time = end_time + time_gap
    target_end_time = target_start_time + time_future

    # Prepare the scalar input
    data_scalar = torch.Tensor()
    if n_input_scalar_components > 0:
        data_scalar = scalar_fields[start:end_time, :n_input_scalar_components]  
        # shape => [time_history, n_input_scalar_components, H, W]

    # Prepare the scalar target
    labels_scalar = torch.Tensor()
    if n_output_scalar_components > 0:
        labels_scalar = scalar_fields[target_start_time:target_end_time, :n_output_scalar_components]
        # shape => [time_future, n_output_scalar_components, H, W]

    # Prepare the vector input
    if n_input_vector_components > 0:
        data_vector = vector_fields[start:end_time, : n_input_vector_components * 2]
        # shape => [time_history, 2*n_input_vector_components, H, W]
        data = torch.cat((data_scalar, data_vector), dim=1)
        # shape => [time_history, n_input_scalar_components + 2*n_input_vector_components, H, W]
    else:
        data = data_scalar

    # Prepare the vector target
    if n_output_vector_components > 0:
        labels_vector = vector_fields[target_start_time:target_end_time, : n_output_vector_components * 2]
        targets = torch.cat((labels_scalar, labels_vector), dim=1)
        # shape => [time_future, n_output_scalar_components + 2*n_output_vector_components, H, W]
    else:
        targets = labels_scalar

    # Optional: add spatial grid, if you want, e.g. data = torch.cat((data, grid), dim=1)

    if targets.size(1) == 0:
        raise ValueError("No targets")

    return data, targets



class NavierStokes2DDataset(Dataset):
    """
    A dataset that, on each __getitem__ call, 
    1) Randomly picks an h5 file from the given directory 
    2) Randomly picks a sample inside that file
    3) Randomly picks a start_time
    4) Returns x, y from create_data2D
    """
    def __init__(
        self,
        folder_path: str,
        n_input_scalar_components: int = 1,
        n_input_vector_components: int = 2,
        n_output_scalar_components: int = 1,
        n_output_vector_components: int = 1,
        time_history: int = 4,
        time_future: int = 1,
        time_gap: int = 0,
        trajlen: int = 14,
        subset: str = "train",    # or "valid"/"test", etc.
        max_samples: int = 10000  # control the dataset size if desired
    ):
        super().__init__()
        self.folder_path = folder_path  
        entries = os.listdir(folder_path)

        # Filter out only the directories
        subdirectories = [
            entry for entry in entries 
            if os.path.isdir(os.path.join(folder_path, entry))
        ]
        self.subdirectories = subdirectories 
        self.n_input_scalar_components = n_input_scalar_components
        self.n_input_vector_components = n_input_vector_components
        self.n_output_scalar_components = n_output_scalar_components
        self.n_output_vector_components = n_output_vector_components
        self.time_history = time_history
        self.time_future = time_future
        self.time_gap = time_gap
        self.trajlen = trajlen
        self.max_samples = max_samples
        self.subset = subset

        # Gather all h5 files in the folder that match the subset (e.g., "train")
        # e.g., something like "NavierStokes2D_train_*****.h5"
        self.h5_files = {}
        for sub in subdirectories:
            self.h5_files[sub] = []
            fpath = os.path.join(folder_path, sub)  
            for f in os.listdir(fpath):
                if f.endswith(".h5") and subset in f:
                    self.h5_files[sub].append(os.path.join(fpath, f))
        print("Data Folder: ", folder_path)
        for k, v in self.h5_files.items():
            print(f"Subset: {k}, Number of Files: {len(v)}")                      
        # print("H5 Files: ", self.h5_files)  

    def __len__(self):
        # You can decide how many total samples you want per epoch;
        # here, we just set an upper bound.
        return self.max_samples

    def __getitem__(self, idx):
        h5_files = self.h5_files['fine']  
        file_path = random.choice(h5_files)

        modes = self.subdirectories 
        # print("Mode: ", modes)


        paths_dict = {}
        for mode in modes:
            # Replace "/fine/" with the corresponding mode
            new_path = file_path.replace("/fine/", f"/{mode}/")
            if os.path.exists(new_path):
                paths_dict[mode] = new_path 
        fidelity_data = {} 
        with h5py.File(file_path, "r") as f: 
            data = f[self.subset]   # or whatever dataset name you use inside the file

            # 2. Randomly pick a sample index within that file
            num_samples_in_file = data["u"].shape[0]
            sample_idx = random.randint(0, num_samples_in_file - 1)

            # 3. Extract the PDE fields from the dataset
            u = torch.tensor(data["u"][sample_idx]).unsqueeze(1).float()   # shape: [T, 1, H, W]
            vx = torch.tensor(data["vx"][sample_idx]).float()              # shape: [T, H, W]
            vy = torch.tensor(data["vy"][sample_idx]).float()              # shape: [T, H, W]

            # For illustration, suppose you had a buoyancy term:
            cond = None
            if "buo_y" in data:
                cond_data = torch.tensor(data["buo_y"][sample_idx]).float()
                cond = cond_data.unsqueeze(0)  # e.g., shape: [T, 1, H, W]

            # Combine vx, vy into a vector field v
            # shape: [T, 2, H, W]
            v = torch.stack([vx, vy], dim=1)

            # 4. Randomly pick a start_time
            time_resolution = min(u.shape[0], self.trajlen)
            # ensure start_time + time_history + time_future + time_gap <= time_resolution
            max_start_time = time_resolution - self.time_history - self.time_future - self.time_gap
            start_time = random.randint(0, max_start_time)

            # 5. Create x, y using your create_data2D
            #    (we pass grid=None here, but you could also incorporate a spatial grid)
            out_x, out_y = create_data2D(
                self.n_input_scalar_components,
                self.n_input_vector_components,
                self.n_output_scalar_components,
                self.n_output_vector_components,
                u,  # scalar_fields
                v,  # vector_fields
                grid=None,
                start=start_time,
                time_history=self.time_history,
                time_future=self.time_future,
                time_gap=self.time_gap
            ) 

        for fidelity, path in paths_dict.items(): 
            with h5py.File(path, "r") as f:
                data = f[self.subset]   # or whatever dataset name you use inside the file

                # 2. Randomly pick a sample index within that file
                num_samples_in_file = data["u"].shape[0] 
                # 3. Extract the PDE fields from the dataset
                u = torch.tensor(data["u"][sample_idx]).unsqueeze(1).float()   # shape: [T, 1, H, W]
                vx = torch.tensor(data["vx"][sample_idx]).float()              # shape: [T, H, W]
                vy = torch.tensor(data["vy"][sample_idx]).float()              # shape: [T, H, W]

                # For illustration, suppose you had a buoyancy term:
                cond = None
                if "buo_y" in data:
                    cond_data = torch.tensor(data["buo_y"][sample_idx]).float()
                    cond = cond_data.unsqueeze(0)  # e.g., shape: [T, 1, H, W]
                viscosity = None
                if "viscosity" in data:
                    viscosity_data = torch.tensor(data["viscosity"][sample_idx]).float()
                    viscosity = viscosity_data.unsqueeze(0)
                # Combine vx, vy into a vector field v
                # shape: [T, 2, H, W]
                v = torch.stack([vx, vy], dim=1)

                # 4. Randomly pick a start_time
                time_resolution = min(u.shape[0], self.trajlen)
                # ensure start_time + time_history + time_future + time_gap <= time_resolution
                max_start_time = time_resolution - self.time_history - self.time_future - self.time_gap 
                # 5. Create x, y using your create_data2D
                #    (we pass grid=None here, but you could also incorporate a spatial grid)
                x, y = create_data2D(
                    self.n_input_scalar_components,
                    self.n_input_vector_components,
                    self.n_output_scalar_components,
                    self.n_output_vector_components,
                    u,  # scalar_fields
                    v,  # vector_fields
                    grid=None,
                    start=start_time,
                    time_history=self.time_history,
                    time_future=self.time_future,
                    time_gap=self.time_gap
                ) 
                x = F.interpolate(
                    x, 
                    size=(out_x.shape[2], out_x.shape[3]), 
                    mode='bilinear', 
                    align_corners=False  # often recommended for smooth results
                )
                y = F.interpolate(
                    y, 
                    size=(out_y.shape[2], out_y.shape[3]), 
                    mode='bilinear', 
                    align_corners=False  # often recommended for smooth results
                ) 
                fidelity_data[fidelity] = (x, y)
                if viscosity is not None:
                    fidelity_data['viscosity'] = viscosity
        return out_x, out_y, fidelity_data