import os
import math
import glob
import datetime
import torch
import random
import argparse
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mtdt
from statistics import mean
from pyproj import Transformer
from sklearn.preprocessing import scale


def haversine(lat1, lon1, lat2, lon2, to_radians=True, earth_radius=6371):
    """

    Calculate the great circle distance between two points
    on the earth (specified in decimal degrees or in radians)

    All (lat, lon) coordinates must have numeric dtypes and be of equal length.

    """
    if to_radians:
        lat1, lon1, lat2, lon2 = np.radians([lat1, lon1, lat2, lon2])

    a = np.sin((lat2-lat1)/2.0)**2 + \
        np.cos(lat1) * np.cos(lat2) * np.sin((lon2-lon1)/2.0)**2

    return earth_radius * 2 * np.arcsin(np.sqrt(a)) #distance in km



class lonlat_to_3d(object):

    def __init__(self):
        self.transformer_lonlat = Transformer.from_crs(
                                 {"proj":'latlong', "ellps":'WGS84', "datum":'WGS84'},
                                 {"proj":'geocent', "ellps":'WGS84', "datum":'WGS84'},
                                                       )
        self.transformer_xyz   = Transformer.from_crs( 
                                 {"proj":'geocent', "ellps":'WGS84', "datum":'WGS84'},
                                 {"proj":'latlong', "ellps":'WGS84', "datum":'WGS84'},
                                                       )

    def transform(self, lon, lat):
        x,y,z = self.transformer_lonlat.transform(lon,lat,0,radians = False)
        return x,y,z

    def inv_transform(self, x, y, z):
        lon,lat,h = self.transformer_xyz.transform(x,y,z, radians = False)
        return lon,lat


def taxi(dataset_path=None):
    min_len = 3
    all_df=[]
    afn = glob.glob(dataset_path+"/*.txt")
    t_diff_min = 1e7
    start_time = datetime.datetime(2008, 2, 1, 0, 0)
    end_time = datetime.datetime(2008, 2, 9, 23, 59)
    for fn in afn:
        try:
            df = pd.read_csv(fn,sep=",",engine='python',header=None,dtype={2:"float64", 3:"float64"}) 
        except:
            #print("%s is empty!" %(fn))
            continue
        df.drop_duplicates(inplace=True)
        df.dropna(inplace=True)
        df[1] = pd.to_datetime(df[1])   
        t_mask = (df[1] >= start_time) & (df[1] <= end_time)
        df = df.loc[t_mask]
        df = df.sort_values(by=1)
        df["t_diff"] = df[1].diff(1).dt.total_seconds().fillna(0)         
        df[1] = df["t_diff"].cumsum()
        
#        least_recent_date = df[1].min()
#        df[1] = df[1]  -  least_recent_date
#        df[1] = df[1].dt.total_seconds()
        df = df.astype(np.float64)
        df = df.groupby(1).mean().reset_index()
        #==========remove outliers by geolocation================
        #df = df.loc[(df[2] <= 120) & (df[2] >= 110) & (df[3] <= 42) & (df[3] >= 38)]
        if not df[2].between(116.1, 116.8).all():
            continue
        if not df[3].between(39.5, 40.3).all():
            continue
        #==========remove outliers by moving speed ====================
        skip = 0
        for _ in range(df.shape[0]):
            init_L = df.shape[0]
            if init_L<min_len:
                break
            else:
                #if not df[1].between(0,259200.0).all():
                if not df["t_diff"].between(0,86400.0).all():
                    skip=1
                    break
                df["speed"] = haversine(df[3], df[2], df[3].shift(), df[2].shift()) / df["t_diff"]
                if df.shape[0]<min_len:
                    break
                df.iloc[0,df.columns.get_loc("speed")] = df["speed"].iloc[1]
                df = df.loc[(df["speed"] < 0.042)]# & (df["speed"] >= 1e-6)]  #0.1=360km/h
                if init_L==df.shape[0]:
                    break
                df["t_diff"] = df[1].diff().fillna(0) 
        #=======================================================
        if skip or (df.shape[0]<min_len):
            #print("%s is singleton!" %(fn))
            continue  
        df[1] = df["t_diff"]
        t_diff_min = min(t_diff_min, df[1][1:].min())
        all_df.append(df[[0,2,3,1]])
    random.seed(6)
    random.shuffle(all_df)
    lens_min = min(map(len,all_df))
    lens_max = max(map(len,all_df))
    print("n_sequences:", len(all_df))
    print("min length:", lens_min)
    print("max length:", lens_max)
    print("average length:", mean(map(len,all_df)))

    all_df = pd.concat(all_df,ignore_index=True).astype(np.float64)
    all_df.rename({0:"taxi_id",2:"Longitude",3:"Latitude",1:"t_diff"},axis=1,inplace=True)
    all_df["taxi_id"] = all_df["taxi_id"].astype(int)
    print("observations:", all_df.shape[0])
    print("mean time diff before: %f" %(all_df["t_diff"].mean()))
    print("min time diff before: %f" %(t_diff_min))
    print("max time diff before: %f" %(all_df["t_diff"].max()))
    print("std time diff before: %f" %(all_df["t_diff"].std()))
    all_df["t_diff"] = all_df["t_diff"]/all_df["t_diff"].mean()
    print("mean time diff: %f" %(all_df["t_diff"].mean()))
    print("min time diff: %f" %(all_df["t_diff"].min()))
    print("max time diff: %f" %(all_df["t_diff"].max()))
    print("std time diff: %f" %(all_df["t_diff"].std()))

    lonlat_max = all_df[["Longitude","Latitude"]].max().values.tolist()
    lonlat_min = all_df[["Longitude","Latitude"]].min().values.tolist()
    lonlat_mean = all_df[["Longitude","Latitude"]].mean()
    lonlat_std = all_df[["Longitude","Latitude"]].std()
    print("lon_lat max: {0}, {1}".format(lonlat_max[0], lonlat_max[1]))
    print("lon_lat min: {0}, {1}".format(lonlat_min[0], lonlat_min[1]))
    print("lon_lat mean: {0}, {1}".format(lonlat_mean.values.tolist()[0], lonlat_mean.values.tolist()[1]))
    print("lon_lat std: {0}, {1}".format(lonlat_std.values.tolist()[0], lonlat_std.values.tolist()[1]))

    all_df[["Longitude","Latitude"]] = ( all_df[["Longitude","Latitude"]] - lonlat_mean ) / lonlat_std

    all_df.plot(kind='scatter', x="Longitude", y="Latitude", s=0.1, marker=",", linewidths=0)
    plt.savefig("taxiplot.png", dpi=5000)
    all_df[["taxi_id","Longitude","Latitude","t_diff"]].to_csv("taxi.csv", index=False)
    #============================================        
 


def SCAR(dataset_path=None):
    min_len = 3
    all_df = []
    afn = glob.glob(os.path.join(dataset_path,"*.csv"))
    metadata = pd.read_csv(os.path.join(dataset_path, "RAATD_metadata.csv"), sep=",",engine='python',
                           usecols=["individual_id","sex","age_class"], encoding="unicode_escape")
    t_diff_min = 1e7
    latlon_3d = lonlat_to_3d()
    start_time = datetime.datetime(1990, 12, 31, 0 ,0)
    end_time = datetime.datetime(2017, 1, 1, 23, 59)
    for fn in afn:
        if "RAATD_metadata.csv" in fn:
            continue
        df = pd.read_csv(fn,sep=",",engine='python')
        df = df.loc[df["location_to_keep"]==1]
        df.drop_duplicates(inplace=True)
        #==========choose a subset by geolocation ===========
        #df = df.loc[(df["decimal_latitude"] <= 10) & (df["decimal_latitude"] >= -90)]
        #====================================================
        df[["h","m","s"]] = df["time"].str.split(":",expand=True)
        df["DateTime"] = pd.to_datetime(df[["year", "month", "day", "h","m","s"]]) 
        t_mask = (df["DateTime"] >= start_time) & (df["DateTime"] <= end_time)
        df = df.loc[t_mask] 
        if df.shape[0]<1:
            print("%s is empty!" %(fn))
            continue
        if df.shape[0]==1:
            print("%s contains single observation!" %(fn))
            continue
        if df.shape[0] < min_len:
            print("%s contains 2 observation!" %(fn))
            continue

        df = df[["DateTime","individual_id","abbreviated_name","breeding_stage","decimal_longitude","decimal_latitude"]] 
        df.rename({"decimal_longitude":"Longitude","decimal_latitude":"Latitude"},axis=1,inplace=True)
        for id_ in df["individual_id"].unique():
            df_id = df.loc[df["individual_id"]==id_] 
            #==============excludng some outliers==========
            if id_=="MAW1995_emp_f_x_516":
                df_id = df_id[df_id["Latitude"]<0]
            if id_=="43841_ws_BASct43.csv":
                df_id = df_id[df_id["Latitude"]<-31]
            if id_=="57332":
                df_id = df_id[df_id["Latitude"]<-35]
            if id_=="ct23x-exF-08":
                df_id = df_id[df_id["Latitude"]<-19]
            if id_=="unknown119":
                df_id = df_id[df_id["Latitude"]<-47]
            if id_=="123232_AP2013":
                df_id = df_id[df_id["Longitude"]>-120]

            #=============================================== 
             
            df_id = df_id.sort_values(by="DateTime")
            df_id = df_id.groupby("DateTime", as_index=False).agg({"individual_id":"first","abbreviated_name":"first","breeding_stage":"first","Longitude":"mean","Latitude":"mean"})
            df_id["t_diff"] = df_id["DateTime"].diff(1).dt.total_seconds().fillna(0)            
            df_id["time"] = df_id["t_diff"].cumsum()
            #==========remove outliers by moving speed====================
            for _ in range(df_id.shape[0]):
                init_L = df_id.shape[0]
                if init_L < min_len:
                    break
                else:
                    #if not df[1].between(0,259200.0).all():
                    #if not df["t_diff"].between(0,86400.0).all():
                    #    skip=1
                    #    break
                    df_id["speed"] = haversine(df_id["Latitude"], df_id["Longitude"], df_id["Latitude"].shift(), df_id["Longitude"].shift()) / df_id["t_diff"]
                    if df_id.shape[0]<min_len:
                        break  
                    df_id.iloc[0,df_id.columns.get_loc("speed")] = df_id["speed"].iloc[1]
                    df_id = df_id.loc[(df_id["speed"] < 0.11)]  #0.1=360km/h
                    if init_L==df_id.shape[0]:
                        break
                    df_id["t_diff"] = df_id["time"].diff().fillna(0) 
            #=======================================================
            if df_id.shape[0] < min_len:
                print("%s less than 3 observations after cleaning!" %(id_))
                continue  
            #=========================================================
            df_id["X"], df_id["Y"], df_id["Z"] = zip(*df_id.apply(lambda x: latlon_3d.transform(x["Longitude"], x["Latitude"]), axis=1))
            all_df.append(df_id[["individual_id","abbreviated_name","breeding_stage","Longitude","Latitude","X","Y","Z","t_diff"]])
            t_diff_min = min(t_diff_min, df_id["t_diff"][1:].min())
    random.seed(6)
    random.shuffle(all_df)
    lens_min = min(map(len,all_df))
    lens_max = max(map(len,all_df))
    print("n_sequences:", len(all_df))
    print("min length:", lens_min)
    print("max length:", lens_max)
    print("average length:", mean(map(len,all_df)))
    
    all_df = pd.concat(all_df,ignore_index=True)
    all_df = all_df.merge(metadata, on="individual_id")
    all_df = all_df.astype(dtype= {"individual_id":"object","abbreviated_name":"object","breeding_stage":"object","sex":"object","age_class":"object",
                                    "Longitude":"float64","Latitude":"float64","X":"float64","Y":"float64","Z":"float64","t_diff":"float64"})
    
    all_df["breeding_stage"] = all_df["breeding_stage"].str.lower()
    all_df["breeding_stage"] = all_df["breeding_stage"].str.replace("-"," ")
    all_df["breeding_stage"] = all_df["breeding_stage"].str.replace("_"," ")

    print("observations:", all_df.shape[0])
    print("mean time diff before: %f" %(all_df["t_diff"].mean()))
    print("min time diff before: %f" %(t_diff_min))
    print("max time diff before: %f" %(all_df["t_diff"].max()))
    print("std time diff before: %f" %(all_df["t_diff"].std()))
    all_df["t_diff"] = all_df["t_diff"]/all_df["t_diff"].mean()
    print("mean time diff: %f" %(all_df["t_diff"].mean()))
    print("min time diff: %f" %(all_df["t_diff"].min()))
    print("max time diff: %f" %(all_df["t_diff"].max()))
    print("std time diff: %f" %(all_df["t_diff"].std()))
    lonlat_max = all_df[["Longitude","Latitude"]].max().values.tolist()
    lonlat_min = all_df[["Longitude","Latitude"]].min().values.tolist()
    lonlat_mean = all_df[["Longitude","Latitude"]].mean().values.tolist()
    lonlat_std = all_df[["Longitude","Latitude"]].std().values.tolist()
    print("lon_lat max: {0}, {1}".format(lonlat_max[0], lonlat_max[1]))
    print("lon_lat min: {0}, {1}".format(lonlat_min[0], lonlat_min[1]))
    print("lon_lat mean: {0}, {1}".format(lonlat_mean[0], lonlat_mean[1]))
    print("lon_lat std: {0}, {1}".format(lonlat_std[0], lonlat_std[1]))
    
    xyz_max = all_df[["X","Y","Z"]].max().values.tolist()
    xyz_min = all_df[["X","Y","Z"]].min().values.tolist()
    xyz_mean = all_df[["X","Y","Z"]].mean()
    xyz_std = all_df[["X","Y","Z"]].std()

    all_df[["X","Y","Z"]] = (all_df[["X","Y","Z"]] - xyz_mean) / xyz_std

    print("X Y Z max: {0}, {1}, {2}".format(xyz_max[0], xyz_max[1], xyz_max[2]))
    print("X Y Z min: {0}, {1}, {2}".format(xyz_min[0], xyz_min[1], xyz_min[2]))
    print("X Y Z mean: {0}, {1}, {2}".format(xyz_mean.values.tolist()[0], xyz_mean.values.tolist()[1], xyz_mean.values.tolist()[2]))
    print("X Y Z std: {0}, {1}, {2}".format(xyz_std.values.tolist()[0], xyz_std.values.tolist()[1], xyz_std.values.tolist()[2]))
    print("abbreviated_name")
    print(all_df.abbreviated_name.value_counts(dropna=False))
    print("breeding_stage")
    print(all_df.breeding_stage.value_counts(dropna=False))
    print("age_class")
    print(all_df.age_class.value_counts(dropna=False))
    print("sex")
    print(all_df.sex.value_counts(dropna=False))
    
    all_df["breeding_stage"] = all_df["breeding_stage"].replace("unknown",np.nan)
    all_df["sex"] = all_df["sex"].replace("U",np.nan)
    
    all_df["abbreviated_name"], abbreviated_name_key = pd.factorize(all_df["abbreviated_name"])
    all_df["breeding_stage"], breeding_stage_key = pd.factorize(all_df["breeding_stage"])
    all_df["age_class"], age_class_key = pd.factorize(all_df["age_class"])
    all_df["sex"] , sex_key = pd.factorize(all_df["sex"])
    print(list(zip(abbreviated_name_key,range(len(abbreviated_name_key)))))
    print(list(zip(breeding_stage_key,range(len(breeding_stage_key)))))
    print(list(zip(age_class_key,range(len(age_class_key)))))
    print(list(zip(sex_key,range(len(sex_key)))))
    print("\n")

    all_df.plot(kind='scatter', x="Longitude", y="Latitude", s=0.1, marker=",", linewidths=0)
    plt.savefig("scar_plot.png", dpi=1000)
    all_df.to_csv("scar.csv", index=False)
    

def lrff(dataset_path=None):
    min_len = 3
    df = pd.read_csv(dataset_path, sep=",", engine='python')
    df.drop_duplicates(subset=["PTT","Longitude","Latitude","DateTime"], inplace=True)
    df.dropna(subset=["PTT","Longitude","Latitude","DateTime"], inplace=True)
    PTT = df["PTT"].unique()
    #====data preprocessing====
    df["DateTime"] = pd.to_datetime(df["DateTime"])
    #===========================================
    least_recent_date = df["DateTime"].min()
    df["DateTime"] = df["DateTime"]  -  least_recent_date
    df["DateTime"] = df["DateTime"].dt.total_seconds()/1.

    #==========================
    all_df = []
    t_diff_min = 1e7
    for lrff in PTT:
        df1 = df.loc[df["PTT"]==lrff,["PTT","Longitude","Latitude","DateTime"]].sort_values(by=["DateTime"])
        df1 = df1.astype(np.float64)
        df1 = df1.groupby('DateTime').mean().reset_index()
        if df1.shape[0]<=min_len:
            continue
        df1["t_diff"] = df1["DateTime"].diff().fillna(0)
        t_diff_min = min(t_diff_min, df1["t_diff"][1:].min())
        all_df.append(df1[["PTT","Longitude","Latitude","t_diff"]])
    random.seed(6)
    random.shuffle(all_df)
    lens_min = min(map(len,all_df))
    lens_max = max(map(len,all_df))
    print("n_sequences:", len(all_df))
    print("min length:", lens_min)
    print("max length:", lens_max)
    print("average length:", mean(map(len,all_df)))
    all_df = pd.concat(all_df,ignore_index=True).astype(np.float64)
    all_df["PTT"] = all_df["PTT"].astype(int)
    print("observations:", all_df.shape[0])
    print("mean time diff before: %f" %(all_df["t_diff"].mean()))
    print("min time diff before: %f" %(t_diff_min))
    print("max time diff before: %f" %(all_df["t_diff"].max()))
    print("std time diff before: %f" %(all_df["t_diff"].std()))
    all_df["t_diff"] = all_df["t_diff"]/all_df["t_diff"].mean()
    print("mean time diff: %f" %(all_df["t_diff"].mean()))
    print("min time diff: %f" %(all_df["t_diff"].min()))
    print("max time diff: %f" %(all_df["t_diff"].max()))
    print("std time diff: %f" %(all_df["t_diff"].std()))
    lonlat_max = all_df[["Longitude","Latitude"]].max().values.tolist()
    lonlat_min = all_df[["Longitude","Latitude"]].min().values.tolist()
    lonlat_mean = all_df[["Longitude","Latitude"]].mean()
    lonlat_std = all_df[["Longitude","Latitude"]].std()
    print("lon_lat max: {0}, {1}".format(lonlat_max[0], lonlat_max[1]))
    print("lon_lat min: {0}, {1}".format(lonlat_min[0], lonlat_min[1]))
    print("lon_lat mean: {0}, {1}".format(lonlat_mean.values.tolist()[0], lonlat_mean.values.tolist()[1]))
    print("lon_lat std: {0}, {1}".format(lonlat_std.values.tolist()[0], lonlat_std.values.tolist()[1]))

    all_df.plot(kind='scatter', x="Longitude", y="Latitude", s=0.5, marker=",", linewidths=0)
    plt.savefig("lrffplot.png", dpi=1000)
    all_df[["Longitude","Latitude"]] = ( all_df[["Longitude","Latitude"]] - lonlat_mean ) / lonlat_std
    all_df.to_csv("lrff.csv", index=False)


 
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Calculate stats of data")
    parser.add_argument('--dataset-path', type=str, default="", help="path to the dataset") 
    args = parser.parse_args()
    if "taxi" in args.dataset_path:
        taxi(args.dataset_path)
    elif "LRFF" in args.dataset_path:
        lrff(args.dataset_path)
    elif "SCAR" in args.dataset_path: 
        SCAR(args.dataset_path)
    else:
        print('unknown dataset')







