import os
import cv2
import glob
import torch
import random
import einops
import numpy as np
import glob as gb
from utils import read_flo
from torch.utils.data import Dataset
from cvbase.optflow.visualize import flow2rgb


def readFlow(sample_dir, resolution, to_rgb):
    flow = read_flo(sample_dir)
    h, w, _ = np.shape(flow)
    flow = cv2.resize(flow, (resolution[1], resolution[0]), interpolation=cv2.INTER_LINEAR)
    flow[:, :, 0] = flow[:, :, 0] * resolution[1] / w
    flow[:, :, 1] = flow[:, :, 1] * resolution[0] / h
#     import pdb; pdb.set_trace()
    ### static flow 0 [h,w,2] - torgb -> 1 [h,w,3] 
    if to_rgb: 
        flow = np.clip((flow2rgb(flow) - 0.5) * 2, -1., 1.)
    else:
        flow[:, :, 0] = flow[:, :, 0] / resolution[1]
        flow[:, :, 1] = flow[:, :, 1] / resolution[0]
    return einops.rearrange(flow, 'h w c -> c h w')

def readRGB(sample_dir, resolution):
    rgb = cv2.imread(sample_dir)
    rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
    rgb = ((rgb / 255.0) - 0.5) * 2.0
    rgb = cv2.resize(rgb, (resolution[1], resolution[0]), interpolation=cv2.INTER_LINEAR)
    rgb = np.clip(rgb, -1., 1.)
    return einops.rearrange(rgb, 'h w c -> c h w')

def readSeg(sample_dir, resolution=None):
    gt = cv2.imread(sample_dir) / 255
    if resolution:
        gt = cv2.resize(gt, (resolution[1], resolution[0]), interpolation=cv2.INTER_NEAREST)
    return einops.rearrange(gt, 'h w c -> c h w')

    
class FrameData(Dataset):
    def __init__(self, data_dir, resolution, dataset, gap=2, to_rgb=False, train=True, val_seq=None):
        self.dataset = dataset
        self.eval = eval
        self.to_rgb = to_rgb
        self.data_dir = data_dir
        self.img_dir = data_dir[1]
        self.gap = gap
        self.resolution = resolution
        self.seq_length = 7
        if train:
            self.train = train
            self.seq = list([os.path.basename(x) for x in gb.glob(os.path.join(self.img_dir, '*'))])
#             print(self.seq)
        else: 
            self.train = train
            self.seq = val_seq
        

    def __len__(self):
        if self.train:
            return 10000
        else:
            return len(self.seq)

    def __getitem__(self, idx):
        if self.train:
            seq_name = random.choice(self.seq)
            seq = os.path.join(self.img_dir, seq_name, '*.jpg')
            length = len(gb.glob(seq))
            gap = self.gap
            ind = random.randint(gap*self.seq_length//2, length-gap*self.seq_length//2-1)
            
            seq_ids = [ind+gap*(i-self.seq_length//2) for i in range(self.seq_length)]

            flow_idxs = []
            flow_dirs = []
            for i in range(self.seq_length):
                if i == 0:
                    flow_idxs.extend(np.random.choice(np.arange(1, self.seq_length), 2, replace=False).tolist())
                elif i == self.seq_length-1:
                    flow_idxs.extend(np.random.choice(np.arange(self.seq_length-1), 2, replace=False).tolist())
                else:
                    flow_idxs.extend([np.random.choice(i), np.random.choice(np.arange(i+1, self.seq_length))])
                Gap_0 = int(flow_idxs[-2]-i)*gap
                flow_dirs.append(os.path.join(self.data_dir[0], f'Flows_gap{Gap_0}', seq_name, str(seq_ids[i]).zfill(5)+'.flo'))
                Gap_1 = int(flow_idxs[-1]-i)*gap
                flow_dirs.append(os.path.join(self.data_dir[0], f'Flows_gap{Gap_1}', seq_name, str(seq_ids[i]).zfill(5)+'.flo'))
            rgb_dirs = [os.path.join(self.data_dir[1], seq_name, str(i).zfill(5)+'.jpg') for i in seq_ids]
            flow_idxs = np.array(flow_idxs)
            flows = [readFlow(flow_dir, self.resolution, self.to_rgb) for flow_dir in flow_dirs]
            rgbs = [readRGB(rgb_dir, self.resolution) for rgb_dir in rgb_dirs]

            out_flow = np.stack(flows, 0) ## 14(7), C, H, W
            out_rgb = np.stack(rgbs, 0) ## 7, C, H, W 
            
            return out_flow, out_rgb, flow_idxs
        else:
            if self.dataset == 'FBMS':
                seq_name = self.seq[idx]
                rgb_dirs = sorted(os.listdir(os.path.join(self.data_dir[1], seq_name)))
                rgb_dirs = [os.path.join(self.data_dir[1], seq_name, x) for x in rgb_dirs if x.endswith(".jpg")]
                rgbs = np.stack([readRGB(rgb_dir, self.resolution) for rgb_dir in rgb_dirs], axis=0)
                gt_dirs = sorted(os.listdir(os.path.join(self.data_dir[2], seq_name)))
                val_idx = [int(x[:-4])-int(gt_dirs[0][:-4]) for x in gt_dirs if x.endswith(".png")]
                gt_dirs = [os.path.join(self.data_dir[2], seq_name, x) for x in gt_dirs if x.endswith(".png")]  
                gts = np.stack([readSeg(gt_dir) for gt_dir in gt_dirs], axis=0)
                return rgbs, gts, seq_name, val_idx
            else:
                seq_name = self.seq[idx]
                tot = len(glob.glob(os.path.join(self.data_dir[1], seq_name, '*')))
                rgb_dirs = [os.path.join(self.data_dir[1], seq_name, str(i).zfill(5)+'.jpg') for i in range(tot)]
#                 flow_dirs = [os.path.join(self.data_dir[0], 'Flows_gap-1', seq_name, str(i).zfill(5)+'.flo') for i in range(1, tot)]
                gt_dirs = [os.path.join(self.data_dir[2], seq_name, str(i).zfill(5)+'.png') for i in range(tot)]
                rgbs = np.stack([readRGB(rgb_dir, self.resolution) for rgb_dir in rgb_dirs], axis=0)
#                 flows = np.stack([readFlow(flow_dir, self.resolution, self.to_rgb) for flow_dir in flow_dirs], axis=0) 
                gts = np.stack([readSeg(gt_dir) for gt_dir in gt_dirs], axis=0)
                return rgbs, gts, seq_name, [i for i in range(tot)]
                
