import time
import datetime
import torch
import numpy as np
import random
import sys
import argparse

def parse_arguments():
    # train args
    parser = argparse.ArgumentParser(description='context2seq-NAT (PyTorch version)')
    parser.add_argument('--base_dir', type=str, help='Output base dir')
    parser.add_argument('--target', default="heart_rate", help='derived_speed or heart_rate')

    parser.add_argument('--patience', default=10, type=int, help='Patience for early stopping')
    parser.add_argument('--epoch', default=60, type=int, help='Max epochs')
    parser.add_argument('--attributes', default="userId", help='Attributes to include')
    parser.add_argument('--include_device', action='store_true', help='Include device')
    parser.add_argument('--input_attributes', default="distance,altitude,time_elapsed", help='Input features')
    parser.add_argument('--pretrain', action='store_true',
                        help='Use pretrain model weights')
    parser.add_argument('--fn', default="endomondoHR_proper.json", help='Path to original data')
    parser.add_argument('--pretrain_file', default="", help='Pretrain file name (placeholder in this example)')
    parser.add_argument('--trainValidTestFN', default="/xx/pkl",
                        help="Path to the pkl file containing train/valid/test splits + contextMap")
    parser.add_argument('--dataset', default="fitrec",
                        help="Path to the pkl file containing train/valid/test splits + contextMap")
    parser.add_argument('--eval', action='store_true',
                        help='Eval mode')
    parser.add_argument('--temporal', action='store_true', help='Use temporal inputs (context inputs)')

    parser.add_argument('--batch_size', default=64, type=int, help='Batch size')
    parser.add_argument('--attribute_dim', default=5, type=int, help='Dimension of user/sport/gender embeddings')
    parser.add_argument('--hidden_dim', default=128, type=int, help='Dimension of LSTM hidden state')
    parser.add_argument('--lr', default=0.0002, type=float, help='Learning rate')

    parser.add_argument('--weight_decay', default=0, type=float, help='Global L2 weight decay')

    parser.add_argument('--model_type', default="transformer2", help='LSTM, transformer, MLP')
    parser.add_argument('--contrastive_loss', action='store_true',help='Using Contrastive loss')
    parser.add_argument('--use_sport', action='store_true', help='Including sport in contrastive loss')
    parser.add_argument('--device', default=None, help='Execution Device')
    parser.add_argument('--full_temporal', action='store_true', help='Including time-aware full temporal attention')
    parser.add_argument('--limit_full_temporal_length',type=int, default=None, help='Limit full temporal length')
    parser.add_argument('--num_workers',type=int, default=0, help='Num of workers for DataLoader')
    parser.add_argument('--feature_dropout',action='store_true', help='Use context modality dropout')
    parser.add_argument('--advanced_feature_dropout', action='store_true', help='Use advanced(must switch on first) context modality dropout')
    parser.add_argument('--clip_grad_norm', type=float, default=None, help='Use context modality dropout')
    parser.add_argument('--contrastive_weight', type=float, default=0.1, help='contrastive weight')
    args = parser.parse_args()
    return args

def convert_unix_to_datetime(unix_timestamp):
    local_time_tuple = time.localtime(unix_timestamp)
    return datetime.datetime(*local_time_tuple[:6])

def normalize_array(values, z_multiple=5):
    mean_val = values.mean()
    std_val = values.std() if values.std() > 0 else 1e-6
    z_score = (values - mean_val) / std_val
    return z_score * z_multiple

def get_device():
    if torch.backends.mps.is_available() and torch.backends.mps.is_built():
        return torch.device("mps")
    elif torch.cuda.is_available():
        return torch.device("cuda")
    else:
        print("@@@@@@@@@@@@[ERROR] CUDA is not available. Exiting program.")
        sys.exit(1)

def set_random_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def to_builtin(x):
    if isinstance(x, (np.integer,)):
        return int(x)
    if isinstance(x, (np.floating,)):
        return float(x)
    if torch.is_tensor(x) and x.ndim == 0:
        return to_builtin(x.item())
    return x

def to_jsonable(obj):
    if isinstance(obj, np.ndarray):
        return obj.tolist()
    if torch.is_tensor(obj):
        return to_jsonable(obj.cpu().numpy())

    if isinstance(obj, dict):
        return {to_builtin(k): to_jsonable(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [to_jsonable(v) for v in obj]

    return to_builtin(obj)