import numpy as np

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, accuracy_score
import matplotlib.pyplot as plt
from sklearn.svm import SVC

def inspect_confusion_matrix(examples, predictions):
    gts = [x['label'] for x in examples]
    preds = [x['label'] for x in predictions]
    cfm = confusion_matrix(gts, preds, labels=['True', 'False', 'Neither'])
    print("Acc: {:.2f}".format(accuracy_score(gts, preds) * 100))
    print(cfm)

def filter_by_predicted_label(examples, predictions, *labels):
    if not labels:
        return examples, predictions
    exs = []
    preds = []
    for ex, p in zip(examples, predictions):
        if p["label"] not in labels:
            continue
        exs.append(ex)
        preds.append(p)
    return exs, preds

class FewshotClsReranker:
    def __init__(self):        
        self.index_of_label = {'True': 0, 'False': 1, 'Neither': 2}
        self.label_of_index = ['True', 'False', 'Neither']
        self.model = LogisticRegression(random_state=42, C=10, fit_intercept=True)        

    def train(self, examples, predictions):        
        orig_preds = np.array([self.index_of_label[p['label']] for p in predictions] )
        gt_scores = [self.index_of_label[ex['label']] for ex in examples]
        cls_scores = [p['class_probs'] for p in predictions]        
        gt_scores, cls_scores = np.array(gt_scores), np.array(cls_scores)
        
        self.model.fit(cls_scores, gt_scores)
        train_preds = self.model.predict(cls_scores)
        train_acc = np.mean(train_preds ==gt_scores)
        print("Base ACC: {:.2f}".format(np.mean(orig_preds == gt_scores) * 100), "Train ACC: {:.2f}".format(train_acc * 100))

    def apply(self, ex, pred):
        probs = pred['class_probs']
        p = self.model.predict(np.array([probs]))[0]
        return self.label_of_index[p]


class JointClsProbExplReranker:
    def __init__(self):
        self.model = LogisticRegression(random_state=42, C=10, fit_intercept=True)
        self.index_of_label = {'True': 0, 'False': 1, 'Neither': 2}
        self.label_of_index = ['True', 'False', 'Neither']

    def train(self, examples, predictions):
        # calibrate true
        orig_preds = np.asarray([self.index_of_label[p['label']] for p in predictions])
        gt_scores = np.asarray([self.index_of_label[ex['label']] for ex in examples])
        cls_scores = np.asarray([p['class_probs'] for p in predictions])        
        pre_scores = np.asarray([[p['premise_coverage']] for p in predictions])
        hyp_scores = np.asarray([[p['hypothesis_coverage']] for p in predictions])
        training_feature = np.concatenate((cls_scores, pre_scores), axis=1)
        # print(training_feature.shape)
        self.model.fit(training_feature, gt_scores)
        train_preds = self.model.predict(training_feature)
        train_acc = np.mean(train_preds ==gt_scores)
        print("Base ACC: {:.2f}".format(np.mean(orig_preds == gt_scores) * 100), "Train ACC: {:.2f}".format(train_acc * 100))

    def apply(self, ex, pred):
        probs = pred['class_probs'] + [pred['premise_coverage']]
        p = self.model.predict(np.array([probs]))[0]
        return self.label_of_index[p]
