import os
import argparse
import numpy as np
from utils.utils import *
from exp_AD import Exp_AD
import random
import warnings
import ast
warnings.filterwarnings('ignore')
torch.cuda.empty_cache()

def set_seed(seed=2025):
    random.seed(seed)                         # Python Random Seed
    np.random.seed(seed)                      # NumPy Seed
    torch.manual_seed(seed)                   # CPU Processing Seed
    torch.cuda.manual_seed(seed)              # GPU Processing Seed
    torch.cuda.manual_seed_all(seed)          # Multi-GPU Processing Seed
    torch.backends.cudnn.deterministic = True # cudnn Deterministic
    torch.backends.cudnn.benchmark = False    # cudnn Reproductibility 
    os.environ['PYTHONHASHSEED'] = str(seed)  # Python Hash Seed

def str2bool(v):
    return v.lower() in ('true')

def main(config):
    torch.set_float32_matmul_precision('medium')
    set_seed(2025)
    if (not os.path.exists(config.model_save_path)):
        mkdir(config.model_save_path)
    exp = Exp_AD(vars(config))

    if config.mode == 'train':
        exp.train()
    elif config.mode == 'test':
        exp.test()

    return exp


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    # Base arguments
    parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--num_epochs', type=int, default=10, help='Train epochs')
    parser.add_argument('--win_size', type=int, default=100, help='Length of window')
    parser.add_argument('--input_c', type=int, default=38, help='Input channels')
    parser.add_argument('--output_c', type=int, default=38, help='Output channels')
    parser.add_argument('--batch_size', type=int, default=128, help='Batch Size')
    parser.add_argument('--dataset', type=str, default='credit', help='Name of the dataset')
    parser.add_argument('--mode', type=str, default='train', choices=['train', 'test'], help='Execution mode')
    parser.add_argument('--data_path', type=str, default='./dataset/creditcard_ts.csv', help='Path of input dataset')
    parser.add_argument('--model_save_path', type=str, default='checkpoints', help='Directory to save trained model')
    parser.add_argument('--anomaly_ratio', type=float, default=4.00, help='Anomaly ratio threshold (r)')
    # KANomaly-specific arguments
    parser.add_argument('--patch_sizes', type=ast.literal_eval, default=[20, 10, 5, 2, 1], help='Patch sizes (multi-scale)') # (cf. Appendix D)
    parser.add_argument('--gridsize', type=int, default=5, help='The number of frequency components') # (cf. Appendix D)
    parser.add_argument('--num_layers', type=int, default=3, help='The number of KAN layers used in each Fourier-KAN Block') # (cf. Appendix D)
    parser.add_argument('--norm', type=str, default='IN', choices=['BN', 'IN', 'LN'], help='BatchNorm, InstanceNorm, or LayerNorm') # (cf. Appendix D)
    parser.add_argument('--drop', type=float, default=0.3, help='DropPath rate') # (cf. Appendix D)
    
    config = parser.parse_args()
    args = vars(config)
    print(Color.MAGENTA + '------------ Options -------------')
    for k, v in sorted(args.items()):
        print('%s: %s' % (str(k), str(v)))
    print('-------------- End ----------------' + Color.RESET)
    main(config)