from io import BytesIO
import time
from model import model
import preprocess_data
import torch
from pathlib import Path
import anomalyDetector
import numpy as np

class Predictor:
    def __init__(self,datatype, filename, threshold, channel_idx=0):
        TimeseriesData = preprocess_data.PickleDataLoad(data_type=datatype, filename=filename,
                                                        augment_test_data=False)

        checkpoint = torch.load(str(Path('save', datatype, 'checkpoint', filename).with_suffix('.pth')))
        args = checkpoint['args']
        args.prediction_window_size = 10
        args.beta = 1.0
        args.save_fig = True
        args.compensate = False
        self.args =args

        nfeatures = TimeseriesData.trainData.size(-1)

        self.model = model.RNNPredictor(rnn_type=args.model,
                                     enc_inp_size=nfeatures,
                                     rnn_inp_size=args.emsize,
                                     rnn_hid_size=args.nhid,
                                     dec_out_size=nfeatures,
                                     nlayers=args.nlayers,
                                     res_connection=args.res_connection).to(args.device)
        self.model.load_state_dict(checkpoint['state_dict'])
        self.train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
        self.test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)
        self.channel_idx = channel_idx
        self.threshold = threshold
        mean, cov = checkpoint['means'][self.channel_idx], checkpoint['covs'][self.channel_idx]
        self.mean = mean
        self.cov = cov




    def get_predict(self, prev):
        testdatap = self.test_dataset[:prev]
        file = BytesIO()
        hidden, predictions, scores, rearranged, errors, hiddens, predicted_scores = anomalyDetector.anomalyScoreResume(self.args,self.model,testdatap,self.mean,self.cov,score_predictor=None,channel_idx=self.channel_idx,
                                                                                                                        tempfilename=file
                                                                                                                        )
        # print(file.name)
        file.seek(0,0)
        # print(file.read())
        def predict(texts):
            res = []
            print("predict")
            print(len(texts))
            tt = time.time()
            # print(file.read())
            # exit(0)
            for text in texts:
                text = list(filter(lambda x: x is not None, text))
                file.seek(0,0)
                testdata = self.test_dataset[prev:len(text)+prev]
                testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
                # if prev != 0:
                #     print(testdata)
                #     print(testdata2)
                testdata[:,:,0] = testdata2[:,:,0]
                testdatap = self.test_dataset[:prev]
                testdata = torch.cat([testdatap,testdata],dim=0)
                # print(testdata)
                # print(scores)
                # print(predictions)
                _, prediction, score, rearrang, _, _, _ = anomalyDetector.anomalyScoreResume(self.args, self.model, testdata, self.mean, self.cov,
                                                                                                                init_hidden=0,
                                                                                                                pasthidden = hidden,
                                                                                                                predictions=predictions.copy(),
                                                                                                                rearranged=rearranged.copy(),
                                                                                                                errors=errors.copy(),
                                                                                                                hiddens=hiddens.copy(),
                                                                                                                predicted_scores=predicted_scores.copy(),
                                                                                                                score_predictor=None,
                                                                                                                channel_idx=self.channel_idx,
                                                                                                                tempfilename=file
                                                                                             )
                score = torch.stack(score)
                res.append(1 if score[-1] >= self.threshold else 0)
                # print(text)
                # print(prediction)
                # print(rearrang)
                # print(score[-1])
            # print(time.time() - tt)
            return np.array(res)
        return predict


    def predict(self, prev, texts):
        res = []
        for text in texts:
            testdata = self.test_dataset[prev:len(text)+prev]
            testdata2 = torch.tensor(np.array(text)).unsqueeze(1).unsqueeze(1)
            # if prev != 0:
            #     print(testdata)
            #     print(testdata2)
            testdata[:,:,0] = testdata2[:,:,0]
            testdatap = self.test_dataset[:prev]
            testdata = torch.cat([testdatap,testdata],dim=0)
            score, sorted_prediction, sorted_error, _, predicted_score = anomalyDetector.anomalyScore(self.args, self.model, testdata, self.mean, self.cov,
                                                                                                      score_predictor=None,
                                                                                                      channel_idx=self.channel_idx)
            print(score)
            res.append(1 if score[-1] >= self.threshold else 0)
        return np.array(res)