import sys
import math
import glob
import tqdm
import torch
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mtdt




class scar_dataset_facthmm(torch.utils.data.Dataset):
    def __init__(self, dataset_path=None, dataset_stats=None, task=0, decoded_path=None):
        self.data_dim = 3
        self.dataset_stats = dataset_stats
        self.task = task
        self.s = [] 
        if task==1:
            cols = ["individual_id","abbreviated_name"]
            df = pd.read_csv(dataset_path, sep=",", engine='python', usecols=cols,  dtype={"abbreviated_name":np.int64})
            df = pd.concat([df, pd.read_csv(decoded_path, sep=",", engine='python')], axis=1)
            self.n_class = 17 
            for id_ in df["individual_id"].unique():
                df_id = df.loc[df["individual_id"]==id_]
                self.s.append(torch.from_numpy(df_id[["state","abbreviated_name"]].values))
        elif task==2:
            cols = ["individual_id","age_class"]
            df = pd.read_csv(dataset_path, sep=",", engine='python', usecols=cols,  dtype={"age_class":np.int64})
            df = pd.concat([df, pd.read_csv(decoded_path, sep=",", engine='python')], axis=1)
            self.n_class = 5
            #self.class_cnt = torch.tensor([1016,102,70,20,12])
            for id_ in df["individual_id"].unique():
                df_id = df.loc[(df["individual_id"]==id_) & (df["age_class"]!=-1)]
                self.s.append(torch.from_numpy(df_id[["state","age_class"]].values)) if len(df_id)>0 else None
        elif task==3:
            cols = ["individual_id","sex"]
            df = pd.read_csv(dataset_path, sep=",", engine='python', usecols=cols,  dtype={"sex":np.int64})
            df = pd.concat([df, pd.read_csv(decoded_path, sep=",", engine='python')], axis=1)
            self.n_class = 2
            #self.class_cnt = torch.tensor([1095,522])
            for id_ in df["individual_id"].unique():
                df_id = df.loc[(df["individual_id"]==id_) & (df["sex"]!=-1)]
                self.s.append(torch.from_numpy(df_id[["state","sex"]].values)) if len(df_id)>0 else None
        elif task==4:
            cols = ["individual_id","breeding_stage"]
            df = pd.read_csv(dataset_path, sep=",", engine='python', usecols=cols,  dtype={"breeding_stage":np.int64})
            df = pd.concat([df, pd.read_csv(decoded_path, sep=",", engine='python')], axis=1)
            self.n_class = 14 
            #self.class_cnt = torch.tensor([106127,12685,31144,2065,24509,7883,148221,34290,8517,1371,6281,235,480,396])
            for id_ in df["individual_id"].unique():
                df_id = df.loc[(df["individual_id"]==id_) & (df["breeding_stage"]!=-1)]
                self.s.append(torch.from_numpy(df_id[["state","breeding_stage"]].values)) if len(df_id)>0 else None
        else:
            cols = ["individual_id","X","Y","Z","t_diff"]
            df = pd.read_csv(dataset_path, sep=",", engine='python', usecols=cols, dtype={"X":np.float32,"Y":np.float32,"Z":np.float32,"t_diff":np.float32})
            self.n_class = None
            self.class_cnt = None
            for id_ in df["individual_id"].unique():
                df_id = df.loc[df["individual_id"]==id_]
                self.s.append(torch.from_numpy(df_id[["X","Y","Z","t_diff"]].values))
 
    def __getitem__(self, index):
        return self.s[index]

    def __len__(self):
        return len(self.s)


class taxi_dataset_facthmm(torch.utils.data.Dataset):
    def __init__(self, dataset_path=None, dataset_stats=None):
        self.data_dim = 2
        self.dataset_stats = dataset_stats
        df = pd.read_csv(dataset_path, sep=",", engine='python',dtype={"Longitude":np.float32,"Latitude":np.float32,"t_diff":np.float32})
        alltid = df["taxi_id"].unique()     
        self.s = []        
        for tid in alltid: 
            ataxi = df.loc[df["taxi_id"]==tid]
            self.s.append(torch.from_numpy(ataxi[["Longitude","Latitude","t_diff"]].values))            
 
    def __getitem__(self, index):
        return self.s[index] 

    def __len__(self):
        return len(self.s)
        

class lrff_dataset_facthmm(torch.utils.data.Dataset):
    def __init__(self, dataset_path=None, dataset_stats=None):
        self.data_dim = 2
        self.dataset_stats = dataset_stats
        df = pd.read_csv(dataset_path, sep=",", engine='python',dtype={"Longitude":np.float32,"Latitude":np.float32,"t_diff":np.float32})
        alltid = df["PTT"].unique()
        self.s = []        
        for tid in alltid: 
            ataxi = df.loc[df["PTT"]==tid]
            self.s.append(torch.from_numpy(ataxi[["Longitude","Latitude","t_diff"]].values))
    
    def __getitem__(self, index):
        return self.s[index]
    
    def __len__(self):
        return len(self.s)

        
def collate_batch(batch):        
    s = torch.nn.utils.rnn.pack_sequence(batch, enforce_sorted=False)  
    return s
    

class CustomLoader(torch.utils.data.DataLoader):

    def __init__(self, max_obs, dataset=None, mode="train", shuffle=True, num_workers=0, progress_bar=True):
        super().__init__(dataset=dataset, shuffle=shuffle, num_workers=num_workers)
        self.max_obs = max_obs
        self.mode = mode
        self.no_bar = not progress_bar
    
    def __iter__(self):
        batch_len = 0
        batch = []
        for idx in tqdm.tqdm(self.sampler, smoothing=0, mininterval=1.0, file=sys.stdout, disable=self.no_bar):
            batch_i = self.dataset[idx]
            len_batch_i = batch_i.size(0)
            if self.mode=="train":
                batch_len += int(len_batch_i*0.8)
                batch.append(batch_i[:int(len_batch_i*0.8)])
            elif self.mode=="test":
                batch_len += len_batch_i
                batch.append(batch_i)
            else:
                batch_len += (len_batch_i - int(len_batch_i*0.8))
                batch.append(batch_i[int(len_batch_i*0.8):])
            while batch_len >= self.max_obs:
                if batch_len == self.max_obs:
                    yield collate_batch(batch)
                    batch_len = 0
                    batch = []
                else:
                    return_batch, batch, batch_len = batch[:-1], batch[-1:], batch[-1].size(0)
                    yield collate_batch(return_batch)
        if len(batch) > 0:
            yield collate_batch(batch)


def get_data_n_stats(dataset_path,task=0,decoded_path=None):
    if "taxi" in dataset_path:
        dataset_mean = [[116.40801498508985, 39.91548332789868]]
        dataset_std = [[0.10802431873215311, 0.09817366369732516]]
        dataset_stats = [dataset_mean, dataset_std]
        dataset = taxi_dataset_facthmm(dataset_path, dataset_stats)
    elif "scar" in dataset_path:
        dataset_mean = [[969575.6856847513, 419860.75879995385, -5350214.682521352]]
        dataset_std = [[2291630.8851368325, 2262637.5868772673, 610694.1851632163]]
        dataset_stats = [dataset_mean, dataset_std]
        dataset = scar_dataset_facthmm(dataset_path, dataset_stats, task, decoded_path)
    elif "lrff" in dataset_path:
        dataset_mean = [[145.11196719704435, -19.2005747817734]]
        dataset_std = [[4.718201911499462, 4.572033288285534]]
        dataset_stats = [dataset_mean, dataset_std]
        dataset = lrff_dataset_facthmm(dataset_path, dataset_stats) 
    else:
        raise ValueError('unknown dataset name')
    return dataset


if __name__ == '__main__':
    dataset_path = ''
    dataset = taxi_dataset_facthmm(dataset_path)





