import mylime
from mylime import lime_timeseries
import anomalyDetector
import preprocess_data
from model import model
import torch
from pathlib import Path
import numpy as np
import time
threshold = 100


def predict(prev, texts):
    res = []
    print(len(texts))
    for text in texts:
        testdata = test_dataset[prev:len(text) + prev]
        testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
        # if prev != 0:
        #     print(testdata)
        #     print(testdata2)
        testdata[:, :, 0] = testdata2[:, :, 0]
        testdatap = test_dataset[:prev]
        testdata = torch.cat([testdatap, testdata], dim=0)
        # print(testdata)
        score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, model, testdata,
                                                                                                  mean, cov,
                                                                                                  score_predictor=None,
                                                                                                  channel_idx=channel_idx)
        res.append([1-score[-1],score[-1]])
    return np.array(res)

if __name__ == "__main__":

    TimeseriesData = preprocess_data.PickleDataLoad(data_type="nyc_taxi", filename="nyc_taxi.pkl",
                                                    augment_test_data=False)
    checkpoint = torch.load(str(Path('save', 'nyc_taxi', 'checkpoint', 'nyc_taxi').with_suffix('.pth')))
    args = checkpoint['args']

    args.prediction_window_size = 10
    args.beta = 1.0
    args.save_fig = True
    args.compensate = False

    channel_idx = 0


    nfeatures = TimeseriesData.trainData.size(-1)

    model = model.RNNPredictor(rnn_type=args.model,
                               enc_inp_size=nfeatures,
                               rnn_inp_size=args.emsize,
                               rnn_hid_size=args.nhid,
                               dec_out_size=nfeatures,
                               nlayers=args.nlayers,
                               res_connection=args.res_connection).to(args.device)
    model.load_state_dict(checkpoint['state_dict'])
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)

    if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
        print('=> loading pre-calculated mean and covariance')
        mean, cov = checkpoint['means'][channel_idx], checkpoint['covs'][channel_idx]
    else:
        print('=> calculating mean and covariance')
        mean, cov = anomalyDetector.fit_norm_distribution_param(args, model, train_dataset, channel_idx=channel_idx)

    explainer = lime_timeseries.LimeTimeseriesExplainer()

    score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, model, test_dataset, mean, cov,
                                                                              score_predictor=None,
                                                                              channel_idx=0)

    for i, sco in enumerate(score):
        if sco >=threshold:
            print(str(i)+':')
            b=time.time()
            exp = explainer.explain_instance(np.array(test_dataset[max(i-20, 0):i+1, :, 0].cpu()).squeeze(1),
                                             lambda x: predict(max(i-20, 0), x), num_samples=1000)
            print(exp.as_list())
            print('time: '+str(time.time()-b))
            # exit(0)