import numpy as np
import matplotlib.pyplot as plt
import ot
import math
from funcs import *
import os
from tqdm import tqdm
cur_location = os.getcwd()
from datetime import date, timedelta
import netCDF4
from pathlib import Path
import pickle
np.set_printoptions(precision=3)
pre_path = "/../../MRMSdata/nsslMosaic2D/cref_10min_DFW/"
format_suffix = ".mdv.cf.nc"
epsilon = 0.01
lamda = 0.1

counter = 0
mat_list = []
idx_list = []
time_list = []
yrs=range(2016, 2022 + 1)
for yyyy in yrs:
    counter += 1000000 # use to seperate different year
    mth_start, mth_end = 3, 10

    start_date = date(yyyy, mth_start, 1)
    end_date = date(yyyy, mth_end, 1)
    time_between_dates = end_date - start_date
    days_between_dates = time_between_dates.days
    for dd in tqdm(range(days_between_dates)):
        current_date = start_date + timedelta(days=dd)
        cond1 = (yyyy == 2014)
        cond2 = yyyy == 2015 and current_date.month <= 3 and current_date.day <= 11
        if cond1 or cond2:
            header = "mosaicked_refl"
        else:
            header = "MREF"

        file_date = f"{yyyy:04}{current_date.month:02}{current_date.day:02}"
        pre_name = cur_location + pre_path + file_date + "/"
        # print(pre_name)
        for hh in range(24):
            for mm in range(0, 59, 10):
                counter += 1
                isotime = f"{current_date}T{hh:02}:{mm:02}"
                hr_n_min = f"{hh:02}{mm:02}00"
                file_name = pre_name + file_date + "_" + hr_n_min + format_suffix
                if os.path.exists(file_name):
                    f = netCDF4.Dataset(file_name)
                    cur_mat = f.variables[header][0, 0, ::-1, :].data
                    non_neg_mat = np.where(cur_mat >= 35, cur_mat, 0)
                    mat_list.append(non_neg_mat)
                    idx_list.append(counter)
                    time_list.append(isotime)

MRMS = dataset(mat_list, idx_list, time_list)
file_name = f"{cur_location}/../../Data/MRMS16-22.dat"
with open(file_name, 'wb') as file:  # 'wb' denotes write binary mode
    pickle.dump(MRMS, file)