import argparse
import json
import random
import numpy as np
import torch

from models import LPILLM

from data_provider.data_factory import data_provider
import os

from utils.tools import test

if __name__ == '__main__':
    os.environ['CURL_CA_BUNDLE'] = ''
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

    parser = argparse.ArgumentParser(description='LPI-LLM Evaluation')

    # basic config
    parser.add_argument('--task_name', type=str, default='long_term_forecast',
                        help='task name, options:[long_term_forecast, short_term_forecast, imputation, classification, anomaly_detection]')
    parser.add_argument('--model', type=str, default='LPI-LLM',
                        help='model name, options: [LPI-LLM]')
    parser.add_argument('--seed', type=int, help='random seed')

    # data loader
    parser.add_argument('--data', type=str, default='LPI4AI', help='dataset type')
    parser.add_argument('--root_path', type=str, default='./dataset/', help='root path of the data file')
    parser.add_argument('--data_path', type=str, default='set_cleaned_sampled.pickle', help='data file')
    parser.add_argument('--features', type=str, default='M',
                        help='forecasting task, options:[M, S, MS]; '
                            'M:multivariate predict multivariate, S: univariate predict univariate, '
                            'MS:multivariate predict univariate')
    parser.add_argument('--target', type=str, default='OT', help='target feature in S or MS task')
    parser.add_argument('--loader', type=str, default='modal', help='dataset type')
    parser.add_argument('--freq', type=str, default='h',
                        help='freq for time features encoding, '
                            'options:[s:secondly, t:minutely, h:hourly, d:daily, b:business days, w:weekly, m:monthly], '
                            'you can also use more detailed freq like 15min or 3h')
    parser.add_argument('--checkpoints', type=str, default='weights/checkpoint', help='location of model checkpoints')

    # forecasting task
    parser.add_argument('--seq_len', type=int, default=96, help='input sequence length')
    parser.add_argument('--label_len', type=int, default=48, help='start token length')
    parser.add_argument('--pred_len', type=int, default=96, help='prediction sequence length')
    parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='subset for M4')

    # model define
    parser.add_argument('--enc_in', type=int, default=7, help='encoder input size')
    parser.add_argument('--dec_in', type=int, default=7, help='decoder input size')
    parser.add_argument('--c_out', type=int, default=7, help='output size')
    parser.add_argument('--d_model', type=int, default=16, help='dimension of model')
    parser.add_argument('--n_heads', type=int, default=8, help='num of heads')
    parser.add_argument('--e_layers', type=int, default=2, help='num of encoder layers')
    parser.add_argument('--d_layers', type=int, default=1, help='num of decoder layers')
    parser.add_argument('--d_ff', type=int, default=32, help='dimension of fcn')
    parser.add_argument('--moving_avg', type=int, default=25, help='window size of moving average')
    parser.add_argument('--factor', type=int, default=1, help='attn factor')
    parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
    parser.add_argument('--embed', type=str, default='timeF',
                        help='time features encoding, options:[timeF, fixed, learned]')
    parser.add_argument('--activation', type=str, default='gelu', help='activation')
    parser.add_argument('--output_attention', action='store_true', help='whether to output attention in encoder')
    parser.add_argument('--patch_len', type=int, default=16, help='patch length')
    parser.add_argument('--stride', type=int, default=8, help='stride')
    parser.add_argument('--prompt_domain', type=int, default=0, help='')

    # optimization
    parser.add_argument('--num_workers', type=int, default=10, help='data loader num workers')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size of train input data')
    parser.add_argument('--learning_rate', type=float, default=0.0001, help='optimizer learning rate')
    parser.add_argument('--des', type=str, default='test', help='exp description')
    parser.add_argument('--loss', type=str, default='MSE', help='loss function')
    parser.add_argument('--use_amp', action='store_true', help='use automatic mixed precision training', default=False)
    parser.add_argument('--llm_layers', type=int, default=32)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    args = parser.parse_args()

    test_data, test_loader = data_provider(args, 'test')

    if args.model == 'LPI-LLM':
        model = LPILLM.Model(args)
    else:
        print('Model No Support!')
        exit()

    checkpoint = torch.load(args.checkpoints, map_location="cuda" if torch.cuda.is_available() else "cpu")
    model.load_state_dict(checkpoint, strict=False)
    model = model.float().to(device=device, dtype=torch.bfloat16)

    def absolute_error(pred, true):
        return torch.mean(torch.sum(torch.abs(pred - true), dim=1))

    def errors(predicted, target):
        pred_filtered = torch.where(predicted < 0.03, torch.tensor(0.0), predicted)
        error = torch.abs(pred_filtered - target)
        top_1_mae = torch.mean(torch.topk(error, 1, dim=1).values)
        top_5_mae = torch.mean(torch.topk(error, 5, dim=1).values)
        return torch.mean(torch.sum(error.squeeze(-1),dim=1)), top_1_mae, top_5_mae

    criterion = absolute_error
    metric = errors

    print("Evaluating model on test data...")
    test_loss, test_mae_loss, preds, top_1, top_5 = test(args, model, test_loader, criterion, metric, device)
    print("Evaluation results - Test Loss: {:.7f} Test CAE: {:.7f} Test Top 1 MAE: {:.7f} Test Top 5 MAE: {:.7f}".format(
        test_loss, test_mae_loss, top_1, top_5))