import random

from myanchor import anchor_timeseries
import anomalyDetector
import preprocess_data
from model import model
import torch
from pathlib import Path
import numpy as np
import time
import os



threshold = 10000


def get_predict(prev):
    testdatap = test_dataset[:prev]
    hidden, predictions, scores, rearranged, errors, hiddens, predicted_scores = anomalyDetector.anomalyScoreResume(args,model,testdatap,mean,cov,score_predictor=None,channel_idx=channel_idx)
    def predict(texts):
        res = []
        for text in texts:
            testdata = test_dataset[prev:len(text)+prev]
            testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
            # if prev != 0:
            #     print(testdata)
            #     print(testdata2)
            testdata[:,:,0] = testdata2[:,:,0]
            testdatap = test_dataset[:prev]
            testdata = torch.cat([testdatap,testdata],dim=0)
            # print(testdata)
            # print(scores)
            # print(predictions)
            _, prediction, score, rearrang, _, _, _ = anomalyDetector.anomalyScoreResume(args, model, testdata, mean, cov,
                                                                                                            init_hidden=0,
                                                                                                            pasthidden = hidden,
                                                                                                            predictions=predictions.copy(),
                                                                                                            rearranged=rearranged.copy(),
                                                                                                            errors=errors.copy(),
                                                                                                            hiddens=hiddens.copy(),
                                                                                                            predicted_scores=predicted_scores.copy(),
                                                                                                            score_predictor=None,
                                                                                                            channel_idx=channel_idx)
            score = torch.stack(score)
            print(score[-1])
            res.append(1 if score[-1] >= threshold else 0)
            # print(text)
            # print(prediction)
            # print(rearrang)
            # print(score[-1])
        return np.array(res)
    return predict



def predict(prev, texts):
    res = []
    # print(len(texts))
    for i, text in enumerate(texts):
        # print(i)
        # print(text)
        testdata = test_dataset[prev:len(text) + prev]
        testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
        # if prev != 0:
        #     print(testdata)
        #     print(testdata2)
        testdata[:, :, 0] = testdata2[:, :, 0]
        testdatap = test_dataset[:prev]
        testdata = torch.cat([testdatap, testdata], dim=0)
        # print(testdata)
        score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(args, model, testdata,
                                                                                                  mean, cov,
                                                                                                  score_predictor=None,
                                                                                                  channel_idx=channel_idx)

        # print(score)
        myscore = 1 if score[-1] >= threshold else 0
        res.append(myscore)
    return np.array(res)



if __name__ == '__main__':

    datatype = 'ecg'
    filename = 'chfdb_chf13_45590.pkl'

    TimeseriesData = preprocess_data.PickleDataLoad(data_type=datatype, filename=filename,
                                                    augment_test_data=False)

    checkpoint = torch.load(str(Path('save', datatype, 'checkpoint', filename).with_suffix('.pth')))
    args = checkpoint['args']

    args.prediction_window_size = 10
    args.beta = 1.0
    args.save_fig = True
    args.compensate = False

    channel_idx = 0

    nfeatures = TimeseriesData.trainData.size(-1)

    model = model.RNNPredictor(rnn_type=args.model,
                               enc_inp_size=nfeatures,
                               rnn_inp_size=args.emsize,
                               rnn_hid_size=args.nhid,
                               dec_out_size=nfeatures,
                               nlayers=args.nlayers,
                               res_connection=args.res_connection).to(args.device)
    model.load_state_dict(checkpoint['state_dict'])
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)


    exp = anchor_timeseries.AnchorTimeseries()


    if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
        print('=> loading pre-calculated mean and covariance')
        mean, cov = checkpoint['means'][channel_idx], checkpoint['covs'][channel_idx]
    else:
        print('=> calculating mean and covariance')
        mean, cov = anomalyDetector.fit_norm_distribution_param(args, model, train_dataset, channel_idx=channel_idx)

    pos = 406

    print(list(np.array(test_dataset[pos-20+1:pos + 1, :, 0].cpu()).squeeze(1)))

    x = np.array(test_dataset[pos-20+1:pos + 1, :, 0].cpu()).squeeze(1)
    #
    # exit()

    # text = [-0.72281784, -1.1551771, -1.4112191, -1.5934472, -1.6865796, -1.8195024,
    #         -1.8712586, -1.9132113, -1.9230148, -1.915518, -1.84502, -1.6123332]
    text = x
    # print(mean.unsqueeze(0))
    # npredict = get_predict(1)
    npredict =get_predict(pos-20+1)
    print(npredict([text]))
    words, positions, true_label, sample_fn = exp.get_sample_fn(
        text , npredict)
    #
    # for i in range(12):
    #
    #     raw_data, data, _ = sample_fn([i], 100, False)
    #     print(i, end=":\n")
    #     sum = 0
    #     for x in raw_data:
    #         labels = npredict([x, ])
    #         sum += labels[-1]
    #         # print(x, end=' ')
    #         # print(labels)
    #         # if labels[-1]==0:
    #             # time.sleep(5)
    #     print(str(sum)+'/'+str(len(raw_data)))
    # f = open("checkinput.out","w")
    for i in range(len(text)-1,-1,-1):
        sum = 0
        print(i,end=':')
        # f.write(str(i)+":")
        for j in range(100):
            changed = text.copy()
            changed[i] = np.random.choice([text[i]+(np.random.normal()), text[i]+np.random.random()*2-1,
                                           np.random.normal(), text[i]+random.random()*2-1])

            labels = npredict([changed, ])
            sum += labels[-1]
            # print(x, end=' ')
            # print(labels)
            if labels[-1]==0:
                time.sleep(1)
        print(str(sum)+'/'+str(100))
        # f.write("%d/%d\n" % (sum, 100))
        # f.flush()




