import myanchor
from myanchor import anchor_timeseries
import anomalyDetector
import preprocess_data
from model import model
import torch
from pathlib import Path
import numpy as np
import time
import os
from io import BytesIO

threshold = 100

def get_predict(prev):
    testdatap = test_dataset[:prev]
    # file =
    file = BytesIO()
    hidden, predictions, scores, rearranged, errors, hiddens, predicted_scores = anomalyDetector.anomalyScoreResume(args,mymodel,testdatap,mean,cov,score_predictor=None,channel_idx=channel_idx,
                                                                                                                    # tempfilename=file
                                                                                                                    )
    # print(file.name)
    file.seek(0,0)
    print(file.read())
    def predict(texts):
        res = []
        print(len(texts))
        tt = time.time()
        # print(file.read())
        # exit(0)
        for text in texts:
            file.seek(0,0)
            testdata = test_dataset[prev:len(text)+prev]
            testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
            # if prev != 0:
            #     print(testdata)
            #     print(testdata2)
            testdata[:,:,0] = testdata2[:,:,0]
            testdatap = test_dataset[:prev]
            testdata = torch.cat([testdatap,testdata],dim=0)
            # print(testdata)
            # print(scores)
            # print(predictions)
            _, prediction, score, rearrang, _, _, _ = anomalyDetector.anomalyScoreResume(args, mymodel, testdata, mean, cov,
                                                                                                            init_hidden=0,
                                                                                                            pasthidden = hidden,
                                                                                                            predictions=predictions.copy(),
                                                                                                            rearranged=rearranged.copy(),
                                                                                                            errors=errors.copy(),
                                                                                                            hiddens=hiddens.copy(),
                                                                                                            predicted_scores=predicted_scores.copy(),
                                                                                                            score_predictor=None,
                                                                                                            channel_idx=channel_idx,
                                                                                                            # tempfilename=file
                                                                                         )
            score = torch.stack(score)
            res.append(1 if score[-1] >= threshold else 0)
            # print(text)
            # print(prediction)
            # print(rearrang)
            # print(score[-1])
        print(time.time() - tt)
        return np.array(res)
    return predict


def predict(prev, texts):
    res = []
    for text in texts:
        testdata = test_dataset[prev:len(text)+prev]
        testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
        # if prev != 0:
        #     print(testdata)
        #     print(testdata2)
        testdata[:,:,0] = testdata2[:,:,0]
        testdatap = test_dataset[:prev]
        testdata = torch.cat([testdatap,testdata],dim=0)
        score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, mymodel, testdata, mean, cov,
                                                                                                  score_predictor=None,
                                                                                                  channel_idx=channel_idx)
        print(score)
        res.append(1 if score[-1] >= threshold else 0)
    return np.array(res)

if __name__ == "__main__":
    ecg_filenames = ("chfdb_chf01_275.pkl",
                     "chfdb_chf13_45590.pkl",
                     "chfdbchf15.pkl",
                     "ltstdb_20221_43.pkl",
                     "ltstdb_20321_240.pkl",
                     "mitdb__100_180.pkl",
                     "qtdbsel102.pkl",
                     "stdb_308_0.pkl",
                     "xmitdb_x108_0.pkl"
                     )
    respiration_filenames = ("nprs43.pkl",
                             "nprs44.pkl"
                             )
    space_shuttle_filenames = ("TEK14.pkl",
                               "TEK16.pkl",
                               "TEK17.pkl")
    gesture_filenames = ("ann_gun_CentroidA.pkl",)
    power_demand_filenames = ("power_data.pkl",)
    nyc_taxi_filenames = ("nyc_taxi.pkl",)

    dict = {}

    dict['ecg'] = ecg_filenames
    dict['respiration'] = respiration_filenames
    dict['space_shuttle'] = space_shuttle_filenames
    dict['gesture'] = gesture_filenames
    dict['power_demand'] = power_demand_filenames
    dict['nyc_taxi'] = nyc_taxi_filenames
    datatype = 'nyc_taxi'
    filename = 'nyc_taxi.pkl'


    for datatype in dict:
        for filename in dict[datatype]:
            print(datatype)
            print(filename)

            save_path = 'anchor_res'
            try:
                os.mkdir(str(Path(save_path,datatype)))
            except:
                pass
            f = open(str(Path(save_path,datatype, filename+'.txt').with_suffix('.txt')), 'w')


            TimeseriesData = preprocess_data.PickleDataLoad(data_type=datatype, filename=filename,
                                                            augment_test_data=False)

            checkpoint = torch.load(str(Path('save', datatype, 'checkpoint', filename).with_suffix('.pth')))
            args = checkpoint['args']

            args.prediction_window_size = 10
            args.beta = 1.0
            args.save_fig = True
            args.compensate = False

            channel_idx = 0


            nfeatures = TimeseriesData.trainData.size(-1)

            mymodel = model.RNNPredictor(rnn_type=args.model,
                                           enc_inp_size=nfeatures,
                                           rnn_inp_size=args.emsize,
                                           rnn_hid_size=args.nhid,
                                           dec_out_size=nfeatures,
                                           nlayers=args.nlayers,
                                           res_connection=args.res_connection).to(args.device)
            mymodel.load_state_dict(checkpoint['state_dict'])
            train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
            test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)

            if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
                print('=> loading pre-calculated mean and covariance')
                mean, cov = checkpoint['means'][channel_idx], checkpoint['covs'][channel_idx]
            else:
                print('=> calculating mean and covariance')
                mean, cov = anomalyDetector.fit_norm_distribution_param(args, mymodel, train_dataset, channel_idx=channel_idx)

            explainer = anchor_timeseries.AnchorTimeseries()

            score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, mymodel, test_dataset, mean, cov,
                                                                                                      score_predictor=None,
                                                                                                      channel_idx=0)
            # print(score)
            for i, sco in enumerate(score):
                if sco >=threshold:
                    npredict = lambda x: predict(max(i-20+1, 0), x)
                    npredict = get_predict(max(i - 20 + 1, 0))
                    print(str(i)+':')
                    f.write("%d:\n" % i)
                    # break
                    b=time.time()
                    print(np.array(test_dataset[max(i-20+1, 0):i+1, :, 0].cpu()).squeeze(1))
                    # exit(0)
                    exp = explainer.explain_instance(np.array(test_dataset[max(i-20+1, 0):i+1, :, 0].cpu()).squeeze(1),
                                                     npredict,threshold=0.99)
                    print(exp.names())
                    f.write(str(exp.names())+'\n')
                    print(exp.precision())
                    f.write(str(exp.precision())+'\n')
                    print('time: '+str(time.time()-b))
                    f.write('time: %f\n' % (time.time()-b))
                    f.write('\n')
                    f.flush()
                    # exit(0)

            f.write('\n')
            f.close()
