# https://github.com/PolymathicAI/multiple_physics_pretraining/blob/main/data_utils/datasets.py
""" 
Remember to parameterize the file paths eventually
"""
import torch
import torch.nn
import numpy as np
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torch.utils.data.distributed import DistributedSampler
import os
import re
try:
    from mixed_dset_sampler import MultisetSampler
    from hdf5_datasets import *
except ImportError:
    from .mixed_dset_sampler import MultisetSampler
    from .hdf5_datasets import *
import os
import random
import glob
from pdb import set_trace as bp

broken_paths = []
# IF YOU ADD A NEW DSET MAKE SURE TO UPDATE THIS MAPPING SO MIXED DSET KNOWS HOW TO USE IT
DSET_NAME_TO_OBJECT = {
            'swe': SWEDataset,
            'incompNS': IncompNSDataset,
            'diffre2d': DiffRe2DDataset,
            'compNS': CompNSDataset,
            'scalarflow': ScalarFlowDataset
            }


def get_data_loader(params, paths, distributed, split='train', train_val_test=None, rank=0, train_offset=0, rollout=1):
    # paths, types, include_string = zip(*paths)
    if train_val_test and split == 'train': # TODO
        train_val_test = [train_val_test[0]*params.train_subsample, train_val_test[1], train_val_test[2]]
    dataset = MixedDataset(paths, n_steps=params.n_steps, train_val_test=train_val_test, target_shape=params.input_size, split=split,
                            tie_fields=params.tie_fields, use_all_fields=params.use_all_fields, enforce_max_steps=params.enforce_max_steps, 
                            train_offset=train_offset,
                            rollout=rollout, temporal_cutoff=getattr(params, "temporal_cutoff", -1) if split == "train" else getattr(params, "temporal_cutoff_test", -1), max_channels=params.in_chans)
    # dataset = IncompNSDataset(paths[0], n_steps=params.n_steps, train_val_test=params.train_val_test, split=split)
    seed = torch.random.seed() if 'train' in split else 0
    if distributed:
        base_sampler = DistributedSampler
    else:
        base_sampler = RandomSampler
    sampler = MultisetSampler(dataset, base_sampler, params.batch_size,
                               distributed=distributed, max_samples=params.epoch_size, 
                               rank=rank)#, seed=seed)
    # sampler = DistributedSampler(dataset) if distributed else None
    dataloader = DataLoader(dataset,
                            batch_size=int(params.batch_size),
                            num_workers=params.num_data_workers,
                            shuffle=False, #(sampler is None),
                            sampler=sampler, # Since validation is on a subset, use a fixed random subset,
                            drop_last=True,
                            pin_memory=torch.cuda.is_available())
    return dataloader, dataset, sampler


class MixedDataset(Dataset):
    def __init__(self, path_list=[], n_steps=1, dt=1, train_val_test=(.8, .1, .1), target_shape=256,
                  split='train', tie_fields=True, use_all_fields=True, extended_names=False, 
                  enforce_max_steps=False, train_offset=0, rollout=1, temporal_cutoff=-1, max_channels=3):
        super().__init__()
        # Global dicts used by Mixed DSET. 
        self.train_offset = train_offset
        self.path_list, self.type_list, self.include_string, self.pde_param_list = zip(*path_list)
        self.tie_fields = tie_fields
        self.extended_names = extended_names
        self.target_shape = target_shape
        self.split = split
        self.sub_dsets = []
        self.offsets = [0]
        self.train_val_test = train_val_test
        self.use_all_fields = use_all_fields
        self.rollout = rollout
        self.temporal_cutoff = temporal_cutoff
        self.max_channels = max_channels

        for dset, path, include_string, pde_param in zip(self.type_list, self.path_list, self.include_string, self.pde_param_list):
            subdset = DSET_NAME_TO_OBJECT[dset](path, include_string, pde_param=pde_param, n_steps=n_steps,
                                                 dt=dt, train_val_test=train_val_test, target_shape=target_shape, split=split,
                                                 rollout=self.rollout, temporal_cutoff=self.temporal_cutoff, max_channels=max_channels)
            # Check to make sure our dataset actually exists with these settings
            try:
                len(subdset)
            except ValueError:
                raise ValueError(f'Dataset {path} is empty. Check that n_steps < trajectory_length in file.')
            self.sub_dsets.append(subdset)
            self.offsets.append(self.offsets[-1]+len(self.sub_dsets[-1]))
        self.offsets[0] = -1

        self.subset_dict = self._build_subset_dict()

    def get_state_names(self):
        name_list = []
        if self.use_all_fields:
            for name, dset in DSET_NAME_TO_OBJECT.items():
                field_names = dset._specifics()[2]
                name_list += field_names
            return name_list
        else:
            visited = set()
            for dset in self.sub_dsets:
                    name = dset.get_name() # Could use extended names here
                    if not name in visited:
                        visited.add(name)
                        name_list.append(dset.field_names)
        return [f for fl in name_list for f in fl] # Flatten the names

    def _build_subset_dict(self):
        # Maps fields to subsets of variables
        if self.tie_fields: # Hardcoded, but seems less effective anyway
            subset_dict = {
                        # 'swe': [3],
                        'incompNS': [0, 1, 2],
                        # 'compNS': [0, 1, 2, 3],
                        # 'diffre2d': [4, 5]
                        }
        elif self.use_all_fields:
            cur_max = 0
            subset_dict = {}
            for name, dset in DSET_NAME_TO_OBJECT.items():
                field_names = dset._specifics()[2]
                subset_dict[name] = list(range(cur_max, cur_max + len(field_names)))
                cur_max += len(field_names)
        else:
            subset_dict = {}
            cur_max = self.train_offset
            for dset in self.sub_dsets:
                name = dset.get_name(self.extended_names)
                if not name in subset_dict:
                    subset_dict[name] = list(range(cur_max, cur_max + len(dset.field_names)))
                    cur_max += len(dset.field_names)
        return subset_dict

    def __getitem__(self, index):
        file_idx = np.searchsorted(self.offsets, index, side='right')-1 #which dataset are we are on
        local_idx = index - max(self.offsets[file_idx], 0)
        try:
            x, bcs, y, _ = self.sub_dsets[file_idx][local_idx]
        except Exception as e:
            print('FAILED AT ', file_idx, local_idx, index,int(os.environ.get("RANK", 0)), e)
        return x, file_idx, torch.tensor(self.subset_dict[self.sub_dsets[file_idx].get_name()]), bcs, y
    
    def __len__(self):
        return sum([len(dset) for dset in self.sub_dsets])



import imageio
import json
import torch.nn.functional as F
import cv2

# 1920x1080

# TODO: for 120 frames
# FLOW_H_TOP, FLOW_H_BOTTOM = 560, 1820
# FLOW_W_LEFT, FLOW_W_RIGHT = 240, 816
# TODO: for 20 frames
# FLOW_H_TOP, FLOW_H_BOTTOM = 1085, 1820
# FLOW_W_LEFT, FLOW_W_RIGHT = 353, 786
# 508 1820 204 882

FLOW_H = 1920
FLOW_W = 1080

def calculate_crop_box(basedir="./data/ScalarReal", half_res=False, split='train', frame_num_cutoff=-1):
    if half_res:
        f_h = FLOW_H // 2
        f_w = FLOW_W // 2
    else:
        f_h = FLOW_H
        f_w = FLOW_W
    
    C_H, C_W = (f_h - 60 , int(f_w/2)) if half_res else  (f_h - 120 , int(f_w/2))
    flow_h_top, flow_h_bottom, flow_w_left, flow_w_right = C_H, C_H, C_W, C_W   
    with open(os.path.join(basedir, 'info.json'), 'r') as fp:
        meta = json.load(fp)
        if split == 'all':
            video_list = []
            video_list.extend(meta['train_videos'])
            video_list.extend(meta['test_videos'])
        videos = {}
        for video_id, train_video in enumerate(video_list):
            imgs = []
            f_name = os.path.join(basedir, train_video['file_name'])
            filename = train_video['file_name'].split('.')[0] # name (prefix) of h5 file to save
            if frame_num_cutoff <= 0:
                frame_num = train_video['frame_num']
            else:
                frame_num = frame_num_cutoff
            
            reader = imageio.get_reader(f_name, "ffmpeg")
            for frame_i in range(frame_num):
                reader.set_image_index(frame_i)
                frame = reader.get_next_data()
                H, W = frame.shape[:2]
                if half_res:
                    frame = cv2.resize(frame, (W//2, H//2), interpolation=cv2.INTER_AREA)
                imgs.append(frame)
            reader.close()

            imgs = np.array(imgs)
            # out_put = np.where(imgs<20, 0, imgs) # num_f*h*w*c
            out_put = np.where(imgs<20, 0, imgs) 
            if half_res:
                out_put[:, -50:, :,:]=0
            else:
                out_put[:, -100:, :,:]=0
            img_test = out_put.sum(axis=-1) # num_f, h, w 1920*1080   # .sum(axis=0)
            w_test = img_test.sum(axis=1) # add h, len = w  num_f*1080  
            h_test = img_test.sum(axis=2) # add w, len = h  num_f*1920
            w_test_mask = np.where(w_test>40, 1, 0)
            h_test_mask = np.where(h_test>40, 1, 0)
            y1 = C_H - h_test_mask[:, :C_H].sum(axis=1).max()
            y2 = C_H + h_test_mask[:, C_H:].sum(axis=1).max()
            x1 = C_W - w_test_mask[:, :C_W].sum(axis=1).max()
            x2 = C_W + w_test_mask[:, C_W:].sum(axis=1).max()
            if y1 < flow_h_top: flow_h_top = y1
            if y2 > flow_h_bottom: flow_h_bottom = y2
            if x1 < flow_w_left: flow_w_left = x1
            if x2 > flow_w_right: flow_w_right = x2
            
            # print(filename + 'xy', y1, y2, x1, x2)
            # videos[filename] = imgs
        print('nf:', frame_num,  'bbox:', flow_h_top, flow_h_bottom, flow_w_left, flow_w_right)
        # for filename in videos:
        #     img = videos[filename]
            # imgs_crop = img[:, flow_h_top:flow_h_bottom, flow_w_left:flow_w_right]
            # cv2.imwrite(f'logs/HY/after_crop_gt_{filename}.png', imgs_crop[-1])
            # cv2.imwrite(f'logs/HY/befor_crop_gt_{filename}.png', img[-1])
    bbox = flow_h_top, flow_h_bottom, flow_w_left, flow_w_right
    # print(size)
    # import pdb; pdb.set_trace()
    return bbox, ((f_h-flow_h_top-(f_h-flow_h_bottom), f_w-flow_w_left-(f_w-flow_w_right)), (flow_h_top, f_h-flow_h_bottom, flow_w_left, f_w-flow_w_right))


# def pinf_frame_to_h5(basedir="./data/ScalarReal", half_res=False, split='train', frame_num_cutoff=-1, target_shape=512, bbox=None, save_path=None):
def pinf_frame_to_h5(basedir="./data/ScalarReal", half_res=False, split='train', frame_num_cutoff=-1, target_shape=512, bbox=None, save_path=None):
    """
    half_res: learning happens on half-resolution images (independent of target_shape, which is to avoid transformer OOM)
    """
    # print('half_res',half_res)
    if half_res:
        f_h = FLOW_H // 2
        f_w = FLOW_W // 2
    else:
        f_h = FLOW_H
        f_w = FLOW_W
    if bbox==None:
        flow_h_top, flow_h_bottom, flow_w_left, flow_w_right =  (542, 960, 176, 398) if half_res else (1085, 1820, 353, 786)
    else: 
        flow_h_top, flow_h_bottom, flow_w_left, flow_w_right = bbox
    # frame data
    with open(os.path.join(basedir, 'info.json'), 'r') as fp:
        # read render settings
        meta = json.load(fp)
        # read video frames
        # all videos should be synchronized, having the same frame_rate and frame_num
        if split == 'all':
            video_list = []
            video_list.extend(meta['train_videos'])
            video_list.extend(meta['test_videos'])
        else:
            video_list = meta[split + '_videos'] if (split + '_videos') in meta else meta['train_videos'][0:1]
        if save_path == None:
            save_path=f"{basedir}/h5"
        else:
            save_path=f"{basedir}/{save_path}"

        paths = []
        if frame_num_cutoff > 0:
            path = save_path + f"/n{frame_num_cutoff}_{target_shape}/{split}"
        else:
            path = save_path + f"/n{frame_num_cutoff}_{target_shape}_test/{split}"
        
        for video_id, train_video in enumerate(video_list):
            imgs = []
            f_name = os.path.join(basedir, train_video['file_name'])
            filename = train_video['file_name'].split('.')[0] # name (prefix) of h5 file to save
            os.makedirs(os.path.join(path, filename), exist_ok=True)
            paths.append(os.path.join(path, filename))
            if frame_num_cutoff <= 0:
                frame_num = train_video['frame_num']
            else:
                frame_num = frame_num_cutoff
            
            reader = imageio.get_reader(f_name, "ffmpeg")
            for frame_i in range(frame_num):
                reader.set_image_index(frame_i)
                frame = reader.get_next_data()

                H, W = frame.shape[:2]
                # camera_angle_x = float(train_video['camera_angle_x'])
                # Focal = .5 * W / np.tan(.5 * camera_angle_x)
                
                C_H, C_W = (f_h - 60 , int(f_w/2)) if half_res else  (f_h - 120 , int(f_w/2))
                if half_res:
                    frame = cv2.resize(frame, (W//2, H//2), interpolation=cv2.INTER_AREA)
                
                frame = np.where(frame<20, 0, frame)       # lyq 
                # import pdb; pdb.set_trace()
                # denoise
                img_test = frame.sum(axis=-1) # num_f, h, w 1920*1080   # .sum(axis=0)
                w_test = img_test.sum(axis=0) # add h, len = w  num_f*1080  
                h_test = img_test.sum(axis=1) # add w, len = h  num_f*1920
                w_test_mask = np.where(w_test>40, 1, 0)
                h_test_mask = np.where(h_test>40, 1, 0)
                y1 = C_H - h_test_mask[:C_H].sum()
                y2 = C_H + h_test_mask[C_H:].sum()
                x1 = C_W - w_test_mask[:C_W].sum()
                x2 = C_W + w_test_mask[C_W:].sum()
                frame[:y1] = 0
                frame[y2:] = 0
                frame[:, :x1] = 0
                frame[:, x2:] = 0
                imgs.append(frame)

            reader.close()
            imgs = (np.float32(imgs) / 255.)
            imgs = imgs[:, flow_h_top:flow_h_bottom, flow_w_left:flow_w_right]  # reduce image size to save GPU memory used by the transformer
            if target_shape > 0 and isinstance(target_shape, int):
                imgs_resized = np.zeros((imgs.shape[0], target_shape, target_shape, imgs.shape[-1]))
                for i, img in enumerate(imgs):
                    # to enlarge the image: INTER_LINEAR or INTER_CUBIC interpolation
                    # to shrinke the image: INTER_AREA interpolation
                    imgs_resized[i] = cv2.resize(img, (target_shape, target_shape), interpolation=cv2.INTER_AREA)
                imgs = imgs_resized.astype(np.float32)

            with h5py.File(os.path.join(path, filename, filename+'.h5'), 'w') as f:
                tasks = f.create_group('data')
                tasks.create_dataset('r', data=imgs[:, :, :, 0])
                tasks.create_dataset('g', data=imgs[:, :, :, 1])
                tasks.create_dataset('b', data=imgs[:, :, :, 2])

    return paths


############  load hy for finetune  #####################3

def pinf_combine_frames_to_h5(basedir="./data/ScalarReal", half_res=False, split='train', frame_num_cutoff=-1, target_shape=512, bbox=None, save_path=None, hy_dir="", hy_frames=10):
    """
    half_res: learning happens on half-resolution images (independent of target_shape, which is to avoid transformer OOM)
    """
    # print('half_res',half_res)
    if half_res:
        f_h = FLOW_H // 2
        f_w = FLOW_W // 2
    else:
        f_h = FLOW_H
        f_w = FLOW_W
    if bbox==None:
        flow_h_top, flow_h_bottom, flow_w_left, flow_w_right =  (542, 960, 176, 398) if half_res else (1085, 1820, 353, 786)
    else: 
        flow_h_top, flow_h_bottom, flow_w_left, flow_w_right = bbox
    # frame data
    fm_frames= {}
    with open(os.path.join(basedir, 'info.json'), 'r') as fp:
        # read render settings
        meta = json.load(fp)
        # read video frames
        # all videos should be synchronized, having the same frame_rate and frame_num
        if split == 'all':
            video_list = []
            video_list.extend(meta['train_videos'])
            video_list.extend(meta['test_videos'])
        else:
            video_list = meta[split + '_videos'] if (split + '_videos') in meta else meta['train_videos'][0:1]
        if save_path == None:
            save_path=f"{basedir}/h5"
        else:
            save_path=f"{basedir}/{save_path}"

        paths = []
        if frame_num_cutoff > 0:
            path = save_path + f"/n{frame_num_cutoff}_{target_shape}/{split}"
        else:
            path = save_path + f"/n{frame_num_cutoff}_{target_shape}_test/{split}"
        
        for video_id, train_video in enumerate(video_list):
            imgs = []
            f_name = os.path.join(basedir, train_video['file_name'])
            filename = train_video['file_name'].split('.')[0] # name (prefix) of h5 file to save
            os.makedirs(os.path.join(path, filename), exist_ok=True)
            paths.append(os.path.join(path, filename))
            if frame_num_cutoff <= 0:
                frame_num = train_video['frame_num']
            else:
                frame_num = frame_num_cutoff
            
            reader = imageio.get_reader(f_name, "ffmpeg")
            for frame_i in range(frame_num):
                reader.set_image_index(frame_i)
                frame = reader.get_next_data()

                H, W = frame.shape[:2]
                # camera_angle_x = float(train_video['camera_angle_x'])
                # Focal = .5 * W / np.tan(.5 * camera_angle_x)
                
                C_H, C_W = (f_h - 60 , int(f_w/2)) if half_res else  (f_h - 120 , int(f_w/2))
                if half_res:
                    frame = cv2.resize(frame, (W//2, H//2), interpolation=cv2.INTER_AREA)
                
                frame = np.where(frame<20, 0, frame)       # lyq 
                # import pdb; pdb.set_trace()
                # denoise
                img_test = frame.sum(axis=-1) # num_f, h, w 1920*1080   # .sum(axis=0)
                w_test = img_test.sum(axis=0) # add h, len = w  num_f*1080  
                h_test = img_test.sum(axis=1) # add w, len = h  num_f*1920
                w_test_mask = np.where(w_test>40, 1, 0)
                h_test_mask = np.where(h_test>40, 1, 0)
                y1 = C_H - h_test_mask[:C_H].sum()
                y2 = C_H + h_test_mask[C_H:].sum()
                x1 = C_W - w_test_mask[:C_W].sum()
                x2 = C_W + w_test_mask[C_W:].sum()
                frame[:y1] = 0
                frame[y2:] = 0
                frame[:, :x1] = 0
                frame[:, x2:] = 0
                imgs.append(frame)
            imgs = (np.float32(imgs) / 255.)
            fm_frames[filename] = imgs
            reader.close()
    # what if frame_num_cutoff=-1 or 120

    for video_id, file_name in enumerate(fm_frames):   

        fm_imgs = fm_frames[file_name]

        path_vi = os.path.join(hy_dir, file_name)
        
        hy_img_paths = [x for x in os.listdir(path_vi) if x.startswith('rgb') and x.endswith('.png')]
        hy_img_paths = sorted(hy_img_paths, key=lambda x: int(re.findall("\d+", x)[-1]))
        hy_imgs = []
        # 
        if split=='train':
            for frame_id in range(frame_num_cutoff, frame_num_cutoff+hy_frames):  # [20, 30)
                if 'rgb_{:03d}.png'.format(frame_id) in hy_img_paths:
                    hy_img = cv2.imread(os.path.join(path_vi, 'rgb_{:03d}.png'.format(frame_id)))/255.0  # num_f, h, w, c
                else:
                    print('ERRRR: rgb_{:03d}.png doesn\'t exist'.format(frame_id))
                    print('Check dir:', os.path.join(path_vi, 'rgb_{:03d}.png'.format(frame_id)))
                    bp()
                hy_imgs.append(hy_img)
            hy_imgs = np.float32(hy_imgs)
            if not hy_imgs.shape[1:] == fm_imgs.shape[1:]:
                print(f'HY size{hy_imgs.shape} and FM size{fm_imgs.shape} dont match')
                bp()
            imgs = np.concatenate((fm_imgs, hy_imgs))
        else:
            for idx in range(frame_num_cutoff-hy_frames, frame_num_cutoff):
                # idx = int(img_pth[-7:-4])
                # img_pth = []
                # img_pth = hy_img_paths[0]
                img_pth = 'rgb_{:03d}.png'.format(idx)
                # bp()
                if idx<len(fm_imgs):
                    hy_img = cv2.imread(os.path.join(path_vi, img_pth))/255.0   # num_f, h, w, c
                    # hy_imgs.append(hy_img)
                    fm_imgs[idx] = hy_img
            imgs = fm_imgs
        imgs = imgs[:, flow_h_top:flow_h_bottom, flow_w_left:flow_w_right]  # reduce image size to save GPU memory used by the transformer
        if target_shape > 0 and isinstance(target_shape, int):
            imgs_resized = np.zeros((imgs.shape[0], target_shape, target_shape, imgs.shape[-1]))
            for i, img in enumerate(imgs):
                # to enlarge the image: INTER_LINEAR or INTER_CUBIC interpolation
                # to shrinke the image: INTER_AREA interpolation
                imgs_resized[i] = cv2.resize(img, (target_shape, target_shape), interpolation=cv2.INTER_AREA)
            imgs = imgs_resized.astype(np.float32)

        with h5py.File(os.path.join(path, file_name, file_name+'.h5'), 'w') as f:
            tasks = f.create_group('data')
            tasks.create_dataset('r', data=imgs[:, :, :, 0])
            tasks.create_dataset('g', data=imgs[:, :, :, 1])
            tasks.create_dataset('b', data=imgs[:, :, :, 2])

    return paths
