import json
import os
from dataclasses import dataclass, field
import time
from typing import List, Optional
import torch
import numpy as np
import h5py
from os.path import join, exists

from src.utils import pylogger
from src.utils.permutations import perm_w_start_idcs
from src.utils.probe_sampling import box_sample, get_probe_idcs, select_from_video

log = pylogger.RankedLogger(__name__, rank_zero_only=True)

class OpenFoamDataset:
    def __init__(self, 
                split: str,
                mode: str,
                image_size: int,
                global_root: str,
                local_root: str,
                dataset_rel_path: str,
                files: List[str],
                n_steps: int,
                seed: int,
                n_probe_history: int = 1,  # if 0, the probe values at the current timestep are taken
                do_cache: bool=True,
                split_sizes: list=[0.9, 0.1],
                probe_idcs: Optional[list|int]=None,
                probe_mask_box: Optional[list[list[float]]]=None,
                norm_method='normal',
                norm_multiplier=1.,
                position_scale=1000.,
                keep_probe_idcs_fixed=False,
                static_probes=np.array([[64, 1]]),
                jump_by_sequence: bool=False,
                **kwargs
                ):
        assert split in ['train', 'val', 'test']
        assert sum(split_sizes) == 1., "split sizes must sum to 1"

        self.position_scale = position_scale
        self.mode = mode
        self.img_size = image_size
        self.split_sizes = split_sizes
        self.n_steps = n_steps
        self.n_probe_history = n_probe_history
        self.seed = seed
        self.probe_mask_box = probe_mask_box
        self.norm_method = norm_method
        self.norm_multiplier = norm_multiplier
        self.dataset_rel_path = dataset_rel_path
        self.keep_probe_idcs_fixed = keep_probe_idcs_fixed
        self.static_probes = static_probes
        self.jump_by_sequence = jump_by_sequence

        # check if local root available, if yes, use it
        self.root_dir = join(global_root, dataset_rel_path)
        if local_root is not None and (local_root_dir := exists(join(local_root, dataset_rel_path))):
            assert False, "something seems to be off in if condition - double check"
            self.root_dir = local_root_dir
        
        self.case_files = self._get_cases(files, split)
        self.indices, self.lengths = self._compute_indices_for_cases()
        self.cum_lengths = np.cumsum([0] + self.lengths)
        # make all mesh_centers positive for RoPE
        # self.meshcenters_scaled = self.shift_scale_positions(self.mesh_centers)
        
        # TODO: 
        self.probe_idcs = probe_idcs
        
        self.cache = None
        if do_cache:
            self.cache = self._load_cache()

        # mean and std for normalization        
        with open(join(self.root_dir, 'normalization.pkl'), 'r') as file:
            data = json.load(file)
        self.norm_consts = {key: torch.tensor(value) for key, value in data.items()}
        
    def _get_cases(self, files, split):
        # select target split
        files = files[split]
        # get case files from filesystem
        cases = sorted([f for f in os.listdir(self.root_dir) if f.endswith(".hdf5")])
        # filtered cases
        cases = [f for f in cases if f in files]
        print(f'cases for split: {split}: {cases}')

        case_files = sorted([join(self.root_dir, case) for case in cases])
        return case_files
    
    def _load_cache(self):
        cache = []
        i_timesteps = 0
        i_trajectories = 0
        for case_file in self.case_files:
            with h5py.File(case_file, 'r', libver='latest') as hdf_file:
                # print(f'entered file {os.getpid()}')
                case_group = hdf_file['case']
                fields_group = case_group['fields']

                # get all timesteps of trajectory
                # U_dict = {}
                U_list = []
                sorted_keys = sorted(fields_group['U'].keys(), key=float)
                for i, timestep_key in enumerate(sorted_keys):
                    # U_dict[f'U_{i}'] = fields_group['U'][timestep_key][:] # U at timestep t
                    U_list.append(fields_group['U'][timestep_key][:])
                    i_timesteps += 1
            
            i_trajectories += 1
            # cache.append(U_dict)
            cache.append(np.ascontiguousarray(np.stack(U_list, axis=0)))
        
        print(f'Computed cache with {i_timesteps} timesteps')
        return cache

    def _compute_indices_for_cases(self):
        """
        Compute the length of each case based on the sequence_length in metadata.
        """
        indices_list = []
        for case_file in self.case_files:
            with h5py.File(case_file, 'r') as hdf_file:
                case_group = hdf_file['case']
                metadata = {key: value for key, value in case_group['metadata'].attrs.items()}
                seqlen = metadata.get('sequence_length', 0)
                
                assert seqlen > self.n_probe_history + self.n_steps
                # Start only from probe history values
                # Leave out last steps to predict n_steps into the future

                if self.jump_by_sequence:
                    step_size = self.n_probe_history + self.n_steps
                    idcs = np.arange(self.n_probe_history -1, seqlen - self.n_steps, step_size)
                else:
                    idcs = np.arange(self.n_probe_history - 1, seqlen - self.n_steps)
                
                indices_list.append(idcs)
    
        lengths = [len(sub_idcs) for sub_idcs in indices_list]
        return indices_list, lengths
        
    def __len__(self):
        """
        Return the total number of items in the dataset.
        """
        return self.cum_lengths[-1].item()

    def __getitem__(self, idx):
        """
        Get an item from the dataset.

        :param idx: Index of the item to retrieve.
        :return: A dictionary containing metadata, mesh centers, and the requested timesteps' U and p.
        """
        # Determine which case and timestep index to retrieve
        def get_idcs():
            idx_of_case = np.digitize(idx, self.cum_lengths).item() - 1
            idx_in_case = idx - self.cum_lengths[idx_of_case]
            timestep_index = self.indices[idx_of_case][idx_in_case]
            return idx_of_case, timestep_index
        idx_of_case, timestep_index = get_idcs()

        probe_idcs = get_probe_idcs(self.probe_idcs, self.img_size, self.static_probes)
        if self.keep_probe_idcs_fixed:
            # permanently use these
            self.probe_idcs = probe_idcs

        U_vid_field, probe_vid_field = self.get_frame_from_cache_or_file(idx_of_case, timestep_index, probe_idcs)
        
        # put probe_field into range [0, 1000]
        # TODO: investigate! this offset is hacky
        # TODO: remove after deadline!
        offset_to_make_velocities_positive = 1.
        probe_vid_field_scaled = (
            (self.unnormalize(probe_vid_field) - self.norm_consts['min'] + offset_to_make_velocities_positive) / 
            (self.norm_consts['max'] - self.norm_consts['min']) * self.position_scale
        )
        assert torch.all(probe_vid_field_scaled >= 0), f'not all coords are >= 0, coords.min(): {probe_vid_field_scaled.min()}'
        
        
        data = dict(
            field=U_vid_field,
            
            probe_idcs=probe_idcs,
            probe_pos=(probe_idcs / self.img_size * self.position_scale).astype(np.float32),
            probe_field=probe_vid_field_scaled,
            
            # global_cond=torch.tensor([self.case_to_normalized_reynolds(idx_of_case)], dtype=torch.float32)
        )
        
        torch_data = {
            k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v
            for k, v in data.items()
        }
        return torch_data


    def get_frame_from_cache_or_file(self, case_index, timestep_index, probe_idcs):
        '''
        probe history:                [t-h+1, ... t]        h = n_probe_history     h points in probe history
        future timesteps:             [t+1, ... t+n]        n = n_steps;            
        ----------------------------
        total steps = h + n
        '''
        timestep_indices = np.arange(timestep_index-self.n_probe_history+1, timestep_index+self.n_steps+1)

        U_list = []
        probe_field_list = []
        if self.cache is not None:
            
            U = torch.from_numpy(self.cache[case_index][timestep_indices]).float()
            U = self.normalize(U)
            
            probe_history = select_from_video(U, probe_idcs)
            if self.n_steps > 0:
                probe_history[timestep_index+1:] = 0
            
            return U, probe_history
            # for t_idx in timestep_indices:
            #     # U = torch.from_numpy(self.cache[case_index][f'U_{t_idx}']).float()
            #     U = torch.from_numpy(self.cache[case_index][f'U_{t_idx}']).float()
            #     if U.shape[0] != self.img_size:
            #         U = center_crop(U, self.img_size, self.img_size)
            #     U = self.normalize(U)
            #     # add probe history
            #     if t_idx <= timestep_index:
            #         probe_field_list.append(select_from_video(U, probe_idcs))
            #     else:
            #         probe_field_list.append(np.zeros(shape=(probe_idcs)))
            #     # only add U field for current and future timesteps
            #     U_list.append(U)
            
        else:
            case_file = self.case_files[case_index]
            with h5py.File(case_file, 'r', libver='latest') as hdf_file:
                # print(f'entered file {os.getpid()}')
                case_group = hdf_file['case']
                fields_group = case_group['fields']
                sorted_keys = sorted(fields_group['U'].keys(), key=float)
                for t_idx in timestep_indices:
                    U = torch.from_numpy(fields_group['U'][sorted_keys[t_idx]][:]).float() # U at timestep t
                    if U.shape[0] != self.img_size:
                        U = center_crop(U, self.img_size, self.img_size)
                    U = self.normalize(U)
                    # add probe history
                    if t_idx <= timestep_index:
                        probe_field_list.append(select_from_video(U, probe_idcs))
                    else:
                        probe_field_list.append(torch.from_numpy(np.zeros_like(probe_idcs)))
                    # only add U field for current and future timesteps
                    U_list.append(U)
        
            probe_history = torch.stack(probe_field_list, dim=0)
            U_vid_field = torch.stack(U_list, dim=0)
            return U_vid_field, probe_history
    
    
    
    def normalize(self, u):
        if self.norm_method == 'normal':
            u_norm = (u - self.norm_consts['mean'].to(u.device)) / self.norm_consts['std'].to(u.device)
            return u_norm * self.norm_multiplier
        if self.norm_method == 'abs_max':
            abs_max = torch.max(torch.abs(self.norm_consts['min']), torch.abs(self.norm_consts['max']))
            u_norm = u / abs_max.to(u.device)
            return u_norm * self.norm_multiplier
        raise ValueError

    def unnormalize(self, u):
        if self.norm_method == 'normal':
            u_unnorm = u * self.norm_consts['std'].to(u.device) + self.norm_consts['mean'].to(u.device)
            return u_unnorm / self.norm_multiplier
        if self.norm_method == 'abs_max':
            abs_max = torch.max(torch.abs(self.norm_consts['min']), torch.abs(self.norm_consts['max']))
            u_unnorm = u * abs_max.to(u.device)
            return u_unnorm / self.norm_multiplier
        raise ValueError
        
    
    def case_to_normalized_reynolds(self, case_or_caseidx):
        
        if self.dataset_rel_path in [
            'h5_LES/singleJet', 
            'h5_LES/singleJet.03',
            'h5_LES/singleJet.03_large',
            'h5_LES/singleJet.04',
            ]:
            
            case = self.case_files[case_or_caseidx] if isinstance(case_or_caseidx, int) else case_or_caseidx
            
            fname = case.split('/')[-1]
            assert fname.startswith('Re') and fname.endswith(".hdf5"), 'needs to have format '
            reynolds = float(fname.replace('Re', '').replace('.hdf5', ''))

            # normalize range into [-1, 1]
            range = [200, 2500]
            assert range[0] <= reynolds <= range[1], f"usually values are in the range [{range[0], range[1]} for this dataset"
            mean = np.mean(range)
            half_range = (range[1] - range[0])/2
            
            return (reynolds-mean)/half_range
        
        raise ValueError('need to specify reynolds case normalization for each dataset')


def center_crop(image, new_height, new_width):
    h, w, d = image.shape
    # Calculate the starting indices for the crop
    start_h = (h - new_height) // 2
    start_w = (w - new_width) // 2
    # Perform the crop
    cropped = image[start_h:start_h + new_height, start_w:start_w + new_width, :]
    return cropped


