"""Code is adapted from https://github.com/MIT-AI-Accelerator/neurips-2020-sevir. Their license is MIT License."""
import sys
import os
import os.path as osp
from typing import List, Union, Dict, Sequence
from math import ceil
import numpy as np
import numpy.random as nprand
import datetime
import pandas as pd
import h5py 
import cv2

## --- START MODIFICATION --- ##
# ADDED Imports for new features
import pickle
import random
## --- END MODIFICATION --- ##

import torch
from torch.utils.data import Dataset as TorchDataset, DataLoader
from torch.nn.functional import avg_pool2d
from torchvision import transforms 

from matplotlib import colors
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
from moviepy.editor import ImageSequenceClip

# --- ORIGINAL, UNCHANGED FUNCTIONS ---
def change_layout_np(data,
                     in_layout='NHWT', out_layout='NHWT',
                     ret_contiguous=False):
    # (Original function code here)
    if in_layout == 'NHWT': pass
    elif in_layout == 'NTHW': data = np.transpose(data, axes=(0, 2, 3, 1))
    elif in_layout == 'NTCHW': data = data[:, :, 0, :, :]; data = np.transpose(data, axes=(0, 2, 3, 1))
    else: raise NotImplementedError
    if out_layout == 'NHWT': pass
    elif out_layout == 'NTHW': data = np.transpose(data, axes=(0, 3, 1, 2))
    elif out_layout == 'NTCHW': data = np.transpose(data, axes=(0, 3, 1, 2)); data = np.expand_dims(data, axis=2)
    else: raise NotImplementedError
    if ret_contiguous: data = data.ascontiguousarray()
    return data

def change_layout_torch(data,
                        in_layout='NHWT', out_layout='NHWT',
                        ret_contiguous=False):
    # (Original function code here)
    if in_layout == 'NHWT': pass
    elif in_layout == 'NTHW': data = data.permute(0, 2, 3, 1)
    elif in_layout == 'NTCHW': data = data[:, :, 0, :, :]; data = data.permute(0, 2, 3, 1)
    else: raise NotImplementedError
    if out_layout == 'NHWT': pass
    elif out_layout == 'NTHW': data = data.permute(0, 3, 1, 2)
    elif out_layout == 'NTCHW': data = data.permute(0, 3, 1, 2); data = torch.unsqueeze(data, dim=2)
    else: raise NotImplementedError
    if ret_contiguous: data = data.contiguous()
    return data

# --- ORIGINAL, UNCHANGED CONSTANTS ---
SEVIR_DATA_TYPES = ['vis', 'ir069', 'ir107', 'vil', 'lght']
SEVIR_RAW_DTYPES = {'vis': np.int16, 'ir069': np.int16, 'ir107': np.int16, 'vil': np.int16, 'lght': np.int16}
LIGHTING_FRAME_TIMES = np.arange(- 120.0, 125.0, 5) * 60
SEVIR_DATA_SHAPE = {'lght': (48, 48), }
PREPROCESS_SCALE_SEVIR = {'vis': 1, 'ir069': 1/1174.68, 'ir107': 1/2562.43, 'vil': 1/47.54, 'lght': 1/0.60517}
PREPROCESS_OFFSET_SEVIR = {'vis': 0, 'ir069': 3683.58, 'ir107': 1552.80, 'vil': -33.44, 'lght': -0.02990}
PREPROCESS_SCALE_01 = {'vis': 1, 'ir069': 1, 'ir107': 1, 'vil': 1/255, 'lght': 1}
PREPROCESS_OFFSET_01 = {'vis': 0, 'ir069': 0, 'ir107': 0, 'vil': 0, 'lght': 0}

class SEVIRTorchDataset(TorchDataset):
    """
    A unified PyTorch Dataset for SEVIR data that handles loading, filtering,
    caching, and transformations.
    """
    def __init__(self,
                 dataset_dir: str,
                 data_types: Sequence[str] = ['vil'],
                 traing: str ='test',
                 seq_len: int = 25,
                 img_size: int = 128,
                 stride: int = 20,
                 raw_seq_len: int = 49,
                 sevir_catalog: Union[str, pd.DataFrame] = None,
                 start_date: datetime.datetime = None,
                 end_date: datetime.datetime = None,
                 datetime_filter=None,
                 catalog_filter='default',
                 shuffle: bool = False,
                 shuffle_seed: int = 1,
                 preprocess: bool = True,
                 rescale_method: str = '01',
                 filter_by_mean: bool = True,
                 filter_threshold: float = 16,
                 verbose: bool = False,
                 **kwargs):

        super().__init__()

        self.dataset_dir = dataset_dir
        self.data_types = data_types
        self.seq_len = seq_len
        self.img_size = img_size
        self.stride = stride
        self.raw_seq_len = raw_seq_len
        self.shuffle = shuffle
        self.shuffle_seed = shuffle_seed
        self.preprocess = preprocess
        self.rescale_method = rescale_method
        self.sevir_data_dir = os.path.join(dataset_dir, "data")

        if sevir_catalog is None: sevir_catalog = os.path.join(dataset_dir, "CATALOG.csv")
        self.catalog = pd.read_csv(sevir_catalog, parse_dates=['time_utc'], low_memory=False) if isinstance(sevir_catalog, str) else sevir_catalog
        if self.catalog is not None: self.catalog = self.catalog[self.catalog['file_name'].str.contains('STORMEVENTS', na=False)]

        if start_date is not None: self.catalog = self.catalog[self.catalog.time_utc > start_date]
        if end_date is not None: self.catalog = self.catalog[self.catalog.time_utc <= end_date]
        if datetime_filter: self.catalog = self.catalog[datetime_filter(self.catalog.time_utc)]
        if catalog_filter is not None:
            if catalog_filter == 'default': catalog_filter = lambda c: c.pct_missing == 0
            self.catalog = self.catalog[catalog_filter(self.catalog)]

        self._compute_samples()
        self._open_files(verbose=verbose)

        self.valid_sequences = []
        if filter_by_mean:
            if 'vil' not in self.data_types: raise ValueError("Filtering requires 'vil' in data_types.")
            cache_fn = f'valid_sequences_{traing}.pkl'
            cache_path = os.path.join(self.dataset_dir, cache_fn)
            if os.path.exists(cache_path):
                print(f"Loading cached sequence IDs from {cache_path}")
                with open(cache_path, 'rb') as f: self.valid_sequences = pickle.load(f)
            else:
                print("Scanning all sequences for filtering... This may take a while.")
                for event_idx, event_id in enumerate(self._samples.index):
                    for seq_idx in range(self.num_seq_per_event):
                        raw_data = self._get_raw_vil_sequence(event_idx, seq_idx)
                        if raw_data is not None and torch.all(raw_data.mean(dim=(1, 2)) >= filter_threshold):
                            self.valid_sequences.append((event_id, seq_idx))
                print(f"Finished. Saving {len(self.valid_sequences)} IDs to cache...")
                with open(cache_path, 'wb') as f: pickle.dump(self.valid_sequences, f)
        else:
            for event_idx, event_id in enumerate(self._samples.index):
                for seq_idx in range(self.num_seq_per_event):
                    self.valid_sequences.append((event_id, seq_idx))
        
        self.reset()
        self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size), antialias=True)])
        print(f"Dataset ready. Found {len(self.valid_sequences)} valid samples.")

    def reset(self):
        if self.shuffle:
            random.seed(self.shuffle_seed)
            random.shuffle(self.valid_sequences)

    def __len__(self):
        return len(self.valid_sequences)

    def __getitem__(self, index):
        if index >= len(self): raise IndexError("Index out of range")
            
        event_id, seq_idx = self.valid_sequences[index]
        row = self._samples.loc[event_id]
        
        ret_dict = {}
        for imgt in self.data_types:
            fname, hdf_idx = row[f'{imgt}_filename'], row[f'{imgt}_index']
            seq_slice = slice(seq_idx * self.stride, seq_idx * self.stride + self.seq_len)
            data_i = self._hdf_files[fname][imgt][int(hdf_idx):int(hdf_idx) + 1, :, :, seq_slice]
            ret_dict[imgt] = torch.from_numpy(data_i.astype(np.float32)) # (1, H, W, T)

        if self.preprocess:
            # The preprocess function expects a batch, so we add a dummy batch dim
            ret_dict = self.preprocess_data_dict(ret_dict, rescale=self.rescale_method)

        vil_data = ret_dict["vil"].squeeze(0) # Remove dummy batch dim -> (H, W, T)
        vil_data = vil_data.permute(2, 0, 1) #(T, H, W)
        vil_data_with_channel = vil_data.unsqueeze(1) # Shape (T, 1, H, W)
        vil_data_resized = self.transform(vil_data_with_channel) # Output (T, 1, H', W')

        return vil_data_resized
    @staticmethod
    def preprocess_data_dict(data_dict, data_types=None, layout='NHWT', rescale='01'):
        """
        Parameters
        ----------
        data_dict:  Dict[str, Union[np.ndarray, torch.Tensor]]
        data_types: Sequence[str]
            The data types that we want to rescale. This mainly excludes "mask" from preprocessing.
        layout: str
            consists of batch_size 'N', seq_len 'T', channel 'C', height 'H', width 'W'
        rescale:    str
            'sevir': use the offsets and scale factors in original implementation.
            '01': scale all values to range 0 to 1, currently only supports 'vil'
        Returns
        -------
        data_dict:  Dict[str, Union[np.ndarray, torch.Tensor]]
            preprocessed data
        """
        if rescale == 'sevir':
            scale_dict = PREPROCESS_SCALE_SEVIR
            offset_dict = PREPROCESS_OFFSET_SEVIR
        elif rescale == '01':
            scale_dict = PREPROCESS_SCALE_01
            offset_dict = PREPROCESS_OFFSET_01
        else:
            raise ValueError(f'Invalid rescale option: {rescale}.')
        if data_types is None:
            data_types = data_dict.keys()
        for key, data in data_dict.items():
            if key in data_types:
                if isinstance(data, np.ndarray):
                    data = scale_dict[key] * (
                            data.astype(np.float32) +
                            offset_dict[key])
                    data = change_layout_np(data=data,
                                            in_layout='NHWT',
                                            out_layout=layout)
                elif isinstance(data, torch.Tensor):
                    data = scale_dict[key] * (
                            data.float() +
                            offset_dict[key])
                    data = change_layout_torch(data=data,
                                               in_layout='NHWT',
                                               out_layout=layout)
                data_dict[key] = data
        return data_dict

    # --- Internal Helper Methods ---
    def _get_raw_vil_sequence(self, event_idx, seq_idx):
        row = self._samples.iloc[event_idx]
        fname, hdf_idx = row.get('vil_filename'), row.get('vil_index')
        if fname is None or pd.isna(fname) or fname not in self._hdf_files: return None
        try:
            event_data = self._hdf_files[fname]['vil'][int(hdf_idx):int(hdf_idx) + 1, :, :, :]
            seq_slice = slice(seq_idx * self.stride, seq_idx * self.stride + self.seq_len)
            sampled_seq = event_data[:, :, :, seq_slice]
            if sampled_seq.shape[3] != self.seq_len: return None
            return torch.from_numpy(sampled_seq.astype(np.float32))
        except (KeyError, ValueError): return None

    def _compute_samples(self):
        imgt_set = set(self.data_types)
        filtcat = self.catalog[self.catalog['img_type'].isin(self.data_types)]
        filtcat = filtcat.groupby('id').filter(lambda x: imgt_set.issubset(x['img_type']))
        filtcat = filtcat.groupby('id').filter(lambda x: len(x) == len(imgt_set))
        self._samples = filtcat.groupby('id').apply(self._df_to_series)

    def _df_to_series(self, df):
        d = {}
        df = df.set_index('img_type')
        for i in self.data_types:
            s = df.loc[i]
            d.update({ f'{i}_filename': s.file_name, f'{i}_index': s.file_index if i != 'lght' else s.name })
        return pd.Series(d)

    def _open_files(self, verbose=True):
        self._hdf_files = {}
        for data_type in self.data_types:
            filenames = self._samples[f'{data_type}_filename'].dropna().unique()
            for f in filenames:
                if f not in self._hdf_files:
                    path = os.path.join(self.sevir_data_dir, f)
                    if os.path.exists(path):
                        if verbose: print('Opening HDF5 file:', f)
                        self._hdf_files[f] = h5py.File(path, 'r')
    
    @property
    def num_seq_per_event(self):
        return 1 + (self.raw_seq_len - self.seq_len) // self.stride






COLOR_MAP = [[0, 0, 0],
              [0.30196078431372547, 0.30196078431372547, 0.30196078431372547],
              [0.1568627450980392, 0.7450980392156863, 0.1568627450980392],
              [0.09803921568627451, 0.5882352941176471, 0.09803921568627451],
              [0.0392156862745098, 0.4117647058823529, 0.0392156862745098],
              [0.0392156862745098, 0.29411764705882354, 0.0392156862745098],
              [0.9607843137254902, 0.9607843137254902, 0.0],
              [0.9294117647058824, 0.6745098039215687, 0.0],
              [0.9411764705882353, 0.43137254901960786, 0.0],
              [0.6274509803921569, 0.0, 0.0],
              [0.9058823529411765, 0.0, 1.0],]

HMF_COLORS = np.array([
    [82, 82, 82],
    [252, 141, 89],
    [255, 255, 191],
    [145, 191, 219]
]) / 255
PIXEL_SCALE = 255.0
BOUNDS = [0.0, 16.0, 31.0, 59.0, 74.0, 100.0, 133.0, 160.0, 181.0, 219.0, 255]
THRESHOLDS = (16, 74, 133, 160, 181, 219)


def gray2color(image, **kwargs):

    # 定义颜色映射和边界
    cmap = colors.ListedColormap(COLOR_MAP )
    bounds = BOUNDS
    norm = colors.BoundaryNorm(bounds, cmap.N)

    # 将图像进行染色
    colored_image = cmap(norm(image))

    return colored_image


if __name__=='__main__':
        
    data = torch.randn(1, 256,256) * 255
    data = data.numpy()
    color = gray2color(data)

