import os
import pandas as pd
import numpy as np
import bisect
from nowcasting import image
from nowcasting.mask import *
from nowcasting.config import cfg
from nowcasting.utils import *

def encode_month(month):
    """Encode the month into a vector

    Parameters
    ----------
    month : np.ndarray
        (...,) int, between 1 and 12
    Returns
    -------
    ret : np.ndarray
        (..., 2) float
    """
    angle = 2 * np.pi * month/12.0
    ret = np.empty(shape=month.shape + (2,), dtype=np.float32)
    ret[..., 0] = np.cos(angle)
    ret[..., 1] = np.sin(angle)
    return ret


def decode_month(code):
    """Decode the month code back to the month value

    Parameters
    ----------
    code : np.ndarray
        (..., 2) float
    Returns
    -------
    month : np.ndarray
        (...,) int
    """
    assert code.shape[-1] == 2
    flag = code[..., 1] >= 0
    arccos_res = np.arccos(code[..., 0])
    angle = flag * arccos_res + (1 - flag) * (2 * np.pi - arccos_res)
    month = angle / (2.0 * np.pi) * 12.0
    month = np.round(month).astype(int)
    return month


def get_valid_datetime_set():
    valid_datetime_set = pickle.load(open(cfg.HKO_VALID_DATETIME_PATH, 'rb'))
    return valid_datetime_set


def get_exclude_mask():
    with np.load(os.path.join(cfg.HKO_DATA_BASE_PATH, 'mask_dat.npz')) as dat:
        exclude_mask = dat['exclude_mask'][:]
        return exclude_mask


def convert_datetime_to_filepath(date_time):
    """Convert datetime to the filepath

    Parameters
    ----------
    date_time : datetime.datetime

    Returns
    -------
    ret : str
    """
    ret = os.path.join("%04d" %date_time.year,
                        "%02d" %date_time.month,
                        "%02d" %date_time.day,
                        'RAD%02d%02d%02d%02d%02d00.png'
                        %(date_time.year - 2000, date_time.month, date_time.day,
                          date_time.hour, date_time.minute))
    ret = os.path.join(cfg.HKO_PNG_PATH, ret)
    return ret


def convert_datetime_to_maskpath(date_time):
    """Convert datetime to path of the mask

    Parameters
    ----------
    date_time : datetime.datetime

    Returns
    -------
    ret : str
    """
    ret = os.path.join("%04d" %date_time.year,
                        "%02d" %date_time.month,
                        "%02d" %date_time.day,
                        'RAD%02d%02d%02d%02d%02d00.mask'
                        %(date_time.year - 2000, date_time.month, date_time.day,
                          date_time.hour, date_time.minute))
    ret = os.path.join(cfg.HKO_MASK_PATH, ret)
    return ret


class HKOSimpleBuffer(object):
    def __init__(self, df, max_buffer_length, width, height):
        self._df = df
        self._max_buffer_length = max_buffer_length
        assert self._df.size > self._max_buffer_length
        self._width = width
        self._height = height

    def reset(self):
        self._datetime_keys = self._df.index[:self._max_buffer_length]
        self._load()

    def _load(self):
        paths = []
        for i in range(self._datetime_keys.size):
            paths.append(convert_datetime_to_filepath(self._datetime_keys[i]))
        self._frame_dat = image.quick_read_frames(path_list=paths,
                                                  im_h=self._height,
                                                  im_w=self._width,
                                                  grayscale=True)
        self._frame_dat = self._frame_dat.reshape((self._max_buffer_length, 1,
                                                   self._height, self._width))
        self._noise_mask_dat = np.zeros((self._datetime_keys.size, 1,
                                         self._height, self._width),
                                        dtype=np.uint8)

    def get(self, timestamps):
        """timestamps must be sorted

        Parameters
        ----------
        timestamps

        Returns
        -------

        """
        if not (timestamps[0] in self._datetime_keys and timestamps[-1] in self._datetime_keys):
            read_begin_ind = self._df.index[self._df.index.get_loc(timestamps[0])]
            read_end_ind = min(read_begin_ind + self._max_buffer_length, self._df.size)
            assert self._df.index[read_end_ind - 1] >= timestamps[-1]
            self._datetime_keys = self._df.index[read_begin_ind:read_end_ind]
            self._load()
        begin_ind = self._datetime_keys.get_loc(timestamps[0])
        end_ind = self._datetime_keys.get_loc(timestamps[-1]) + 1
        return self._frame_dat[begin_ind:end_ind, :, :, :],\
               self._noise_mask_dat[begin_ind:end_ind, :, :, :]


def pad_hko_dat(frame_dat, mask_dat, batch_size):
    if frame_dat.shape[1] < batch_size:
        ret_frame_dat = np.zeros(shape=(frame_dat.shape[0], batch_size,
                                        frame_dat.shape[2], frame_dat.shape[3], frame_dat.shape[4]),
                                 dtype=frame_dat.dtype)
        ret_mask_dat = np.zeros(shape=(mask_dat.shape[0], batch_size,
                                       mask_dat.shape[2], mask_dat.shape[3], mask_dat.shape[4]),
                                 dtype=mask_dat.dtype)
        ret_frame_dat[:, :frame_dat.shape[1], ...] = frame_dat
        ret_mask_dat[:, :frame_dat.shape[1], ...] = mask_dat
        return ret_frame_dat, ret_mask_dat, frame_dat.shape[1]
    else:
        return frame_dat, mask_dat, batch_size


_exclude_mask = get_exclude_mask()
def precompute_mask(img):
    if img.dtype == np.uint8:
        threshold = round(cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD * 255.0)
    else:
        threshold = cfg.HKO.ITERATOR.FILTER_RAINFALL_THRESHOLD
    mask = np.zeros_like(img, dtype=bool)
    mask[:] = np.broadcast_to((1 - _exclude_mask).astype(bool), shape=img.shape)
    mask[np.logical_and(img < threshold,
                        img > 0)] = 0
    return mask


class HKOIterator(object):
    """The iterator for HKO-7 dataset

    """
    def __init__(self, pd_path, sample_mode, seq_len=30,
                 max_consecutive_missing=2, begin_ind=None, end_ind=None,
                 stride=None, width=None, height=None, base_freq='6min'):
        """Random sample: sample a random clip that will not violate the max_missing frame_num criteria
        Sequent sample: sample a clip from the beginning of the time.
                        Everytime, the clips from {T_begin, T_begin + 6min, ..., T_begin + (seq_len-1) * 6min} will be used
                        The begin datetime will move forward by adding stride: T_begin += 6min * stride
                        Once the clips violates the maximum missing number criteria, the starting
                         point will be moved to the next datetime that does not violate the missing_frame criteria

        Parameters
        ----------
        pd_path : str
            path of the saved pandas dataframe
        sample_mode : str
            Can be "random" or "sequent"
        seq_len : int
        max_consecutive_missing : int
            The maximum consecutive missing frames
        begin_ind : int
            Index of the begin frame
        end_ind : int
            Index of the end frame
        stride : int or None, optional
        width : int or None, optional
        height : int or None, optional
        base_freq : str, optional
        """
        if width is None:
            width = cfg.HKO.ITERATOR.WIDTH
        if height is None:
            height = cfg.HKO.ITERATOR.HEIGHT
        self._df = pd.read_pickle(pd_path)
        self.set_begin_end(begin_ind=begin_ind, end_ind=end_ind)
        self._df_index_set = frozenset([self._df.index[i] for i in range(self._df.size)])
        self._exclude_mask = get_exclude_mask()
        self._seq_len = seq_len
        self._width = width
        self._height = height
        self._stride = stride
        self._max_consecutive_missing = max_consecutive_missing
        self._base_freq = base_freq
        self._base_time_delta = pd.Timedelta(base_freq)
        assert sample_mode in ["random", "sequent"], "Sample mode=%s is not supported" %sample_mode
        self.sample_mode = sample_mode
        if sample_mode == "sequent":
            assert self._stride is not None
            self._current_datetime = self.begin_time
            self._buffer_mult = 6
            self._buffer_datetime_keys = None
            self._buffer_frame_dat = None
            self._buffer_mask_dat = None
        else:
            self._max_buffer_length = None

    def set_begin_end(self, begin_ind=None, end_ind=None):
        self._begin_ind = 0 if begin_ind is None else begin_ind
        self._end_ind = self.total_frame_num - 1 if end_ind is None else end_ind

    @property
    def total_frame_num(self):
        return self._df.size

    @property
    def begin_time(self):
        return self._df.index[self._begin_ind]

    @property
    def end_time(self):
        return self._df.index[self._end_ind]

    @property
    def use_up(self):
        if self.sample_mode == "random":
            return False
        else:
            return self._current_datetime > self.end_time

    def _next_exist_timestamp(self, timestamp):
        next_ind = bisect.bisect_right(self._df.index, timestamp)
        if next_ind >= self._df.size:
            return None
        else:
            return self._df.index[bisect.bisect_right(self._df.index, timestamp)]

    def _is_valid_clip(self, datetime_clip):
        """Check if the given datetime_clip is valid

        Parameters
        ----------
        datetime_clip :

        Returns
        -------
        ret : bool
        """
        missing_count = 0
        for i in range(len(datetime_clip)):
            if datetime_clip[i] not in self._df_index_set:
                missing_count += 1
                if missing_count > self._max_consecutive_missing or\
                        missing_count >= len(datetime_clip):
                    return False
            else:
                missing_count = 0
        return True

    def _load_frames(self, datetime_clips):
        assert isinstance(datetime_clips, list)
        for clip in datetime_clips:
            assert len(clip) == self._seq_len
        batch_size = len(datetime_clips)
        frame_dat = np.zeros((self._seq_len, batch_size, 1, self._height, self._width),
                                  dtype=np.uint8)
        mask_dat = np.zeros((self._seq_len, batch_size, 1, self._height, self._width),
                                 dtype=bool)
        if self.sample_mode == "random":
            paths = []
            mask_paths = []
            hit_inds = []
            miss_inds = []
            for i in range(self._seq_len):
                for j in range(batch_size):
                    timestamp = datetime_clips[j][i]
                    if timestamp in self._df_index_set:
                        paths.append(convert_datetime_to_filepath(datetime_clips[j][i]))
                        mask_paths.append(convert_datetime_to_maskpath(datetime_clips[j][i]))
                        hit_inds.append([i, j])
                    else:
                        miss_inds.append([i, j])
            hit_inds = np.array(hit_inds, dtype=int)
            all_frame_dat = image.quick_read_frames(path_list=paths,
                                                    im_h=self._height,
                                                    im_w=self._width,
                                                    grayscale=True)
            all_mask_dat = quick_read_masks(mask_paths)
            frame_dat[hit_inds[:, 0], hit_inds[:, 1], :, :, :] = all_frame_dat
            mask_dat[hit_inds[:, 0], hit_inds[:, 1], :, :, :] = all_mask_dat
        else:
            # Get the first_timestamp and the last_timestamp in the datetime_clips
            first_timestamp = datetime_clips[-1][-1]
            last_timestamp = datetime_clips[0][0]
            for i in range(self._seq_len):
                for j in range(batch_size):
                    timestamp = datetime_clips[j][i]
                    if timestamp in self._df_index_set:
                        first_timestamp = min(first_timestamp, timestamp)
                        last_timestamp = max(last_timestamp, timestamp)
            if self._buffer_datetime_keys is None or\
                not (first_timestamp in self._buffer_datetime_keys
                    and last_timestamp in self._buffer_datetime_keys):
                read_begin_ind = self._df.index.get_loc(first_timestamp)
                read_end_ind = self._df.index.get_loc(last_timestamp) + 1
                read_end_ind = min(read_begin_ind +
                                   self._buffer_mult * (read_end_ind - read_begin_ind),
                                   self._df.size)
                self._buffer_datetime_keys = self._df.index[read_begin_ind:read_end_ind]
                # Fill in the buffer
                paths = []
                mask_paths = []
                for i in range(self._buffer_datetime_keys.size):
                    paths.append(convert_datetime_to_filepath(self._buffer_datetime_keys[i]))
                    mask_paths.append(convert_datetime_to_maskpath(self._buffer_datetime_keys[i]))
                self._buffer_frame_dat = image.quick_read_frames(path_list=paths,
                                                                 im_h=self._height,
                                                                 im_w=self._width,
                                                                 grayscale=True)
                self._buffer_mask_dat = quick_read_masks(mask_paths)
            for i in range(self._seq_len):
                for j in range(batch_size):
                    timestamp = datetime_clips[j][i]
                    if timestamp in self._df_index_set:
                        assert timestamp in self._buffer_datetime_keys
                        ind = self._buffer_datetime_keys.get_loc(timestamp)
                        frame_dat[i, j, :, :, :] = self._buffer_frame_dat[ind, :, :, :]
                        mask_dat[i, j, :, :, :] = self._buffer_mask_dat[ind, :, :, :]
        return frame_dat, mask_dat

    def reset(self, begin_ind=None, end_ind=None):
        assert self.sample_mode == "sequent"
        self.set_begin_end(begin_ind=begin_ind, end_ind=end_ind)
        self._current_datetime = self.begin_time

    def random_reset(self):
        assert self.sample_mode == "sequent"
        self.set_begin_end(begin_ind=np.random.randint(0,
                                                       self.total_frame_num -
                                                       5 * self._seq_len),
                           end_ind=None)
        self._current_datetime = self.begin_time

    def check_new_start(self):
        assert self.sample_mode == "sequent"
        datetime_clip = pd.date_range(start=self._current_datetime,
                                      periods=self._seq_len,
                                      freq=self._base_freq)
        if self._is_valid_clip(datetime_clip):
            return self._current_datetime == self.begin_time
        else:
            return True

    def sample(self, batch_size, only_return_datetime=False):
        """Sample a minibatch from the hko7 dataset based on the given type and pd_file
        
        Parameters
        ----------
        batch_size : int
            Batch size
        only_return_datetime : bool
            Whether to only return the datetimes
        Returns
        -------
        frame_dat : np.ndarray
            Shape: (seq_len, valid_batch_size, 1, height, width)
        mask_dat : np.ndarray
            Shape: (seq_len, valid_batch_size, 1, height, width)
        datetime_clips : list
            length should be valid_batch_size
        new_start : bool
        """
        if self.sample_mode == 'sequent':
            if self.use_up:
                raise ValueError("The HKOIterator has been used up!")
            datetime_clips = []
            new_start = False
            for i in range(batch_size):
                while not self.use_up:
                    datetime_clip = pd.date_range(start=self._current_datetime,
                                                  periods=self._seq_len,
                                                  freq=self._base_freq)
                    if self._is_valid_clip(datetime_clip):
                        new_start = new_start or (self._current_datetime == self.begin_time)
                        datetime_clips.append(datetime_clip)
                        self._current_datetime += self._stride * self._base_time_delta
                        break
                    else:
                        new_start = True
                        self._current_datetime =\
                            self._next_exist_timestamp(timestamp=self._current_datetime)
                        if self._current_datetime is None:
                            # This indicates that there is no timestamp left,
                            # We point the current_datetime to be the next timestamp of self.end_time
                            self._current_datetime = self.end_time + self._base_time_delta
                            break
                        continue
            new_start = None if batch_size != 1 else new_start
            if only_return_datetime:
                return datetime_clips, new_start
        else:
            assert only_return_datetime is False
            datetime_clips = []
            new_start = None
            for i in range(batch_size):
                while True:
                    rand_ind = np.random.randint(0, self._df.size, 1)[0]
                    random_datetime = self._df.index[rand_ind]
                    datetime_clip = pd.date_range(start=random_datetime,
                                                  periods=self._seq_len,
                                                  freq=self._base_freq)
                    if self._is_valid_clip(datetime_clip):
                        datetime_clips.append(datetime_clip)
                        break
        frame_dat, mask_dat = self._load_frames(datetime_clips=datetime_clips)
        return frame_dat, mask_dat, datetime_clips, new_start

# Simple test for the performance of the HKO iterator.
if __name__ == '__main__':
    np.random.seed(123)
    import time
    import cProfile, pstats
    from nowcasting.config import cfg
    from nowcasting.helpers.visualization import save_hko_gif, save_hko_movie

    minibatch_size = 32
    seq_len = 30
    train_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TRAIN,
                                 sample_mode="random",
                                 seq_len=seq_len)
    valid_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_VALID,
                                 sample_mode="sequent",
                                 seq_len=seq_len,
                                 stride=5)
    test_hko_iter = HKOIterator(pd_path=cfg.HKO_PD.RAINY_TEST,
                                sample_mode="sequent",
                                seq_len=seq_len,
                                stride=5)

    repeat_time = 3
    pr = cProfile.Profile()
    pr.enable()
    begin = time.time()
    for i in range(repeat_time):
        sample_sequence, sample_mask, sample_datetime_clips, new_start =\
            train_hko_iter.sample(batch_size=minibatch_size)
    end = time.time()
    pr.disable()
    ps = pstats.Stats(pr).sort_stats('cumulative')
    ps.print_stats(20)
    print("Train Data Sample FPS: %f" % (minibatch_size * seq_len
                                        * repeat_time / float(end - begin)))

    begin = time.time()
    for i in range(repeat_time):
        sample_sequence, sample_mask, sample_datetimes, new_start =\
            valid_hko_iter.sample(batch_size=minibatch_size)
    end = time.time()
    print("Valid Data Sample FPS: %f" % (minibatch_size * seq_len
                                         * repeat_time / float(end - begin)))
    begin = time.time()
    for i in range(repeat_time):
        sample_sequence, sample_mask, sample_datetimes, new_start =\
            test_hko_iter.sample(batch_size=minibatch_size)
    end = time.time()
    print("Test Data Sample FPS: %f" %(minibatch_size * seq_len
                                       * repeat_time / float(end-begin)))
    code = encode_month(np.arange(1, 13))
    month = decode_month(code)
    print(code)
    print(month.T)

    train_time = 0
    for i in range(30):
        train_batch, train_mask, sample_datetimes, new_start = \
            train_hko_iter.sample(batch_size=minibatch_size)
        name_str = 'train_' + str(i) + '_' + sample_datetimes[0][0].strftime('%Y%m%d%H%M')
        save_hko_movie(train_batch[:, 0, 0, :, :],
                       sample_datetimes[0],
                       train_mask[:, 0, 0, :, :],
                       masked=False,
                       save_path=name_str + '.mp4')
        tic = time.time()
        save_hko_movie(train_batch[:, 0, 0, :, :],
                       sample_datetimes[0],
                       train_mask[:, 0, 0, :, :],
                       masked=True,
                       save_path=name_str + '_filtered.mp4')
        toc = time.time()
        save_hko_movie(train_mask[:, 0, 0, :, :].astype(np.uint8) * 255,
                       sample_datetimes[0],
                       None,
                       masked=False,
                       save_path=name_str + '_mask.mp4')
        print('train, time:', toc - tic)

    valid_time = 0
    while not valid_hko_iter.use_up:
        valid_batch, valid_mask, sample_datetimes, new_start =\
            valid_hko_iter.sample(batch_size=minibatch_size)
        if valid_batch.shape[1] == 0:
            break
        name_str = 'valid_' + str(valid_time) + '_' + sample_datetimes[0][0].strftime('%Y%m%d%H%M')
        save_hko_movie(valid_batch[:, 0, 0, :, :],
                       sample_datetimes[0],
                       valid_mask[:, 0, 0, :, :],
                       masked=False,
                       save_path=name_str + '.mp4')
        tic = time.time()
        save_hko_movie(valid_batch[:, 0, 0, :, :],
                       sample_datetimes[0],
                       valid_mask[:, 0, 0, :, :],
                       masked=True,
                       save_path=name_str + '_filtered.mp4')
        toc = time.time()
        save_hko_movie(valid_mask[:, 0, 0, :, :].astype(np.uint8) * 255,
                       sample_datetimes[0],
                       None,
                       masked=False,
                       save_path=name_str + '_mask.mp4')
        print('valid, time:', toc - tic)
        print(valid_batch.shape[1])
        valid_time += 1
    print(valid_time)
