import numpy as np
import matplotlib.pyplot as plt
import ot
import math
from tqdm import tqdm
from datetime import date, timedelta
import netCDF4
from pathlib import Path
import pickle
from globalcontrol import *
from funcs import *
from RWp_OT import *
import torch
epsilon = 0.01
lamda = 0.1
effective_threshold = 0.01

# load raw data
file_name = f"{cur_location}/../../Data/MRMS16-22.dat"
with open(file_name,'rb') as f:
    MRMS = pickle.load(f)
mat_list, idx_list, time_list = MRMS.mat, MRMS.idx, MRMS.time
print("overall data", len(idx_list))

downsampled_mat_list = downsampling(mat_list, factor= 5)
downsampled_MRMS_dist = [mat2dist(mat, threshold = effective_threshold, mode = "normal") for mat in tqdm(downsampled_mat_list)]

def compare_and_sort(source_idx, seq_length, dataset_dist, p, maxiter = 100, eps = 0.01):
    #check the seq is valid or not first,
    # if valid, then use the distribution for compute, the distribution format has been prestore in MRMS_dist

    source_seq_dist = generate_seq_dist(dataset_dist, source_idx, seq_length)

    dis_and_idx = []
    idx_range = range(len(dataset_dist[:-seq_length]))
    for target_idx in tqdm(idx_range):
        if seq_check(dataset_dist, target_idx, seq_length):
            target_seq_dist = generate_seq_dist(dataset_dist, target_idx, seq_length)
            dis = compute_RWp(source_seq_dist, target_seq_dist, p, eta2 = 0.01, eps2 = 1e-2, maxiter = 100)
        else:
            dis = 100000
        dis_and_idx.append([dis,target_idx])
    sorted_dis_and_idx = sorted(dis_and_idx, key=lambda x: x[0])
    return sorted_dis_and_idx

seq_length = 1
source_idx = 86210
p = 1
source_seq_imgs = seq_images(downsampled_mat_list, source_idx, seq_length)
# display(source_seq_imgs, nrow = seq_length)
sorted_dis_and_idx = compare_and_sort(source_idx, seq_length, downsampled_MRMS_dist, p)
idx_list = retrieving(sorted_dis_and_idx, top_k = 10, time_buffer = 100)
images, time_names = collecting_for_display(idx_list, seq_length, mat_list, time_list)

file_name = f"thm_seq:{seq_length}_rwp:{p}_source_idx:{source_idx}"
data = [sorted_dis_and_idx, images, time_names]
with open(file_name+".pkl", 'wb') as file:  # 'wb' denotes write binary mode
    pickle.dump(data, file)