import mylime
import argparse
from mylime import lime_timeseries
import anomalyDetector
import preprocess_data
from model import model
import torch
from pathlib import Path
import numpy as np
import time
from io import BytesIO
import os
threshold = 10000

parser = argparse.ArgumentParser(description='explain anomaly detection with anchor')

parser.add_argument('--datatype', type=str, default='nyc_taxi',
                    help='type of the input data')
parser.add_argument('--filename', type=str, default='nyc_taxi.pkl',
                    help='checkpoint of model')
parser.add_argument('--threshold', type=int,default=100,
                    help='the threshold of anomaly detection')


def predict(prev, texts):
    res = []
    # print(len(texts))
    for i, text in enumerate(texts):
        # print(i)
        # print(text)
        testdata = test_dataset[prev:len(text) + prev]
        testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
        # if prev != 0:
        #     print(testdata)
        #     print(testdata2)
        testdata[:, :, 0] = testdata2[:, :, 0]
        testdatap = test_dataset[:prev]
        testdata = torch.cat([testdatap, testdata], dim=0)
        # print(testdata)
        score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, model, testdata,
                                                                                                  mean, cov,
                                                                                                  score_predictor=None,
                                                                                                  channel_idx=channel_idx)
        myscore = 1 if score[-1] >= threshold else 0
        res.append([1-myscore, myscore])
    return np.array(res)


def get_predict(prev):
    testdatap = test_dataset[:prev]
    # file =
    file = BytesIO()
    hidden, predictions, scores, rearranged, errors, hiddens, predicted_scores = anomalyDetector.anomalyScoreResume(args,model,testdatap,mean,cov,score_predictor=None,channel_idx=channel_idx,
                                                                                                                    tempfilename=file
                                                                                                                    )
    # print(file.name)
    file.seek(0,0)
    print(file.read())
    def predict(texts):
        res = []
        print(len(texts))
        tt = time.time()
        # print(file.read())
        # exit(0)
        for text in texts:
            text = list(filter(lambda x: x is not None, text))
            file.seek(0,0)
            testdata = test_dataset[prev:len(text)+prev]
            testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
            # if prev != 0:
            #     print(testdata)
            #     print(testdata2)
            testdata[:,:,0] = testdata2[:,:,0]
            testdatap = test_dataset[:prev]
            testdata = torch.cat([testdatap,testdata],dim=0)
            # print(testdata)
            # print(scores)
            # print(predictions)
            _, prediction, score, rearrang, _, _, _ = anomalyDetector.anomalyScoreResume(args, model, testdata, mean, cov,
                                                                                                            init_hidden=0,
                                                                                                            pasthidden = hidden,
                                                                                                            predictions=predictions.copy(),
                                                                                                            rearranged=rearranged.copy(),
                                                                                                            errors=errors.copy(),
                                                                                                            hiddens=hiddens.copy(),
                                                                                                            predicted_scores=predicted_scores.copy(),
                                                                                                            score_predictor=None,
                                                                                                            channel_idx=channel_idx,
                                                                                                            tempfilename=file
                                                                                         )
            score = torch.stack(score)
            myscore = 1 if score[-1] >= threshold else 0
            res.append([1 - myscore, myscore])
            # print(text)
            # print(prediction)
            # print(rearrang)
            # print(score[-1])
        print(time.time() - tt)
        return np.array(res)
    return predict



if __name__ == "__main__":

    ecg_filenames = ("chfdb_chf01_275.pkl",
                     "chfdb_chf13_45590.pkl",
                     "chfdbchf15.pkl",
                     "ltstdb_20221_43.pkl",
                     "ltstdb_20321_240.pkl",
                     "mitdb__100_180.pkl",
                     "qtdbsel102.pkl",
                     "stdb_308_0.pkl",
                     "xmitdb_x108_0.pkl"
                     )
    respiration_filenames = ("nprs43.pkl",
                             "nprs44.pkl"
                             )
    space_shuttle_filenames = ("TEK14.pkl",
                               "TEK16.pkl",
                               "TEK17.pkl")
    gesture_filenames = ("ann_gun_CentroidA.pkl",)
    power_demand_filenames = ("power_data.pkl",)
    nyc_taxi_filenames = ("nyc_taxi.pkl",)

    dict = {}

    dict['ecg'] = ecg_filenames
    dict['respiration'] = respiration_filenames
    dict['space_shuttle'] = space_shuttle_filenames
    dict['gesture'] = gesture_filenames
    dict['power_demand'] = power_demand_filenames
    dict['nyc_taxi'] = nyc_taxi_filenames

    args_ = parser.parse_args()
    #
    datatype = args_.datatype
    filename = args_.filename
    # threshold = args_.threshold
    # for datatype in dict:
    #     for filename in dict[datatype]:
    # datatype = 'ecg'
    # filename = 'chfdb_chf13_45590.pkl'
    save_path = 'lime_res2'
    try:
        os.mkdir(save_path)
    except:
        pass
    try:
        os.mkdir(str(Path(save_path, datatype)))
    except:
        pass
    f =open(str(Path(save_path,datatype, filename+'.txt').with_suffix('.txt')), 'w')

    TimeseriesData = preprocess_data.PickleDataLoad(data_type=datatype, filename=filename,
                                                    augment_test_data=False)

    checkpoint = torch.load(str(Path('save', datatype, 'checkpoint', filename).with_suffix('.pth')))
    args = checkpoint['args']

    args.prediction_window_size = 10
    args.beta = 1.0
    args.save_fig = True
    args.compensate = False

    channel_idx = 0

    nfeatures = TimeseriesData.trainData.size(-1)

    model = model.RNNPredictor(rnn_type=args.model,
                               enc_inp_size=nfeatures,
                               rnn_inp_size=args.emsize,
                               rnn_hid_size=args.nhid,
                               dec_out_size=nfeatures,
                               nlayers=args.nlayers,
                               res_connection=args.res_connection).to(args.device)
    model.load_state_dict(checkpoint['state_dict'])
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)

    if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
        print('=> loading pre-calculated mean and covariance')
        mean, cov = checkpoint['means'][channel_idx], checkpoint['covs'][channel_idx]
    else:
        print('=> calculating mean and covariance')
        mean, cov = anomalyDetector.fit_norm_distribution_param(args, model, train_dataset, channel_idx=channel_idx)

    explainer = lime_timeseries.LimeTimeseriesExplainer()

    score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, model, test_dataset, mean, cov,
                                                                              score_predictor=None,
                                                                              channel_idx=0)

    for i, sco in enumerate(score):
        if sco >=threshold:
            print(str(i)+':')
            # npredict = lambda x: predict(max(i - 20, 0), x)
            npredict = get_predict(max(i - 20 + 1, 0))
            f.write("%d:\n" % i)
            b=time.time()
            exp = explainer.explain_instance(np.array(test_dataset[max(i-20+1, 0):i+1, :, 0].cpu()).squeeze(1),
                                             npredict, num_samples=500,num_features=20)
            print(exp.as_list())
            f.write(str(exp.as_list())+'\n')
            print('time: '+str(time.time()-b))
            f.write('time: %f\n' % (time.time()-b))
            # exit(0)
            # f.write('\n')
            # f.flush()
    # f.write('\n')
    # f.close()
