import myanchor
from myanchor import anchor_timeseries
from myanchor import anchor_timeseries_temporary
import anomalyDetector
import preprocess_data
from model import model
import torch
from pathlib import Path
import numpy as np
import time
import os
from io import BytesIO
import argparse
import re

threshold = 10000

parser = argparse.ArgumentParser(description='explain anomaly detection with anchor')

parser.add_argument('--datatype', type=str, default='nyc_taxi',
                    help='type of the input data')
parser.add_argument('--filename', type=str, default='nyc_taxi.pkl',
                    help='checkpoint of model')
parser.add_argument('--threshold', type=int,default=100,
                    help='the threshold of anomaly detection')


def get_predict(prev):
    testdatap = test_dataset[:prev]
    # file =
    file = BytesIO()
    hidden, predictions, scores, rearranged, errors, hiddens, predicted_scores = anomalyDetector.anomalyScoreResume(args,mymodel,testdatap,mean,cov,score_predictor=None,channel_idx=channel_idx,
                                                                                                                    tempfilename=file
                                                                                                                    )
    # print(file.name)
    file.seek(0,0)
    print(file.read())
    def predict(texts):
        res = []
        print(len(texts))
        tt = time.time()
        # print(file.read())
        # exit(0)
        for text in texts:
            text = list(filter(lambda x: x is not None, text))
            file.seek(0,0)
            testdata = test_dataset[prev:len(text)+prev]
            testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
            # if prev != 0:
            #     print(testdata)
            #     print(testdata2)
            testdata[:,:,0] = testdata2[:,:,0]
            testdatap = test_dataset[:prev]
            testdata = torch.cat([testdatap,testdata],dim=0)
            # print(testdata)
            # print(scores)
            # print(predictions)
            _, prediction, score, rearrang, _, _, _ = anomalyDetector.anomalyScoreResume(args, mymodel, testdata, mean, cov,
                                                                                                            init_hidden=0,
                                                                                                            pasthidden = hidden,
                                                                                                            predictions=predictions.copy(),
                                                                                                            rearranged=rearranged.copy(),
                                                                                                            errors=errors.copy(),
                                                                                                            hiddens=hiddens.copy(),
                                                                                                            predicted_scores=predicted_scores.copy(),
                                                                                                            score_predictor=None,
                                                                                                            channel_idx=channel_idx,
                                                                                                            tempfilename=file
                                                                                         )
            score = torch.stack(score)
            res.append(1 if score[-1] >= threshold else 0)
            # print(text)
            # print(prediction)
            # print(rearrang)
            # print(score[-1])
        print(time.time() - tt)
        return np.array(res)
    return predict


def predict(prev, texts):
    res = []
    for text in texts:
        testdata = test_dataset[prev:len(text)+prev]
        testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
        # if prev != 0:
        #     print(testdata)
        #     print(testdata2)
        testdata[:,:,0] = testdata2[:,:,0]
        testdatap = test_dataset[:prev]
        testdata = torch.cat([testdatap,testdata],dim=0)
        score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, mymodel, testdata, mean, cov,
                                                                                                  score_predictor=None,
                                                                                                  channel_idx=channel_idx)
        print(score)
        res.append(1 if score[-1] >= threshold else 0)
    return np.array(res)

def gen_rule(len):
    nrules=[]
    for j in range(len):
        for k in range(j+1):
            nrules.append([j, j, k])
        for k in range(j+1,len,1):
            for x in range(k - j):
                nrules.append([j, k, x])
    return nrules


if __name__ == "__main__":

    ecg_filenames = ("chfdb_chf01_275.pkl",
                     "chfdb_chf13_45590.pkl",
                     "chfdbchf15.pkl",
                     "ltstdb_20221_43.pkl",
                     "ltstdb_20321_240.pkl",
                     "mitdb__100_180.pkl",
                     "qtdbsel102.pkl",
                     "stdb_308_0.pkl",
                     "xmitdb_x108_0.pkl"
                     )
    respiration_filenames = ("nprs43.pkl",
                             "nprs44.pkl"
                             )
    space_shuttle_filenames = ("TEK14.pkl",
                               "TEK16.pkl",
                               "TEK17.pkl")
    gesture_filenames = ("ann_gun_CentroidA.pkl",)
    power_demand_filenames = ("power_data.pkl",)
    nyc_taxi_filenames = ("nyc_taxi.pkl",)

    dict = {}

    dict['ecg'] = ecg_filenames
    dict['respiration'] = respiration_filenames
    dict['space_shuttle'] = space_shuttle_filenames
    dict['gesture'] = gesture_filenames
    dict['power_demand'] = power_demand_filenames
    dict['nyc_taxi'] = nyc_taxi_filenames

    args_ = parser.parse_args()

    datatype = args_.datatype
    filename = args_.filename
    # datatype = 'ecg'
    # filename = 'chfdb_chf13_45590.pkl'
    # for datatype in dict:
    #     for filename in dict[datatype]:
    print(datatype)
    print(filename)
    print(threshold)


    TimeseriesData = preprocess_data.PickleDataLoad(data_type=datatype, filename=filename,
                                                    augment_test_data=False)

    checkpoint = torch.load(str(Path('save', datatype, 'checkpoint', filename).with_suffix('.pth')))
    args = checkpoint['args']

    args.prediction_window_size = 10
    args.beta = 1.0
    args.save_fig = True
    args.compensate = False

    channel_idx = 0


    nfeatures = TimeseriesData.trainData.size(-1)

    mymodel = model.RNNPredictor(rnn_type=args.model,
                                   enc_inp_size=nfeatures,
                                   rnn_inp_size=args.emsize,
                                   rnn_hid_size=args.nhid,
                                   dec_out_size=nfeatures,
                                   nlayers=args.nlayers,
                                   res_connection=args.res_connection).to(args.device)
    mymodel.load_state_dict(checkpoint['state_dict'])
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)

    if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
        print('=> loading pre-calculated mean and covariance')
        mean, cov = checkpoint['means'][channel_idx], checkpoint['covs'][channel_idx]
    else:
        print('=> calculating mean and covariance')
        mean, cov = anomalyDetector.fit_norm_distribution_param(args, mymodel, train_dataset, channel_idx=channel_idx)

    # save_path = 'anchor_res'
    # explainer = anchor_timeseries.AnchorTimeseries()
    save_path = 'ReX_anchor_res2'
    try:
        os.mkdir(save_path)
    except:
        pass
    explainer = anchor_timeseries_temporary.AnchorTimeseries()

    try:
        os.mkdir(str(Path(save_path,datatype)))
    except:
        pass
    f = open(str(Path(save_path,datatype, filename+'.txt').with_suffix('.txt')), 'w')

    score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, mymodel, test_dataset, mean, cov,
                                                                                              score_predictor=None,
                                                                                              channel_idx=0)
    # print(score)
    for i, sco in enumerate(score):
        if sco >=threshold:
            # npredict = lambda x: predict(max(i-20+1, 0), x)
            npredict = get_predict(max(i - 20 + 1, 0))
            print(str(i)+':')
            f.write("%d:\n" % i)
            # break
            b=time.time()
            print(np.array(test_dataset[max(i-20+1, 0):i+1, :, 0].cpu()).squeeze(1))
            # exit(0)


            text = np.array(test_dataset[max(i-20+1, 0):i+1, :, 0].cpu()).squeeze(1)
            rules = gen_rule(len(text)-1)
            exp = explainer.explain_instance(text=text,
                                             classifier_fn=npredict,threshold=0.99,rules=rules)
            exp_rule = exp.names()
            out_exp = [rules[x] for x in exp_rule]
            print(out_exp)
            f.write(str(out_exp)+'\n')
            print(exp.precision())
            f.write(str(exp.precision())+'\n')
            print('time: '+str(time.time()-b))
            f.write('time: %f\n' % (time.time()-b))
            f.write('\n')
            f.flush()
            # exit(0)

    f.write('\n')
    f.close()
