import myanchor
from myanchor import anchor_timeseries
from myanchor import anchor_timeseries_temporary
import anomalyDetector
import preprocess_data
from model import model
import torch
from pathlib import Path
import numpy as np
import time
import os
from io import BytesIO
import argparse


threshold = 100

parser = argparse.ArgumentParser(description='explain anomaly detection with anchor')

parser.add_argument('--datatype', type=str, default='nyc_taxi',
                    help='type of the input data')
parser.add_argument('--filename', type=str, default='nyc_taxi.pkl',
                    help='checkpoint of model')
parser.add_argument('--threshold', type=int,default=100,
                    help='the threshold of anomaly detection')


def get_predict(prev):
    testdatap = test_dataset[:prev]
    # file =
    file = BytesIO()
    hidden, predictions, scores, rearranged, errors, hiddens, predicted_scores = anomalyDetector.anomalyScoreResume(args,mymodel,testdatap,mean,cov,score_predictor=None,channel_idx=channel_idx,
                                                                                                                    tempfilename=file
                                                                                                                    )
    # print(file.name)
    file.seek(0,0)
    print(file.read())
    def predict(texts):
        res = []
        print(len(texts))
        tt = time.time()
        # print(file.read())
        # exit(0)
        for text in texts:
            text = list(filter(lambda x: x is not None, text))
            file.seek(0,0)
            testdata = test_dataset[prev:len(text)+prev]
            testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
            # if prev != 0:
            #     print(testdata)
            #     print(testdata2)
            testdata[:,:,0] = testdata2[:,:,0]
            testdatap = test_dataset[:prev]
            testdata = torch.cat([testdatap,testdata],dim=0)
            # print(testdata)
            # print(scores)
            # print(predictions)
            _, prediction, score, rearrang, _, _, _ = anomalyDetector.anomalyScoreResume(args, mymodel, testdata, mean, cov,
                                                                                                            init_hidden=0,
                                                                                                            pasthidden = hidden,
                                                                                                            predictions=predictions.copy(),
                                                                                                            rearranged=rearranged.copy(),
                                                                                                            errors=errors.copy(),
                                                                                                            hiddens=hiddens.copy(),
                                                                                                            predicted_scores=predicted_scores.copy(),
                                                                                                            score_predictor=None,
                                                                                                            channel_idx=channel_idx,
                                                                                                            tempfilename=file
                                                                                         )
            score = torch.stack(score)
            res.append(1 if score[-1] >= threshold else 0)
            # print(text)
            # print(prediction)
            # print(rearrang)
            # print(score[-1])
        print(time.time() - tt)
        return np.array(res)
    return predict


def predict(prev, texts):
    res = []
    for text in texts:
        testdata = test_dataset[prev:len(text)+prev]
        testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
        # if prev != 0:
        #     print(testdata)
        #     print(testdata2)
        testdata[:,:,0] = testdata2[:,:,0]
        testdatap = test_dataset[:prev]
        testdata = torch.cat([testdatap,testdata],dim=0)
        score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, mymodel, testdata, mean, cov,
                                                                                                  score_predictor=None,
                                                                                                  channel_idx=channel_idx)
        print(score)
        res.append(1 if score[-1] >= threshold else 0)
    return np.array(res)

if __name__ == "__main__":


    args_ = parser.parse_args()

    datatype = args_.datatype
    filename = args_.filename
    # threshold = args_.threshold
    # for datatype in dict:
    #     for filename in dict[datatype]:


    print(datatype)
    print(filename)
    print(threshold)


    TimeseriesData = preprocess_data.PickleDataLoad(data_type=datatype, filename=filename,
                                                    augment_test_data=False)

    checkpoint = torch.load(str(Path('save', datatype, 'checkpoint', filename).with_suffix('.pth')))
    args = checkpoint['args']

    args.prediction_window_size = 10
    args.beta = 1.0
    args.save_fig = True
    args.compensate = False

    channel_idx = 0


    nfeatures = TimeseriesData.trainData.size(-1)

    mymodel = model.RNNPredictor(rnn_type=args.model,
                                   enc_inp_size=nfeatures,
                                   rnn_inp_size=args.emsize,
                                   rnn_hid_size=args.nhid,
                                   dec_out_size=nfeatures,
                                   nlayers=args.nlayers,
                                   res_connection=args.res_connection).to(args.device)
    mymodel.load_state_dict(checkpoint['state_dict'])
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)

    if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
        print('=> loading pre-calculated mean and covariance')
        mean, cov = checkpoint['means'][channel_idx], checkpoint['covs'][channel_idx]
    else:
        print('=> calculating mean and covariance')
        mean, cov = anomalyDetector.fit_norm_distribution_param(args, mymodel, train_dataset, channel_idx=channel_idx)

    save_path = 'anchor_res2'
    try:
        os.mkdir(save_path)
    except:
        pass
    explainer = anchor_timeseries.AnchorTimeseries()
    # threshold = 150000
    # save_path = 'ReX_anchor_res'
    # explainer = anchor_timeseries_temporary.AnchorTimeseries()

    try:
        os.mkdir(str(Path(save_path,datatype)))
    except:
        pass

    score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, mymodel, test_dataset, mean, cov,
                                                                                              score_predictor=None,
                                                                                              channel_idx=0)
    # print(score)

    ''' 4. Evaluate the result '''
    # The obtained anomaly scores are evaluated by measuring precision, recall, and f_beta scores
    # The precision, recall, f_beta scores are calculated repeatedly,
    # sampling the threshold from 1 to the maximum anomaly score value, either equidistantly or logarithmically.
    print('=> calculating precision, recall, and f_beta')
    precision, recall, f_beta,th = anomalyDetector.get_precision_recall(args, score, num_samples=1000, beta=args.beta,
                                                     label=TimeseriesData.testLabel.to(args.device), threshold=True)

    labels = TimeseriesData.testLabel.to(args.device).cpu()
    for i, x in enumerate(labels):
        if x !=0:
            print(i)

    print('data: ', args.data, ' filename: ', args.filename,
          ' f-beta (no compensation): ', f_beta.max().item(), ' beta: ', args.beta)
    # print(torch.log10(torch.tensor(score.max())))
    # print(torch.logspace(0, torch.log10(torch.tensor(score.max())), 1000).to(args.device))
    threshold = th[f_beta.argmax().item()]
    # print(len(th))
    # print(th)
    # print(len(f_beta))
    # print(len(recall))
    # print(f_beta.argmax().item())
    # print(f_beta)
    # print(threshold)
    # threshold = 100
    # exit(0)
    anomaly = (score > threshold).float()
    for i,x in enumerate(anomaly):
        if x != 0:
            print(i)
    print(score.max())
    # exit(0)

    f = open(str(Path(save_path,datatype, filename+'.txt').with_suffix('.txt')), 'w')
    f.write("threshold:" + str(threshold))
    threshold = 10000
    for i, sco in enumerate(score):
        if sco >=threshold:
            npredict = lambda x: predict(max(i-20+1, 0), x)
            npredict = get_predict(max(i - 20 + 1, 0))
            print(str(i)+':')
            f.write("%d:\n" % i)
            # break
            b=time.time()
            print(np.array(test_dataset[max(i-20+1, 0):i+1, :, 0].cpu()).squeeze(1))
            # exit(0)
            exp = explainer.explain_instance(np.array(test_dataset[max(i-20+1, 0):i+1, :, 0].cpu()).squeeze(1),
                                             npredict,threshold=0.99)
            print(exp.names())
            f.write(str(exp.names())+'\n')
            print(exp.precision())
            f.write(str(exp.precision())+'\n')
            print('time: '+str(time.time()-b))
            f.write('time: %f\n' % (time.time()-b))
            f.write('\n')
            f.flush()
            # exit(0)

    f.write('\n')
    f.close()
