# create dataset from realworld human behavior dataset.
# zhe @ 2023.05

import os
import pandas as pd
import numpy as np
import torch
import pickle
import glob
import json
from datetime import datetime
from call_feature_extraction import query as ft_query



def vis_time(timeStamp):
    import time
    weekday_dict = ['Monday',
                    'Tuesday',
                    'Wednesday', 
                    'Thursday', 
                    'Friday', 
                    'Saturday', 
                    'Sunday']
    timeStamp = int(timeStamp)
    localTime = time.localtime(timeStamp)
    weekday_id = localTime.tm_wday
    weekday_name = weekday_dict[weekday_id]
    styledTime = time.strftime("%Y-%m-%d_%H-%M-%S", localTime)
    return styledTime, localTime
    

def is_ne_in_df(df:pd.DataFrame):
    """
    Some raw data files contain cells with "n/e". This function checks whether
    any column in a df contains a cell with "n/e". Returns False if no columns
    contain "n/e", True otherwise
    """
    
    for col in df.columns:

        true_bool = (df[col] == "n/e")

        if any(true_bool):
            return True

    return False


def ValueHomeLoader(data_dir, test_size, eval_size, train_start, label_num: int=19, feature_num: int=1211, data_type="") -> pd.DataFrame:

    user_data_pths = sorted(glob.glob(f"{data_dir}/p*{data_type}.pkl"))
    
    # load data
    raw_training_data = []
    raw_eval_data = []
    raw_test_data = []
    for user_data_pth in user_data_pths:
        with open(user_data_pth, 'rb') as f:
            encoded_df = pickle.load(f)
        data_len = len(encoded_df)

        # Make sure all "n/e" values have been removed from df. 
        if is_ne_in_df(encoded_df):
            raise ValueError("data frame contains 'n/e' values. These must be handled.")
        elif len(encoded_df.columns) != (label_num + feature_num):
            raise ValueError("invalid column. Please check.")

        # split data
        if len(raw_training_data) == 0:
            raw_training_data = encoded_df[round(data_len*(train_start)):-(round(data_len*(test_size+eval_size)))]
            raw_eval_data = encoded_df[-(round(data_len*(test_size+eval_size))):-round(data_len*test_size)]
            raw_test_data = encoded_df[-round(data_len*test_size):]
        else:
            raw_training_data = pd.concat([raw_training_data, encoded_df[round(data_len*(train_start)):-(round(data_len*(test_size+eval_size)))]])
            raw_eval_data = pd.concat([raw_eval_data, encoded_df[-(round(data_len*(test_size+eval_size))):-round(data_len*test_size)]])
            raw_test_data = pd.concat([raw_test_data, encoded_df[-round(data_len*test_size):]])
    
    return raw_training_data, raw_eval_data, raw_test_data


def get_indices_entire_sequence(data: pd.DataFrame, window_size: int, step_size: int) -> list:
        """
        Produce all the start and end index positions that is needed to produce
        the sub-sequences. 

        Returns a list of tuples. Each tuple is (start_idx, end_idx) of a sub-
        sequence. These tuples should be used to slice the dataset into sub-
        sequences. These sub-sequences should then be passed into a function
        that slices them into input and target sequences. 
        
        Args:
            num_obs (int): Number of observations (time steps) in the entire 
                           dataset for which indices must be generated, e.g. 
                           len(data)

            window_size (int): The desired length of each sub-sequence. Should be
                               (input_sequence_length + target_sequence_length)
                               E.g. if you want the model to consider the past 100
                               time steps in order to predict the future 50 
                               time steps, window_size = 100+50 = 150

            step_size (int): Size of each step as the data sequence is traversed 
                             by the moving window.
                             If 1, the first sub-sequence will be [0:window_size], 
                             and the next will be [1:window_size].

        Return:
            indices: a list of tuples
        """

        stop_position = len(data)-1 # 1- because of 0 indexing
        
        # Start the first sub-sequence at index position 0
        subseq_first_idx = 0
        
        subseq_last_idx = window_size
        
        indices = []
        
        while subseq_last_idx <= stop_position:

            indices.append((subseq_first_idx, subseq_last_idx))
            
            subseq_first_idx += step_size
            
            subseq_last_idx += step_size

        return indices


def generate_square_subsequent_mask(dim1: int, dim2: int) -> torch.Tensor:
    """
    Generates an upper-triangular matrix of -inf, with zeros on diag.
    Modified from: 
    https://pytorch.org/tutorials/beginner/transformer_tutorial.html

    Args:

        dim1: int, for both src and tgt masking, this must be target sequence
              length

        dim2: int, for src masking this must be encoder sequence length (i.e. 
              the length of the input sequence to the model), 
              and for tgt masking, this must be target sequence length 


    Return:

        A Tensor of shape [dim1, dim2]
    """
    return torch.triu(torch.ones(dim1, dim2) * float('-inf'), diagonal=1)


def accuracy_cal_onehot(predicts, gts):

    correct_num1 = 0
    correct_num3 = 0
    correct_num5 = 0
    
    for i in range(predicts.shape[0]):
        for j in range(predicts.shape[1]):
            cpu_data = predicts[i,j,:].to('cpu')
            pred_id_list = np.argsort(np.array(cpu_data))
            pred_1_list = pred_id_list[-1:]
            pred_3_list = pred_id_list[-3:]
            pred_5_list = pred_id_list[-5:]
            gt_id = int(torch.argmax(gts[i,j,:]))
            
            if gt_id in pred_1_list:
                correct_num1 += 1
            if gt_id in pred_3_list:
                correct_num3 += 1
            if gt_id in pred_5_list:
                correct_num5 += 1

    return correct_num1, correct_num3, correct_num5


def accuracy_cal_duration(predicts, gts):
    delta = 0
    for i in range(predicts.shape[0]):
        for j in range(predicts.shape[1]):
            
            dis = pow((predicts[i,j,:] - gts[i,j,:]), 2)
            dis = dis / (pow(gts[i,j,:], 2) + delta)

    return dis


def cos_sim(a, b):
    """
    Computes the cosine similarity cos_sim(a[i], b[j]) for all i and j.
    :return: Matrix with res[i][j]  = cos_sim(a[i], b[j])
    copy from sentenceTransformer.utils.cos_sim()
    """
    if not isinstance(a, torch.Tensor):
        a = torch.tensor(a)

    if not isinstance(b, torch.Tensor):
        b = torch.tensor(b)

    if len(a.shape) == 1:
        a = a.unsqueeze(0)

    if len(b.shape) == 1:
        b = b.unsqueeze(0)

    a_norm = torch.nn.functional.normalize(a, p=2, dim=1)
    b_norm = torch.nn.functional.normalize(b, p=2, dim=1)
    
    return torch.mm(a_norm, b_norm.transpose(0, 1))


def accuracy_cal_target(predicts, gts, userId=-1, human_flag=False, templates=None):
    if isinstance(templates, dict):
        target_templates = templates
    else:    
        if human_flag:
            target_templates = pickle.load(open("human_target_template.pkl", "rb"))
        else:
            target_templates = pickle.load(open("gpt_target_template.pkl", "rb"))

    if userId == -1:
        target_json = target_templates["all"]
    else:
        target_json = target_templates["p{0}".format(str(userId).zfill(2))]


    correct_num1 = 0
    correct_num3 = 0
    correct_num5 = 0
    for i in range(predicts.shape[0]):
        for j in range(predicts.shape[1]):
            predict_sim = cos_sim(predicts[i,j,:].to('cpu'), target_json["embeddings"])
            pred_id_list = np.argsort(np.array(predict_sim))[0]
            pred_1_list = pred_id_list[-1:]
            pred_3_list = pred_id_list[-3:]
            pred_5_list = pred_id_list[-5:]

            gt_sim = cos_sim(gts[i,j,:].to('cpu'), target_json["embeddings"])
            gt_id = int(gt_sim.argmax())
            
            if gt_id in pred_1_list:
                correct_num1 += 1
            if gt_id in pred_3_list:
                correct_num3 += 1
            if gt_id in pred_5_list:
                correct_num5 += 1

    return correct_num1, correct_num3, correct_num5