import geopandas as gpd
from shapely.geometry import box, MultiPolygon
from shapely import wkt
from shapely.strtree import STRtree
import pickle as pk
from tqdm import tqdm
import numpy as np
import time
import h5py
from shapely.geometry import Point, Polygon



country_data =  gpd.read_file("/data/lloyds/country.csv", encoding="latin1")
country_set = set()
for i,j in country_data.iterrows():
    if j["sovereignty"] in ["Netherlands","United Kingdom of Great Britain and Northern Ireland", \
                           "Belgium","Denmark","Germany","France","Norway","Sweden","Finland"]:#
        print(j["sovereignty"])
        country_set.add(j["country_id"])

polygon = []
polygon_name = []
port_data = gpd.read_file("/data/lloyds/port.csv", encoding="latin1")
sub_data =  gpd.read_file("/data/lloyds/sub_port.csv", encoding="latin1")


idx = 0
port_set = set()
sub_ports = []
for i,j in port_data.iterrows():
   
    multipolygon = wkt.loads(j["geofence"]) 
    polygon.append(multipolygon )
    polygon_name.append(j["name"])
    if j["countryId"] in country_set: 
        port_set.add(idx)
        sub_ports.append(j["portId"])
    idx+=1

for i,j in sub_data.iterrows():
   
    multipolygon = wkt.loads(j["geofence"])
    polygon.append(multipolygon )
    polygon_name.append(j["name"])
    if j["portId"] in sub_ports:
        port_set.add(idx)  
    idx+=1
    
country_dict = dict()
for i, j in country_data.iterrows():
    
    country_dict[j["country_id"]] = j["sovereignty"]  
    
print("Number of Selected Ports", len(port_set))



LAT, LON, SOG, COG, HEAD, NAV, TIMESTAMP, TYPE, MMSI, OD = range(10)
num = 0
with h5py.File('EUR.h5', 'w') as f:

    for ship_id in tqdm(ct_set|ct_set_|pass_set|pass_set_):
        traj = get_single_user_traj(ship_id)
        if traj is None:
            continue
       
       
        valid_indices = [
            idx for idx in range(traj.shape[0])
            if traj[idx, -1] != -1 and traj[idx, SOG] == 0
        ]

       
        if len(valid_indices) < 2:
            continue

        
        for j in range(len(valid_indices) - 1):
            start_idx = valid_indices[j]
            end_idx = valid_indices[j + 1]
            origin_port = traj[start_idx, -1]
            destination_port = traj[end_idx, -1]
            
            if origin_port != destination_port and \
                    (origin_port in port_set and destination_port in port_set):
                od_segment = traj[start_idx:end_idx + 1]
                interval = np.max(np.diff(od_segment[:, TIMESTAMP]))
                valid = set(np.unique(od_segment[:, OD]) )
                
                if len(valid - {-1, origin_port, destination_port})==0:
                    f.create_dataset(str(num), data=od_segment[[0,-1],:], compression='gzip')
                    num+=1
                

print("Number of Port-to-port trajectory segments", len(od_list))
