from calc.calc_coverage import CovPreCalculator,CovPreCalculatorLIME
from fit import anchor_fit, lime_fit
import argparse
import preprocess_data
import torch
from pathlib import Path
import anomalyDetector
from model import model
from predict import Predictor
import numpy as np
import sys


def load(datatype, filename):
    TimeseriesData = preprocess_data.PickleDataLoad(data_type=datatype, filename=filename,
                                                    augment_test_data=False)

    checkpoint = torch.load(str(Path('save', datatype, 'checkpoint', filename).with_suffix('.pth')))
    args = checkpoint['args']

    args.prediction_window_size = 10
    args.beta = 1.0
    args.save_fig = True
    args.compensate = False

    channel_idx = 0

    nfeatures = TimeseriesData.trainData.size(-1)

    mymodel = model.RNNPredictor(rnn_type=args.model,
                                 enc_inp_size=nfeatures,
                                 rnn_inp_size=args.emsize,
                                 rnn_hid_size=args.nhid,
                                 dec_out_size=nfeatures,
                                 nlayers=args.nlayers,
                                 res_connection=args.res_connection).to(args.device)
    mymodel.load_state_dict(checkpoint['state_dict'])
    train_dataset = TimeseriesData.batchify(args, TimeseriesData.trainData[:TimeseriesData.length], bsz=1)
    test_dataset = TimeseriesData.batchify(args, TimeseriesData.testData, bsz=1)

    if 'means' in checkpoint.keys() and 'covs' in checkpoint.keys():
        print('=> loading pre-calculated mean and covariance')
        mean, cov = checkpoint['means'][channel_idx], checkpoint['covs'][channel_idx]
    else:
        print('=> calculating mean and covariance')
        mean, cov = anomalyDetector.fit_norm_distribution_param(args, mymodel, train_dataset, channel_idx=channel_idx)
    return {"model": mymodel,
            "test_dataset": test_dataset,
            "mean": mean,
            "cov": cov
            }

def normalize(rules):
    sum = 0
    res = []
    for _,x in rules:
        sum += abs(x)
    for x,y in rules:
        res.append((x,y/sum))
    return res


def selection(lime_dist,num):
    selected = []
    for x in lime_dist:
        if (x[0][0],x[0][1]) not in [(y[0][0],y[0][1]) for y in selected]:
            selected.append(x)
        if len(selected) == num:
            return selected
    return selected



def calc(datatype, filename, threshold):
    ckpt = load(datatype, filename)
    test_dataset = ckpt["test_dataset"]
    anchor_exp = open("./anchor_res/%s/%s.txt" % (datatype, filename), "r").readlines()
    rex_anchor_exp = open("./ReX_anchor_res/%s/%s.txt" % (datatype, filename), "r").readlines()
    lime_exp = open("./lime_res/%s/%s.txt" % (datatype, filename), "r").readlines()
    rex_lime_exp = open("./ReX_lime_res/%s/%s.txt" % (datatype, filename), "r").readlines()
    predictor = Predictor(datatype=datatype, filename=filename, threshold=threshold)

    # print(anchor_exp)
    # print(rex_lime_exp)

    anchor_res = []
    rex_anchor_res = []
    lime_res = []
    rex_lime_res = []
    # f = open('calc9.out', "w")
    # sys.stdout = f
    for i, j, k in zip(range(0, len(anchor_exp), 5), range(0, len(rex_lime_exp), 4), range(0, len(lime_exp), 3)):
        # print(anchor_exp[i])
        pos = eval(anchor_exp[i][:-2])
        print(pos)
        words = np.array(test_dataset[max(pos - 20 + 1, 0):pos + 1, :, 0].cpu().squeeze(1))
        prev = max(pos - 20 + 1, 0)
        anchors = eval(anchor_exp[i + 1])
        rex_anchors = eval(rex_anchor_exp[i + 1])
        lime = eval(lime_exp[k + 1])
        # print(rex_lime_exp[j + 1])
        rex_lime = eval(rex_lime_exp[j + 1])
        predict = predictor.get_predict(prev)
        if max(np.array(lime)[:, 1]) == 0:
            continue
        if max([x[1] for x in rex_lime]) == 0:
            continue
        rex_lime = selection(rex_lime,len(lime))

        lime = normalize(lime)
        rex_lime = normalize(rex_lime)
        def mycalc(rules, myfit):
            return CovPreCalculator(words=words, rules=rules, myfit=myfit, predict=predict, label=1).calc()

        def mycalcLIME(rules, myfit):
            return CovPreCalculatorLIME(words=words, rules=rules, myfit=myfit, predict=predict, label=1).calc()
        anchor_res.append(mycalc(anchors, anchor_fit.AnchorFit(anchors, words)))
        rex_anchor_res.append(mycalc(rex_anchors, anchor_fit.RexFit(rex_anchors, words)))
        print(anchor_res[-1])
        print(rex_anchor_res[-1])
        lime_res.append(mycalcLIME(lime, lime_fit.FitLime(lime, words)))
        rex_lime_res.append(mycalcLIME(rex_lime, lime_fit.FitRex(rex_lime, words)))
        print(lime_res[-1])
        print(rex_lime_res[-1])
        sys.stdout.flush()

    anchor_res = np.array(anchor_res)
    rex_anchor_res = np.array(rex_anchor_res)
    lime_res = np.array(lime_res)
    rex_lime_res = np.array(rex_lime_res)

    print("Anchor:")
    print("coverage: ", end="")
    # print(anchor_res)
    print(sum(anchor_res[:, 0]) / len(anchor_res))
    print("precision: ", end="")
    print(sum(anchor_res[:, 1]) / len(anchor_res))
    print()

    print("Rex_Anchor:")
    print("coverage: ", end="")
    print(sum(rex_anchor_res[:, 0]) / len(rex_anchor_res))
    print("precision: ", end="")
    print(sum(rex_anchor_res[:, 1]) / len(rex_anchor_res))
    print()

    print("LIME:")
    print("coverage: ", end="")
    print(sum(lime_res[:, 0]) / len(lime_res))
    print("precision: ", end="")
    print(sum(lime_res[:, 1]) / len(lime_res))
    print()
    #
    print("ReX_LIME")
    print("coverage: ", end="")
    print(sum(rex_lime_res[:, 0]) / len(rex_lime_res))
    print("precision: ", end="")
    print(sum(rex_lime_res[:, 1]) / len(rex_lime_res))
    print()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='analyse explanation result')
    parser.add_argument('--datatype', type=str, default='ecg',
                        help="path of file needed analyse")
    parser.add_argument('--filename', type=str, default='chfdb_chf13_45590.pkl',
                        help="path")

    args = parser.parse_args()

    datatype = args.datatype
    filename = args.filename

    calc(datatype, filename, threshold=10000)
