import numpy as np
import torch
import copy
import time
import torch.nn as nn
from functools import partial
from tqdm import tqdm
from optimalfair.algorithm.classifierbase import basicprocess
from optimalfair.utils.models import *
from optimalfair.utils.model_utils import *

import sklearn, sklearn.linear_model, sklearn.calibration
from optimalfair.utils.linearpost_utils import *
# from utils.linearpost_postprocess import PostProcessor
from utils.models import DummyEstimator
import matplotlib.pyplot as plt


class classifier(basicprocess):
    def __init__(self, dataset, options, name=''):
        super().__init__(dataset, options, name)
        self.calibration = self.options['calibration']

    def train(self):
        self.model_Y_give_XA = self.fit_Y_give_XA()
        # self.model_A_give_XY = self.fit_A_give_XY()
        self.model_AY_give_X = self.fit_AY_give_X()

        # Calibrator definition
        calibrator_ay_factory = partial(sklearn.calibration.CalibratedClassifierCV,
                                    estimator=DummyEstimator(self.n_class * self.n_group),
                                    cv='prefit',
                                    method='sigmoid')
        
        # # Precompute P(A, Y | X, A) on (test + post-processing) set for aware setting
        # data_X = np.concatenate((self.val_data.X, self.test_data.X),axis=0)
        # data_Y = np.concatenate((self.val_data.Y, self.test_data.Y),axis=0)
        # data_A = np.concatenate((self.val_data.A, self.test_data.A),axis=0)

        inputs_ = np.concatenate([self.test_data.X, self.test_data.A], axis=1)
        probas_y_ = self.model_Y_give_XA.predict_proba(inputs_).reshape(-1, self.n_class).numpy()
        probas_ay_ = np.einsum("ij,ik->ijk", np.eye(self.n_group)[self.test_data.A.astype(int).reshape(-1)], probas_y_)

        # Precompute P(A, Y | X) on (test + post-processing) set for blind setting
        probas_ay_u_ = self.model_AY_give_X.predict_proba(self.test_data.X).reshape(
            -1, self.n_group, self.n_class).numpy()
        
        # print(f'shape of probas_y_: {probas_y_.shape}, probas_ay_: {probas_ay_.shape}, probas_ay_u_: {probas_ay_u_.shape}')
        
        postprocess_kwargs = {
            'n_test': len(self.test_data),
            'n_classes': self.n_class,
            'n_groups': self.n_group,
            'labels': self.test_data.Y.reshape(-1),
            'groups': self.test_data.A.reshape(-1),
            'p_ay_x': probas_ay_,
            'max_workers': 16,
        }
        postprocess_u_kwargs = {**postprocess_kwargs, 'p_ay_x': probas_ay_u_}

        alphas_blind = [self.fair_bound]

        if self.fair_metric == 'dp':
            criterion = 'sp'
        else:
            criterion = self.fair_metric

        seeds=self.options['seed']

        df_blind = postprocess_and_evaluate(
            alphas_blind,
            seeds,
            criterion,
            ['accuracy', f'delta_{criterion}'],
            **postprocess_u_kwargs,
        )
        # df_blind_cal = postprocess_and_evaluate(
        #     alphas_blind,
        #     seeds,
        #     criterion,
        #     ['accuracy', f'delta_{criterion}'],
        #     calibrator_ay_factory=calibrator_ay_factory,
        #     **postprocess_u_kwargs,
        # )

        fig, ax = plt.subplots(1, 1, figsize=(3.5, 3))
        plot_results(ax, df_blind, f'delta_{criterion}', 'accuracy', label='not cal.')
        plot_results(ax, df_blind_cal, f'delta_{criterion}', 'accuracy', label='cal.')
        ax.set_xlabel(f"fairness violation")
        ax.set_ylabel("classification accuracy")
        dataset_name = self.options['data']
        criterion_name = criterion
        model_name = self.options['model']

        ax.set_title(f"{dataset_name} / {criterion_name} (attr. blind) / {model_name}")
        ax.legend()
        plt.show()

        print("Results for attribute blind (not calibrated):")
        # display(df_blind)
        print("Results for attribute blind (calibrated):")
        # display(df_blind_cal)