

# standard libraries
import torch
import numpy as np
import sys

# for autograd
from weights import weights
from helpers import set_seed, fast_xtdx
from training import train_model, train_model_GDRO

# for timing
import time
import gc
import os
import shutil

# torch libraries
from torch.autograd.functional import hessian, jacobian
from torch import nn
from torchvision import models
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module
from torch.func import functional_call, vmap, vjp, jvp, jacrev

# for linear/logistic regression
from sklearn.linear_model import LogisticRegression as logreg, SGDClassifier 
from sklearn.metrics import mean_squared_error as mse

# for multiprocessing
import multiprocessing as mp
from functools import partial

# for BERT
from transformers import BertForSequenceClassification, BertModel
import types



class model():

    def __init__(self, weights_obj):

        self.weights_obj = weights_obj
        self.weights =  self.weights_obj.set_weights_per_group()

    def acc(self, y_hat, y):
        """
        Calculate the accuracy of the model.
        """
        y_class = torch.round(torch.sigmoid(y_hat))

        return torch.mean((y_class == y).to(torch.float32))
    
    
    def group_loss(self, X, y, g, acc=False, pred_type='linear'):
        """
        Method to calculate the loss for a selected group and the last group.

        Parameters:
            X: np.array
                The feature matrix.
            y: np.array
                The target variable.
            g: np.array
                The group membership of each observation.
          
        Returns:
            loss: float
                The loss for the selected group and the last group.
        """

        # get all the groups
        groups = [int(group.item()) for group in torch.unique(g)]

        # create a dict to store the loss for each group
        loss_dict = dict.fromkeys(groups, 0)
        
        # loop through the groups
        for group in groups:

            # select observations for the group
            i_g = (g == group)

            # get the X, y, y_hat for the group
            X_g = X[i_g, :]
            y_g = y[i_g]
            y_hat_g = self.predict(X_g)
            

            # get the loss/acc for the group
            if acc:
                loss_g = self.acc(y_hat_g, y_g)
            else:
                loss_g = self.loss(y_hat_g, y_g)

            # store the loss
            loss_dict[group] = loss_g
        
        return loss_dict
        

    def optimize_GDRO_via_SGD(self, X, y, g, T, batch_size, eta_param, eta_q, C=0.0, use_val=False, X_val=None, y_val=None, g_val=None, early_stopping=False, patience=1, use_acc=False, learning_rate_schedule='bottou'):

        # get the number of groups
        groups = torch.unique(g)

        # get the initial q_t
        q_t ={int(group): 1/len(groups) for group in groups}

        # set the initial weights
        self.weights_obj.reset_p_weights(q_t)

        # define the n_dict; this is the number of observations in each group
        n_dict = {int(group): torch.sum(g == group) for group in groups}

        # save the best worst-case loss
        if use_acc:
            best_worst_group_criterion = 0
        else:
            best_worst_group_criterion = np.inf

        # set the initial learning rate
        if learning_rate_schedule == 'bottou':
            alpha = self.l_1_penalty.detach().numpy()
            w = np.sqrt(1/np.sqrt(alpha))
            grad_log_loss = (-w - 1)/((-w - 1)*w)
            eta0 = w / max(1.0, -grad_log_loss)
            optimal_init = 1.0 / (eta0 * alpha)
            


        # loop over T epochs
        for t in range(T):
            
            if learning_rate_schedule == 'bottou':
                eta_t = 1.0 / (alpha * (optimal_init + t ))
            elif learning_rate_schedule == 'constant':
                eta_t = eta_param
            print('At epoch {}, the learning rate is {}'.format(t, eta_t))

            

            # shuffle the indeces
            indices = np.arange(X.shape[0])
            shuffled_indices = np.random.permutation(indices)

            # loop over the batches, including the last batch
            for i in range(0, len(shuffled_indices), batch_size):

                # get the final index
                last_batch_index = min(i+batch_size, len(shuffled_indices))

                # get the batch indices
                batch_indices = shuffled_indices[i:last_batch_index]

                # get the batch
                X_b, y_b, g_b = X[batch_indices, :], y[batch_indices], g[batch_indices]

                # update the parameters
                self.gradient_step_DRO(X_b, y_b.squeeze(-1), g_b, eta_t)

                # get loss per group for the batch
                loss_dict = self.group_loss(X_b, y_b, g_b, acc=False)
                
                # based on the loss, update the q_t
                q_t = self.weights_obj.update_DRO_weights(q_t, loss_dict,  eta_q, C=C, n_dict=n_dict)

                # reset the p_weights of the model
                self.weights_obj.reset_p_weights(q_t)

            # after completing the first epoch, get the param
            if t == 0:
                best_Beta = self.Beta
                best_t = t

            
            # after completing an epoch, get the loss for the entire dataset
            if use_val:
                loss_dict = self.group_loss(X_val, y_val, g_val, acc=False)
                acc_dict = self.group_loss(X_val, y_val, g_val, acc=True)
            else:
                loss_dict = self.group_loss(X, y, g, acc=False)
                acc_dict = self.group_loss(X, y, g, acc=True)

         
            # record the worst-case loss, acc
            worst_group_loss = max(loss_dict.values())
            worst_group_acc = min(acc_dict.values())

            # print the following stats
            print('At epoch {}, the worst group loss is {} and the worst group acc is {}'.format(t, worst_group_loss, worst_group_acc))

            if use_acc:
                worst_group_criterion = -worst_group_acc
            else:
                worst_group_criterion = worst_group_loss

            if early_stopping and worst_group_criterion >= best_worst_group_criterion:
                patience -= 1
                if patience == 0:
                    print('Early stopping at epoch {}'.format(t))
                    break
            
            # if loss is better, save the parameters
            if worst_group_criterion < best_worst_group_criterion:
                best_worst_group_criterion = worst_group_criterion
                best_Beta = self.Beta
                best_t = t
            
        return best_Beta, best_worst_group_criterion, best_t





      


    def cross_val_loss(self, X, y, g, k=5, use_weighted_loss=False, weights_obj=None, seed=1):
        """
        Method to perform k-fold cross validation.

        Parameters:
            X: np.array
                The feature matrix.
            y: np.array
                The target variable.
            g: np.array
                The group membership of each observation.
            k: int
                The number of folds. Default is 5.
            use_weighted_loss: bool
                If True, use the weighted loss. Default is False.
            weights_obj: weights
                The weights object. Used for weights of loss. Default is None.

        """

        # get the number of observations
        n = X.shape[0]


        # create a list of indices
        indices = np.arange(n)

        # shuffle the indices
        set_seed(seed)
        np.random.shuffle(indices)

        # create the folds
        folds = np.array_split(indices, k)

        # create a list to store the loss
        loss = np.zeros(k)

        # loop through the folds
        for i in range(k):

            # get the test fold
            fold = folds[i]

            # get the train fold
            train = np.concatenate([folds[j] for j in range(k) if j != i])

            # get the weights for the val fold
            if use_weighted_loss:
                w_fold = weights_obj.get_weights_sample(g[fold])
            
            # fit the model
            self.fit(X[train, :], y[train], g[train])
            y_hat_fold = self.predict(X[fold,:])

            # get the loss
            if use_weighted_loss:
                loss_fold = self.loss(y_hat_fold, y[fold], w_loss=w_fold)     
                 
            else:
                loss_fold = self.loss(y_hat_fold, y[fold])

            # get the loss
            loss[i] = loss_fold

        # return the average loss
        return np.mean(loss)
    

    def fnet_single(self, params, x):
        return functional_call(self.nn_model, params, (x.unsqueeze(0),)).squeeze(0)
    
    @torch.jit.script
    def jit_contraction(jac1, jac2, compute: str):
        if compute == 'full':
            return torch.einsum('Naf,Mbf->NMab', jac1, jac2)
        elif compute == 'trace':
            return torch.einsum('Naf,Maf->NM', jac1, jac2)
        elif compute == 'diagonal':
            return torch.einsum('Naf,Maf->NMa', jac1, jac2)
        else:
            raise ValueError("Invalid compute option")



    def empirical_ntk_jacobian_contraction(self, fnet_single, params, x_1, x_2, compute='diagonal', jac1=None, jac2=None):

        if jac1 is None:
            # Compute Jacobian of x_1
            jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x_1).values()
            jac1 = [j.flatten(2) for j in jac1]
            print('Computed jac1')
        if jac2 is None:
            # Compute Jacobian of x_2
            jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x_2).values()
            jac2 = [j.flatten(2) for j in jac2]
            print('Computed jac2')

        # Compute Jacobian contraction
        K = torch.stack([self.jit_contraction(j1, j2, compute) for j1, j2 in zip(jac1, jac2)])
        K = K.sum(0)
        return K, jac1, jac2
    
    
    def compute_ntk_kernel(self, X, X_fit, compute='diagonal', jac1=None, jac2=None):

        K, jac1, jac2 = self.empirical_ntk_jacobian_contraction(self.fnet_single, self.nn_model_params, X, X_fit, compute=compute, jac1=jac1, jac2=jac2)
        K = K.reshape(X.shape[0], X_fit.shape[0])
        
        # set the jac1, jac2 as attributes
        self.jac1, self.jac2 = jac1, jac2

        return K


    def compute_linear_kernel(self, X, X_fit):

        # for now, return the linear kernel
        return X @ X_fit.T

    




class logistic_regression(model):

    def __init__(self,model_param_dict, p_weights=None, p_train=None, add_intercept=True, warm_start=True, verbose=False, weighted_loss_weights=True):
        super().__init__(weights(p_weights, p_train, weighted_loss_weights=weighted_loss_weights))

        self.p_weights = p_weights
        self.p_train = p_train
        self.add_intercept = add_intercept
        self.penalty_type = model_param_dict['penalty_type']
        self.penalty_strength = model_param_dict['penalty_strength']
        self.solver = model_param_dict['solver']
        self.tol = model_param_dict['tol']
        self.seed = model_param_dict['seed']
        self.use_SGDClassifier = model_param_dict['use_SGDClassifier']
        self.verbose =verbose

        # check; if penalty_type is l1, set l_1_penalty to penalty_strength
        if self.penalty_type == 'l1':
            self.l_1_penalty = torch.tensor(float(self.penalty_strength), requires_grad=True)
        else:
            self.l_1_penalty = torch.tensor(0., requires_grad=True)
        
        # check; if penalty_type is l2, set l_2_penalty to penalty_strength
        if self.penalty_type == 'l2':
            self.l_2_penalty = torch.tensor(float(self.penalty_strength), requires_grad=True)
        else:
            self.l_2_penalty = torch.tensor(0., requires_grad=True)
        

         # create the logreg object
        if self.use_SGDClassifier:
            if self.l_1_penalty > 0:
                l1_ratio = 1
            elif self.l_2_penalty > 0:
                l1_ratio = 0
            else:
                l1_ratio = None
            print('Using SGDClassifier')
            self.logreg = SGDClassifier(loss='log_loss',
                                        penalty=self.penalty_type,
                                        alpha=self.penalty_strength,
                                        fit_intercept=self.add_intercept,
                                        random_state=self.seed,
                                        l1_ratio=l1_ratio,
                                        tol=self.tol,
                                        verbose=self.verbose,
                                        max_iter=model_param_dict['T'],
                                        warm_start=warm_start,
                                        eta0=model_param_dict['eta0'],
                                        learning_rate=model_param_dict['learning_rate']
                                        )
        else:
            self.logreg = logreg(penalty=self.penalty_type,
                                C=1/self.penalty_strength,
                                fit_intercept=self.add_intercept,
                                random_state=self.seed,
                                solver=self.solver, 
                                tol=self.tol, 
                                verbose=self.verbose,
                                max_iter=model_param_dict['T'],
                                warm_start=warm_start,
                                dual=False)
            
    def gradient_step_DRO(self, X_b, y_b, g_b, eta_param):

        # get the weights
        w_b = self.weights_obj.get_weights_sample(g_b)

        # set the learning rate
        self.logreg.learning_rate = 'constant'
        self.logreg.eta0 = eta_param

        # use the partial fit method of the logreg object
        self.logreg.partial_fit(X_b, y_b, sample_weight=w_b, classes=np.unique(y_b))

        # set the coefficients and intercept
        self.coef_ = torch.tensor(self.logreg.coef_).to(torch.float32)
        self.intercept_ = torch.tensor(self.logreg.intercept_).to(torch.float32)

        # save the parameters - add the intercept to coefficients if needed
        if self.add_intercept:
            self.Beta = torch.cat([self.intercept_, self.coef_.squeeze(0)]).unsqueeze(-1)
        else:
            self.Beta = self.coef_
            

    def reset_weights(self, p_weights):
        self.weights_obj = weights(p_weights, self.weights_obj.p_train, weighted_loss_weights=self.weights_obj.weighted_loss_weights)


    def fit(self, X, y, g=None):

        # create weight matrix W
        if g is None:
            w = np.ones(X.shape[0])
        else:
            w = self.weights_obj.get_weights_sample(g)

        # check dim of y
        if y.shape[-1]==1:
            y = y.squeeze(-1)
        
        # fit to the data
        self.logreg.fit(X, y, sample_weight=w)

        # turn the coefficients and intercept into a tensor
        self.coef_ = torch.tensor(self.logreg.coef_).to(torch.float32)
        self.intercept_ = torch.tensor(self.logreg.intercept_).to(torch.float32)

        # save the parameters - add the intercept to coefficients if needed
        if self.add_intercept:
            self.Beta = torch.concatenate([self.intercept_, self.coef_.squeeze(0)]).unsqueeze(-1)
        else:
            self.Beta = self.coef_

    def predict(self, X, type_pred='linear'):
        if self.add_intercept:
            X = torch.concatenate([torch.ones(X.shape[0], 1), X], axis=1)
            X = X.to(torch.float32)


        if type_pred == 'linear':
            pred = torch.matmul(X, self.Beta)
        elif type_pred == 'probabilities':
            pred = torch.sigmoid(torch.matmul(X, self.Beta))
        elif type_pred == 'class':
            pred = torch.round(torch.sigmoid(torch.matmul(X, self.Beta)))
        else:
            sys.exit('type_pred must be one of: linear, probabilities, class')
  
        return pred
    
    def loss(self, y_hat, y, w_loss=None, binary=True, reduction='mean'):

        if w_loss is not None:
            w_loss = w_loss.unsqueeze(-1)
        if binary:
            loss_func = BCEWithLogitsLoss(weight=w_loss, reduction=reduction)
        else:
            loss_func = CrossEntropyLoss(weight=w_loss, reduction=reduction)
        
        loss = loss_func(y_hat, y.to(torch.float32))
        
        return loss
    

    def train_loss_via_param(self,Beta):
        
        # add intercept if needed
        if self.add_intercept:
            X_train = torch.cat([torch.ones(self.X_train.shape[0], 1), self.X_train], axis=1).to(torch.float32)
        else:
            X_train = self.X_train.to(torch.float32)

        # get the predictions based on the X_train and Beta
        logits = torch.matmul(X_train, Beta)

        # get the loss based on the y_hat and y_train
        train_loss = self.loss(logits, self.y_train, w_loss=self.w_train)

        # get the penalty
        l_1_penalty = self.l_1_penalty
        l_2_penalty = self.l_2_penalty

        # add penalty if needed
        if l_1_penalty > 0:
            l_1_norm_Beta = torch.sum(torch.abs(Beta))
            add_to_loss = l_1_penalty * l_1_norm_Beta

            # ensure that the gradient is calculated
            train_loss += add_to_loss

        if l_2_penalty > 0:
            l_2_norm_Beta_sq = torch.matmul(Beta.T, Beta)[0, 0]
            add_to_loss = l_2_penalty * l_2_norm_Beta_sq

            # ensure that the gradient is calculated
            train_loss += add_to_loss

        return train_loss
    
    def augmented_loss_subsample_via_param(self, Beta):
        # add intercept if needed
        if self.add_intercept:
            X_train = torch.cat([torch.ones(self.X_train.shape[0], 1), self.X_train], axis=1).to(torch.float32)
        else:
            X_train = self.X_train.to(torch.float32)
            
        # get the predictions based on the X_train and Beta
        logits = torch.matmul(X_train, Beta)

        # get the groups
        groups = torch.unique(self.g_train)

        # loop over the groups (except last group) and calculate the augmented loss
        augmented_loss_subsample_list= []
        for group in groups:
            augmented_loss_g = self.augmented_loss_subsample(logits, self.y_train, self.g_train,  group.int().item())
            augmented_loss_subsample_list.append(augmented_loss_g)
            
        # create tuple of augmented loss
        augmented_loss_subsample = tuple(augmented_loss_subsample_list)

        return augmented_loss_subsample
    
    
    def augmented_loss_via_param(self, Beta):

        # add intercept if needed
        if self.add_intercept:
            X_train = torch.cat([torch.ones(self.X_train.shape[0], 1), self.X_train], axis=1).to(torch.float32)
        else:
            X_train = self.X_train.to(torch.float32)
            
        # get the predictions based on the X_train and Beta
        logits = torch.matmul(X_train, Beta)

        # get the groups
        groups = torch.unique(self.g_train)

        # get the last group, defined as the group with the highest index
        last_group = max(groups)
        last_group_loss = self.loss(logits[self.g_train == last_group], self.y_train[self.g_train == last_group])
        

        # loop over the groups (except last group) and calculate the augmented loss
        augmented_loss_list= []
        for group in groups:
            if group != last_group:
                augmented_loss_g = self.augmented_loss(logits, self.y_train, self.g_train, group, last_group_loss)
                augmented_loss_list.append(augmented_loss_g)
                
        # create tuple of augmented loss
        augmented_loss = tuple(augmented_loss_list)

        return augmented_loss

     
    def weighted_val_loss_via_param(self, Beta):

        # add intercept if needed
        if self.add_intercept:
            X_val = torch.cat([torch.ones(self.X_val.shape[0], 1), self.X_val], axis=1).to(torch.float32)
        else:
            X_val = self.X_val.to(torch.float32)

        # get the predictions based on the X_val and Beta
        logits = torch.matmul(X_val, Beta)

        # get the loss based on the y_hat and y_val
        print('unique weights in val:', torch.unique(self.w_val))
        val_loss = self.loss(logits, self.y_val, w_loss=self.w_val)

        return val_loss
    
    def augmented_loss(self, y_hat, y, g,  selected_group, loss_last):

        # get the predictions/labels for the selected group
        y_hat_selected = y_hat[g == selected_group]
        y_selected = y[g == selected_group]

        # get the loss for the selected group, last group
        loss_selected = self.loss(y_hat_selected, y_selected)

        # return loss difference
        return (loss_selected - loss_last)
    
    def augmented_loss_subsample(self, y_hat, y, g,  selected_group):

        # get the predictions/labels for the selected group
        y_hat_selected = y_hat[g == selected_group]
        y_selected = y[g == selected_group]

        # get the summed loss for the selected group
        loss_selected_sum = self.loss(y_hat_selected, y_selected, reduction='mean')
        
        # multiply the summed loss by 1/m
        n_g = torch.sum(g == selected_group)
        factor = (n_g/self.m)
        augmented_loss_subsample = loss_selected_sum * factor 

        return augmented_loss_subsample

    
    
    def calc_weighted_grad(self, X, beta, w,  y, l_1_penalty, l_2_penalty, eps=1e-6):
        """
        Calculate the gradient of the logistic loss function
        """

        # calculate the sigmoid
        sigmoid = self.predict(X,  type_pred='probabilities').squeeze()

        # add the intercept to X if needed
        if self.add_intercept:
            X = torch.cat((torch.ones(X.shape[0], 1), X), dim=1).to(torch.float32)

        # calculate the gradient in two steps:
        # 1. first term: w(sigmoid - y)
        # 2. second term: multiply X^T with the first term
        grad = torch.matmul(X.T, w*(sigmoid - y))
        

        # divide by the number of samples
        grad /= X.shape[0]

        # take average
        grad = grad.mean(-1).unsqueeze(-1)

        # add the l_2 penalty
        if l_2_penalty>0:
            added_term = 2 * l_2_penalty * beta
            grad += (added_term/ X.shape[0])

        elif l_1_penalty>0:
            # the following is an approximation of the derivative of the l_1 penalty
            beta_squared = (beta**2)
            sqrt_beta_squared = torch.sqrt(beta_squared + eps)
            added_term =  ((beta / sqrt_beta_squared) * l_1_penalty)
            grad += (added_term/ X.shape[0])


        return grad
    
    
    def calc_weighted_Hessian(self, X,  w, beta, l_1_penalty, l_2_penalty, eps=1e-6, divide_by_n=True):
        """
        Calculate the Hessian of the logistic loss function
        """

        # create a diagonal matrix with inputs sigmoid(x_i^T beta) * (1 - sigmoid(x_i^T beta))
        sigmoid = self.sigmoid_x_beta(X, beta).squeeze()
        
        # calculate the diagonal matrix
        diag_H = ((sigmoid * (1 - sigmoid)) * w ).to(torch.float32)
        
        # add the intercept to X
        X = torch.cat((torch.ones(X.shape[0], 1), X), dim=1).to(torch.float32)
        
        # calculate the Hessian
        H = fast_xtdx(X, diag_H)

        # divide by the number of samples
        H /= X.shape[0]

        # add the l_2 penalty
        if l_2_penalty>0:
            added_term = torch.eye(H.shape[0])*l_2_penalty *2

            # add the term
            if divide_by_n:
                H+= (added_term/X.shape[0])
            else:
                H += added_term

        elif l_1_penalty>0:
            # the following is an approximation of the derivative of the l_1 penalty
            beta_squared = (beta**2).squeeze()
            H_l_1_approx =   eps/((beta_squared + eps)**(3/2))
            H_l_1_approx_diag = torch.diag(H_l_1_approx)*l_1_penalty
            added_term = H_l_1_approx_diag/X.shape[0]

            # add the term
            if divide_by_n:
                H += (added_term/X.shape[0])
            else:
                H += added_term

        return H
    


    
class kernel_logistic_regression(logistic_regression):

    def __init__(self, kernel, model_param_dict, p_weights=None, p_train=None, add_intercept=True, use_SGDClassifier=False, warm_start=True, verbose=False, nn_model=None ):
        super().__init__(model_param_dict, p_weights, p_train, add_intercept, use_SGDClassifier, warm_start, verbose)

        if kernel == 'linear':
            self.compute_kernel = self.compute_linear_kernel
        
        elif kernel == 'ntk':
            self.nn_model = nn_model.eval()

            # save the parameters, turn each to torch.float32
            self.nn_model_params = {k: v.detach().to(torch.float32) for k, v in nn_model.named_parameters()}
            self.compute_kernel = self.compute_ntk_kernel
        
        elif kernel is None:
            self.compute_kernel = None

    def fit(self, X, y, g=None):
        # Create weight matrix W
        if g is None:
            w = np.ones(X.shape[0])
        else:
            w = self.weights_obj.get_weights_sample(g)

        # Compute the kernel matrix
        K = self.compute_kernel(X, X)

        # Add small constant to diagonal for numerical stability
        K += 1e-8 * np.eye(K.shape[0])

        # Fit the logistic regression model with the kernel matrix
        self.logreg.fit(K, y, sample_weight=w)
        self.X_fit = X

        # Convert coefficients to torch tensors for consistency with parent class
        self.coef_ = torch.tensor(self.logreg.coef_)
        self.intercept_ = torch.tensor(self.logreg.intercept_)

        # Save the parameters - add the intercept to coefficients if needed
        if self.add_intercept:
            self.Beta = torch.cat([self.intercept_, self.coef_.squeeze(0)]).unsqueeze(-1)
        else:
            self.Beta = self.coef_

    def predict(self, X, type_pred='class'):
        # Compute the kernel matrix
        K = self.compute_kernel(X, self.X_fit, jac2=self.jac2)
        
        if type_pred == 'linear':
            return self.logreg.decision_function(K)
        elif type_pred == 'probabilities':
            return self.logreg.predict_proba(K)[:, 1]
        elif type_pred == 'class':
            return self.logreg.predict(K)
        else:
            raise ValueError('type_pred must be one of: linear, probabilities, class')

                    
class logistic_regression_subsample(logistic_regression):

    def __init__(self,model_param_dict, p_weights=None, p_train=None, add_intercept=True, verbose=False):
        super().__init__(model_param_dict, p_weights, p_train, add_intercept, weighted_loss_weights=False)

        self.p_weights = p_weights
        self.p_train = p_train
        self.add_intercept = add_intercept
        self.penalty_type = model_param_dict['penalty_type']
        self.penalty_strength = model_param_dict['penalty_strength']
        self.solver = model_param_dict['solver']
        self.tol = model_param_dict['tol']
        self.seed = model_param_dict['seed']
        self.verbose = verbose
        

        # if the p_weights does not contain torch.tensors, convert them
        if not isinstance(self.p_weights[1], torch.Tensor):
            self.p_weights = {k: torch.tensor(v, requires_grad=False) for k, v in self.p_weights.items()}

    
    def fit_super(self, X, y):
        super().fit(X, y)

    def get_sample_groups(self, X, y, g, seed):
        """
        Create a sample based on the group membership.

        """

        # set the seed
        set_seed(seed)

        # create a sample of the original sample
        # each sample is sampled with replacement. 
        # each sample in a group g has a probability of p_g * (1/n_g) of being selected

        # get the groups
        groups = np.unique(g)

        # loop over each group, create dict with per group: indeces, and size
        group_dict = {}
        for group in groups:
            group_dict[group] = {}
            group_dict[group]['i'] = np.where(g == group)[0]
            group_dict[group]['n'] = len(group_dict[group]['i'])
        
        # create a vector based on g; each element is the probability of being selected
        p = np.zeros(len(g))
        for group in groups:
            i_group = group_dict[group]['i']
            n_group = group_dict[group]['n']
            p_g = self.p_weights[group]
            p[i_group] =p_g* (1/n_group)

        # calculate the sum of p - this is needed to (numerically) normalize the probabilities
        p_sum = np.sum(p)
        p = p/p_sum
        
        # now, sample the indices
        i_sample = np.random.choice(np.arange(len(g)), size=len(g), replace=True, p=p)

        # get the subsample
        X_tilde = X[i_sample, :]
        y_tilde = y[i_sample]
       

        return X_tilde, y_tilde, i_sample

    def get_subsample_groups(self, X, y, g, seed):
        """
        Creates subsample of original sample
        Selects p_g * n_g unique samples from group g
        """

        

        # get the groups
        groups = np.unique(g)

        # loop over each group, create dict with per group: indeces, and size
        group_dict = {}
        for group in groups:
            group_dict[group] = {}
            group_dict[group]['i'] = np.where(g == group)[0]
            group_dict[group]['n'] = len(group_dict[group]['i'])
        
      

        # now, get the subsample per group, each of size n_tilde
        i_sample = []
        for group in groups:
            i_group = group_dict[group]['i']
            n_g = group_dict[group]['n']
            m_g = torch.ceil(self.p_weights[group]*n_g).int().item()
           
            if self.verbose:
                print('Sampling without replacement for group {}, sampling {} from the original {}'.format(group, m_g, n_g))
            np.random.seed(seed)
            i_sample_group = np.random.choice(i_group, m_g, replace=False)

            i_sample.append(i_sample_group)
        
        
        # now, combine the indeces
        i_sample = np.concatenate(i_sample)
        self.m = len(i_sample)
        if self.verbose:
            print('Size of subsample: {}'.format(len(i_sample)))
        
        # get the subsample
        X_tilde = X[i_sample, :]
        y_tilde = y[i_sample]
      
        return X_tilde, y_tilde, i_sample
    
    def fit(self, X, y, g, seed=0, subsample=True):

        # get the subsample
        if subsample:
            X_tilde, y_tilde, i_sample = self.get_subsample_groups(X, y, g, seed)
        else:
            X_tilde, y_tilde, i_sample = self.get_sample_groups(X, y, g, seed)
            
        print('seed: {}, first 10 indeces: {}'.format(seed, i_sample[:10]))
        self.X_train = X_tilde
        self.y_train = y_tilde
        self.g_train = g[i_sample]

        # fit the model
        self.fit_super(X_tilde, y_tilde)

        # set the sample
        self.i_sample = i_sample


class JTT_model(logistic_regression):

    def __init__(self,  model_param_dict_identifier, model_param_dict_predictor,  p_weights=None, p_train=None, add_intercept=True, class_balanced_identifier=True, p_train_identifier=None, create_identifier=False):
        super().__init__(model_param_dict_predictor, p_weights, p_train, add_intercept)
        

        # if class_balanced_identifier is True, we set the weights of the identifier to be class balanced
        if class_balanced_identifier:
            p_weights_identifier = {0: 0.5, 1: 0.5}
        else:
            p_weights_identifier = p_train_identifier
        
        # set the identifier model
        if create_identifier:
            self.identifier = logistic_regression(model_param_dict_identifier, p_weights_identifier, p_train_identifier, add_intercept)
        else:
            self.identifier = None
            print('No identifier model created')

        # set the p_train
        self.p_train_model = p_train


    def get_p_dict(self, g):

        # select all the groups
        groups = [int(group.item()) for group in torch.unique(g)]
        p_dict = dict.fromkeys(groups, 0)
        n = len(g) 

        for group in groups:
            n_g = (g==group).sum()
            p_g = n_g/n
            p_dict[group] = p_g.item()

        return p_dict

    # fits the logistic regression model
    def fit_super(self, X, y, g):
        super().fit(X, y, g)

    # fits the identifier model
    def fit_identifier(self, X, y):

        # fit the model, using the y as the group variable for potential class-balancing
        self.identifier.fit(X, y, g=y.squeeze(-1))

        # set the parameters
        self.Beta_identifier = self.identifier.Beta

    # get the predictions of the identifier
    def predict_identifier(self, X):


        # add intercept if needed
        if self.add_intercept:
            X = torch.concatenate([torch.ones(X.shape[0], 1), X], axis=1)
            X = X.to(torch.float32)
        # get the predictions
        pred = torch.round(torch.sigmoid(torch.matmul(X, self.Beta_identifier)))
   
        return pred
    
    # get the group variable based on the identifier and y
    def get_group_for_JTT(self, y, mistakes):

        # define the groups as follows: g=1 if y=1, and mistake=1, g=2 if y=1 and mistake=0, g=3 if y=0 and mistake=1, g=4 if y=0 and mistake=0
        g = torch.zeros(y.shape[0], 1)
        g[(y == 0) & (mistakes == 0)] = 1
        g[(y == 0) & (mistakes == 1)] = 2
        g[(y == 1) & (mistakes == 1)] = 3
        g[(y == 1) & (mistakes == 0)] = 4

        return g
    
    
    def get_g_train_JTT(self, y_train, y_hat_class_train, batched=False):

     
        if batched:
            # get the mistakes in batches
            mistakes = torch.zeros(y_train.shape[0])

            # get the mistakes
            interval = 1000
            for i in range(0, y_train.shape[0], interval):
                if (i+interval) < y_train.shape[0]:
                    y_hat_class_train_batch = y_hat_class_train[i:i+1000]
                    y_train_batch = y_train[i:i+1000]
                    mistakes[i:i+1000] = (y_train_batch != y_hat_class_train_batch).to(torch.float32)
                else:
                    y_hat_class_train_batch = y_hat_class_train[i:]
                    y_train_batch = y_train[i:]
                    mistakes[i:] =  (y_train_batch != y_hat_class_train_batch).to(torch.float32)
        else:
            mistakes = (y_train != y_hat_class_train).to(torch.float32)
        
                                              
        print('total observations: {}'.format(y_train.shape[0]))
        print('how many mistakes in (1) total: {}, (2) class y=0: {}, and (3) class y=1: {}'.format(torch.sum(mistakes), torch.sum(mistakes[y_train==0]), torch.sum(mistakes[y_train==1])))
        
        g_train_JTT = self.get_group_for_JTT(y_train, mistakes)
        return g_train_JTT

    def fit(self, X_train, y_train, lambda_JTT=1, p_y_JTT=0.5, pred_train=None):

        # check if pred_train is None
        if pred_train is None:
            # first, fit the identifier
            self.fit_identifier(X_train, y_train)

            # get the predictions where the identifier made a mistake
            y_hat_class = self.predict_identifier(X_train)
        else:
            y_hat_class = torch.round(pred_train)

        # define group variable via combinations of y and mistakes
        g_train_JTT = self.get_g_train_JTT(y_train, y_hat_class)
        self.g_train_JTT =g_train_JTT

        # show counts
        print('Division of groups in JTT model: {}'.format(np.unique(g_train_JTT, return_counts=True)))

        # now, create JTT weights for an sklearn model
        p_y_1 = y_train.float().mean()
        weight_class_1 = p_y_JTT/p_y_1
        weight_class_0 = (1-p_y_JTT)/( 1 - p_y_1)

        # then, we need to apply additional weights to cases where mistakes are made - group 2, 3
        # this is the lambda_JTT weight
        weight_g_1 = weight_class_1 
        weight_g_2 = weight_class_0 * lambda_JTT
        weight_g_3 = weight_class_1 * lambda_JTT
        weight_g_4 = weight_class_0 

        # create a weight vector
        weights_JTT = {1: weight_g_1, 2: weight_g_2, 3: weight_g_3, 4: weight_g_4}
        self.weights_JTT = weights_JTT

        # now, fit an sklearn model using model_param_dict_predictor
        super().__getattribute__('weights_obj').weights = weights_JTT
        self.fit_super(X_train, y_train, g_train_JTT.squeeze(-1))



class DFR_model(logistic_regression_subsample):

    def __init__(self, model_param_dict, p_weights=None, p_train=None, add_intercept=True, k_models=10):
        super().__init__(model_param_dict, p_weights, p_train, add_intercept)
        self.k_models = k_models
        self.parallel_fit = model_param_dict['parallel_fit']

    def fit_super(self, X, y):
        return super().fit_super(X, y)
    
    def fit_single_model(self, k, X, y, g, subsample):
        
        # get the subsample
        if subsample:
            X_tilde_k, y_tilde_k, i_sample_k = self.get_subsample_groups(X, y, g, seed=k)
        else:
            X_tilde_k, y_tilde_k, i_sample_k = self.get_sample_groups(X, y, g, seed=k)

        # fit the model
        self.fit_super(X_tilde_k, y_tilde_k)

        # get the parameters of the model
        return self.Beta

    def fit(self, X, y, g, subsample=True):
        # create tensor to store the parameters
        Beta = torch.zeros((self.k_models, X.shape[1]+1, 1))

        if self.parallel_fit:
            print('Parallel execution')
            # Use multiprocessing for parallel execution
            with mp.Pool(processes=mp.cpu_count()) as pool:
                fit_func = partial(self.fit_single_model, X=X, y=y, g=g, subsample=subsample)
                results = pool.map(fit_func, range(self.k_models))

            # in this case, the m has to be saved
            # first, get the samples per group
            n_per_group = [len(np.where(g == group)[0]) for group in torch.unique(g)]
            print('p_weights: {}'.format(self.p_weights))

            # second, calculate the m via weighted sum          
            self.m = torch.as_tensor(sum([np.ceil(n_per_group[(group-1)]*self.p_weights[group].numpy()) for group in list(self.p_weights.keys())])).item()
            
            for k, Beta_i in enumerate(results):
                Beta[k, :, :] = Beta_i
        else:
            # Sequential execution
            for k in range(self.k_models):
                Beta[k, :, :] = self.fit_single_model(k, X, y, g, subsample)
        
        # average over the parameters
        Beta_avg = torch.mean(Beta, axis=0).to(torch.float32)

        # set the parameters
        self.Beta = Beta_avg


class least_squares(model):


    def __init__(self, model_param_dict, p_weights=None, p_train=None, add_intercept=True):
        super().__init__(weights(p_weights, p_train))

        self.p_weights = p_weights
        self.p_train = p_train
        self.add_intercept = add_intercept
        self.model_param_dict = model_param_dict
        self.l_1_penalty = 0.0
        self.l_2_penalty = 0.0


    def expected_bias_intercept(self, a_1, a_0, p_1_te, p_1_tilde):
        # get the difference in the intercepts
        diff = (a_1 - a_0)**2

        # get the expected bias for group 1, group 0
        bias_group_1 = ((1- p_1_tilde)**2) * diff
        bias_group_0 = (p_1_tilde**2) * diff

        # get the total bias
        p_0_te = 1 - p_1_te
        bias = p_1_te * bias_group_1 + p_0_te * bias_group_0

        return bias
    
    def expected_var_intercept(self, a_1, a_0, p_1_tilde, p_1_tr, n, sigma_1, sigma_0, d):

        # define weight terms
        phi_1 = (p_1_tilde**2)/p_1_tr
        phi_0 = ((1-p_1_tilde)**2)/(1-p_1_tr)
        phi = phi_1 + phi_0
    
        # get the expected variance
        sigma_tilde= (sigma_0 * (1-p_1_tilde)) + (sigma_1 * p_1_tilde)
       
        # get the variance
        factor = ((d+1)/(n))*phi
        var =  factor * ( sigma_tilde)

        return var
    
    
    def expected_loss_intercept(self, a_1, a_0,  p_1_tilde,  p_1_te, p_1_tr, n, sigma_1, sigma_0, d):

        bias = self.expected_bias_intercept(a_1, a_0, p_1_te, p_1_tilde)

        var = self.expected_var_intercept(a_1, a_0, p_1_tilde, p_1_tr, n, sigma_1, sigma_0, d)
        
        noise = (1-p_1_te) * sigma_0 + p_1_te * sigma_1

        loss = bias + var + noise

        return loss, bias, var
    
   

    def return_params(self):
        # return the parameters of the OLS
        return (self.Beta)

    def fit(self, X, y, g=None, set_X_y_train=True):

        # add intercept if needed
        if self.add_intercept:
            X = np.c_[np.ones(X.shape[0]), X]
        
        # create weight matrix W
        if g is None:
            w = np.ones(X.shape[0])
        else:
            w = self.weights_obj.get_weights_sample(g).squeeze(-1)

        W = np.diag(w)
  
        # create X_t_W_X
        X_t_W_X = np.matmul(np.matmul(X.T, W), X)

        # create X_t_W_y
        X_t_W_y = np.matmul(np.matmul(X.T, W), y)

        # solve for beta
        Beta = np.matmul(np.linalg.inv(X_t_W_X), X_t_W_y)

        # set beta
        self.Beta = Beta

        # set X_train and y_train
        if set_X_y_train:
            self.X_train = X
            self.y_train = y



    def predict(self, X):
        if self.add_intercept:
            X = np.c_[np.ones(X.shape[0]), X]
  
        return np.matmul(X, self.Beta)
    

    def MSE(self, y_hat, y, w_loss=None):


        # get the loss per observation
        loss = (y - y_hat)**2
        
        # get the weighted loss
        if w_loss is not None:
            loss = loss * w_loss
            

        # get the squared error - check if numpy or torch
        if isinstance(loss, np.ndarray):
            MSE = np.mean(loss)
        else:
            MSE = torch.mean(loss)
                
        return MSE
    

    def loss(self, y_hat, y, w_loss=None):
        MSE = self.MSE(y_hat, y, w_loss)
        return MSE
    
    def train_loss_via_param(self,Beta):

        # add intercept if needed
        if self.add_intercept:
            X_train = torch.cat([torch.ones(self.X_train.shape[0], 1), self.X_train], axis=1).to(torch.float32)
        else:
            X_train = self.X_train.to(torch.float32)

        # get the predictions based on the X_train and Beta
        y_hat = torch.matmul(X_train, Beta)
        
        # get the loss based on the y_hat and y_train
        train_loss = self.loss(y_hat, self.y_train, w_loss=self.w_train)


        # get the penalty
        l_1_penalty = self.l_1_penalty
        l_2_penalty = self.l_2_penalty

        # add penalty if needed
        if l_1_penalty > 0:
            # square each beta
            eps= 1e-6
            Beta_sq = (Beta**2 + eps).squeeze(-1)
            Beta_sq_sqrt = torch.sqrt(Beta_sq)
            l_1_norm_Beta = torch.sum(Beta_sq_sqrt) # approximating the l1 norm
            train_loss += l_1_penalty * l_1_norm_Beta
        if l_2_penalty > 0:
            Beta_T_Beta = torch.matmul(Beta.T, Beta)
            train_loss += l_2_penalty * Beta_T_Beta


        return train_loss
    
        

    
    def augmented_loss_via_param(self, Beta):

        # add intercept if needed
        if self.add_intercept:
            X_train = torch.cat([torch.ones(self.X_train.shape[0], 1), self.X_train], axis=1).to(torch.float32)
        else:
            X_train = self.X_train.to(torch.float32)
            
        # get the predictions based on the X_train and Beta
        y_hat = torch.matmul(X_train, Beta)
    
        # get the groups
        groups = torch.unique(self.g_train)

        # get the last group, defined as the group with the highest index
        last_group = max(groups)

        # calculate the augmented loss for each group, except the last group, and combine in torch.tensor
        augmented_loss_1 = self.augmented_loss(y_hat, self.y_train, self.g_train, 1, last_group)
        augmented_loss_2 = self.augmented_loss(y_hat, self.y_train, self.g_train, 2, last_group)


        return (augmented_loss_1, augmented_loss_2)
     
    def weighted_val_loss_via_param(self, Beta):

        # add intercept if needed
        if self.add_intercept:
            X_val = torch.cat([torch.ones(self.X_val.shape[0], 1), self.X_val], axis=1).to(torch.float32)
        else:
            X_val = self.X_val.to(torch.float32)

        # get the predictions based on the X_val and Beta
        y_hat = torch.matmul(X_val, Beta)

        # get the loss based on the y_hat and y_val
        print('unique w_val: {}'.format(torch.unique(self.w_val)))
        val_loss = self.loss(y_hat, self.y_val, w_loss=self.w_val)

        return val_loss
    
    def augmented_loss(self, y_hat, y, g,  selected_group, last_group):

        # get the predictions/labels for the selected group
        y_hat_selected = y_hat[g == selected_group]
        y_selected = y[g == selected_group]

        # get the predictions/labels for the last group
        y_hat_last = y_hat[g == last_group]
        y_last = y[g == last_group]

        # get the loss for the selected group, last group
        loss_selected = self.loss(y_hat_selected, y_selected)
        loss_last = self.loss(y_hat_last, y_last)

        # return loss difference
        return (loss_selected - loss_last)
    
    
    def calc_weighted_grad(self, X, beta, w,  y, l_1_penalty, l_2_penalty, eps=1e-6):
        """
        Calculate the gradient of the mean squared error loss function
        """

        # add intercept if needed
        if self.add_intercept:
            X = torch.cat([torch.ones(X.shape[0], 1), X], axis=1).to(torch.float32)

        # calculate the X_T_w_y
        X_T_w_y = torch.matmul(torch.matmul(X.T, torch.diag(w)), y)

        # calculate X_T_w_X*beta
        X_T_w_X_beta = torch.matmul(torch.matmul(X.T, torch.diag(w)), torch.matmul(X, beta))

        # calculate the gradient
        grad = (-X_T_w_y + X_T_w_X_beta)*2

       
        # add the l_2 penalty
        if l_2_penalty>0:
            grad += 2 * l_2_penalty * beta

        elif l_1_penalty>0:
            # the following is an approximation of the derivative of the l_1 penalty
            beta_squared = (beta**2).squeeze()
            sqrt_beta_squared = torch.sqrt(beta_squared + eps)


            grad += ((beta / sqrt_beta_squared) * l_1_penalty)

        return grad
    
    
    def calc_weighted_Hessian(self, X, beta, w,  l_1_penalty, l_2_penalty, eps=1e-6):
        """
        Calculate the Hessian of the mean squared error loss function
        """

        # add intercept if needed
        if self.add_intercept:
            X = torch.cat([torch.ones(X.shape[0], 1), X], axis=1).to(torch.float32)

        # calculate X_T_W_X
        X_T_W_X = torch.matmul(torch.matmul(X.T, torch.diag(w)), X)
        H = 2*X_T_W_X


        # add the l_2 penalty
        if l_2_penalty>0:
            H += torch.eye(H.shape[0])*l_2_penalty *2

        elif l_1_penalty>0:
            # the following is an approximation of the derivative of the l_1 penalty
            beta_squared = (beta**2).squeeze()
            H_l_1_approx =   eps/((beta_squared + eps)**(3/2))
            H_l_1_approx_diag = torch.diag(H_l_1_approx)*l_1_penalty
            H +=H_l_1_approx_diag

        return H
    
    def calc_quadratic_solution(self, X, y, w, beta_0, l_1_penalty, l_2_penalty):
        # set the initial beta
        self.Beta = beta_0
     
        # calculate the gradient
        grad = self.calc_weighted_grad(X, beta_0, w,  y, l_1_penalty, l_2_penalty,).mean(-1).unsqueeze(-1)

        # calculate the Hessian
        H = self.calc_weighted_Hessian(X, beta_0, w, l_1_penalty, l_2_penalty)

        # calculate the solution
        beta = beta_0 - torch.matmul(torch.inverse(H), grad)

        return beta, grad, H
    
        
        

    

class least_squares_subsample(least_squares):

    def __init__(self, model_param_dict, p_weights=None, p_train=None, add_intercept=True):
        super().__init__(p_weights, p_train)

        self.p_weights = p_weights
        self.p_train = p_train
        self.add_intercept = add_intercept
        self.model_param_dict = model_param_dict


    
    def expected_bias_intercept(self, a_1, a_0, p_1_te, p_0_te, p_1_tilde):
        # get the difference in the intercepts
        diff = (a_1 - a_0)**2

        # get the expected bias for group 1, group 0
        bias_group_1 = ((1- p_1_tilde)**2) * diff
        bias_group_0 = (p_1_tilde**2) * diff

        # get the total bias
        bias = p_1_te * bias_group_1 + p_0_te * bias_group_0

        return bias
    
    def expected_var_intercept(self, a_1, a_0, p_1_tilde, n_tilde, sigma_1, sigma_0, d):
        # get the expected variance
        sigma_tilde = ((sigma_0 * (1-p_1_tilde)) + (sigma_1 * p_1_tilde) )
        var_from_intercept_diff = ((a_1 - a_0)**2) * (p_1_tilde * (1-p_1_tilde))
        factor = (d+1)/(n_tilde  )
        var =  factor * (sigma_tilde + var_from_intercept_diff)

        return var
    
    
    def expected_loss_subsample_intercept(self, a_1, a_0,  p_1_te, p_0_te, p_1_tilde, n_tilde, sigma_1, sigma_0, d):

        bias = self.expected_bias_intercept(a_1, a_0, p_1_te, p_0_te, p_1_tilde)

        var = self.expected_var_intercept(a_1, a_0, p_1_tilde, n_tilde, sigma_1, sigma_0, d)

        noise = p_0_te * sigma_0 + p_1_te * sigma_1

        loss = bias + var + noise

        return loss, bias, var


    def get_subsample_two_groups(self, X, y, g, seed):

        # set the seed
        set_seed(seed)

        # create a sample of size n_tilde from the original sample
        # first, take all the observations with g = 0 
        i_0 = np.where(g == 0)[0]
        X_0 = X[i_0, :]
        n_0 = X_0.shape[0]

        # probability of sampling group 1 observations
        p_1_tilde = self.p_weights[1]

        # how many obs. in the subsample?
        n_tilde =int(np.round(n_0/(1-p_1_tilde)))
        self.n_tilde = n_tilde
        n_1_tilde =  n_tilde - n_0
        self.n_1_tilde = n_1_tilde
        self.n_0_tilde = n_0

        # where to take the sample from?
        i_1 = np.where(g == 1)[0]
        i_sample_0 = i_0

        # pick, from i_1, n_1_tilde observations
        if n_1_tilde >= len(i_1):
            i_sample_1 = i_1
        else:   
            i_sample_1 = np.random.choice(i_1, int(n_1_tilde), replace=False)
 
        # now, take a sample of n_1_tilde observations with g = 1, g = 0
        X_0_tilde = X_0
        X_1_tilde = X[i_sample_1, :]

        # combine the two samples indeces
        i_sample = np.concatenate([i_sample_0, i_sample_1])

        # combine the two samples
        X_tilde = np.concatenate((X_0_tilde, X_1_tilde), axis=0)

        self.X_tilde = X_tilde

        # create Y_tilde
        y_tilde = y[i_sample]
        
        return X_tilde, y_tilde, i_sample
    
    def fit_subsample(self, X, y, g, seed=0,  set_X_y_train=True):

        # get the subsample
        X_tilde, y_tilde, i_sample = self.get_subsample_two_groups(X, y, g, seed)

        # fit the model
        self.fit(X_tilde, y_tilde, set_X_y_train=set_X_y_train)

        # set the sample
        self.i_sample = i_sample


       




class resnet_model(nn.Module):

   
    def __init__(self, base_model, y_dim, pretrained=True):
        super(resnet_model, self).__init__()
    
        # if Resnet 34, 18
        if base_model == 'resnet34':
            self.base_model = models.resnet34(pretrained=pretrained)
        elif base_model == 'resnet18':
            self.base_model = models.resnet18(pretrained=pretrained)
        elif base_model == 'resnet50':
            self.base_model = models.resnet50(pretrained=pretrained)

        # if the base model is a resnet, then change model.fc to a linear layer with num_concepts output
        if base_model == 'resnet34' or base_model == 'resnet18' or base_model == 'resnet50':
            num_features = self.base_model.fc.in_features
            self.base_model.fc = nn.Linear(num_features, y_dim)
            
     
    def forward(self, x):
        
        # pass through linear layer
        y_pred = self.base_model(x)
    
        return y_pred


class embedding_creator_resnet(nn.Module):
    """
    Create version of the model, specifically for creating embeddings
    """

    def __init__(self, original_resnet, prevent_finetuning = True):
        super(embedding_creator_resnet, self).__init__()

        # set require_grad = False to all the original resnet layers
        if prevent_finetuning:
            for param in original_resnet.parameters(): # this ensures these layers are not trained subsequently
                param.requires_grad = False

        # select everything but the last layer
        self.list_feature_modules = list(original_resnet.children())[:-1]

        # define the feature extractor 
        self.feature_extractor =  nn.Sequential(
                        # stop at conv4
                        *self.list_feature_modules
                    )
        
        self.flatten = nn.Flatten(start_dim =1, end_dim= -1)
    
    def forward(self, x):

        # get embedding
        embedding = self.feature_extractor(x)
        embedding_flat = self.flatten(embedding)

        return embedding_flat


def rename_fc(model):
    model.fc = model.classifier
    delattr(model, "classifier")

    def classifier(self, x):
        return self.fc(x)
    
    model.classifier = types.MethodType(classifier, model)

    return model



def bert_forward_for_classification(model):

   
    model.base_forward = model.forward

    def forward(self, x):
        return self.base_forward(
            input_ids=x[:, :, 0],
            attention_mask=x[:, :, 1],
            token_type_ids=x[:, :, 2]).logits

    model.forward = types.MethodType(forward, model)
    return model


def bert_forward_for_embeddings(model):


    model.base_forward = model.forward

    def forward(self, x):
        return self.bert(
            input_ids=x[:, :, 0],
            attention_mask=x[:, :, 1],
            token_type_ids=x[:, :, 2],
            output_hidden_states=True,
            return_dict=True
        )
    model.forward = types.MethodType(forward, model)
    return model


    
class BERT_model(nn.Module):

    def __init__(self, output_dim,  output_attentions = False,  output_hidden_states = False, replace_fc=True, embedding_creator_forward=False):
        super(BERT_model, self).__init__()

        if replace_fc:
            self.model = BertForSequenceClassification.from_pretrained(
                'bert-base-uncased', num_labels=output_dim, output_attentions = output_attentions, output_hidden_states = output_hidden_states)
            self.model = bert_forward_for_classification(self.model)
        elif embedding_creator_forward:
            self.model = BertForSequenceClassification.from_pretrained(
                'bert-base-uncased', num_labels=output_dim, output_attentions = output_attentions, output_hidden_states = output_hidden_states)
            self.model = bert_forward_for_embeddings(self.model)
            
        
        
    def forward(self, x):
        return self.model(x)
    
    

class embedding_creator_BERT(nn.Module):
    """
    Create version of the model, specifically for creating embeddings
    """

    def __init__(self, original_BERT, embedding_type):
        super(embedding_creator_BERT, self).__init__()

        # set require_grad = False to all the original resnet layers
        for param in original_BERT.parameters(): # this ensures these layers are not trained subsequently
            param.requires_grad = False
      
        # define the feature extractor 
        self.feature_extractor =  original_BERT

        # define the flatten layer
        self.flatten = nn.Flatten(start_dim =1, end_dim= -1)
        self.embedding_type = embedding_type
    
    def forward(self, x):

        # get embedding
        output = self.feature_extractor(x)

        if self.embedding_type == 'CLS':
            embedding =  output.hidden_states[-1][:,0,:].detach()
        elif self.embedding_type == 'pool':
            embedding = output[1].detach()

        
        # flatten the embedding
        embedding_flat = self.flatten(embedding)

        return embedding_flat



def get_predictions_in_batches(model, loader):

    # total steps to take
    total_step = len(loader)


    model = model.eval()   # Set model to evaluate mode
    list_pred_batches = [None]*total_step

    # go over each batch
    with torch.no_grad():
        for i, (images) in enumerate(loader):
            
            # input and output batch
            b_x =images[0]

            # if batch size is 1, remove the first dimension
            if b_x.shape[0] == 1:
                b_x = b_x.squeeze(0)

            # get output and calc acc
            pred = torch.sigmoid(model(b_x))
            list_pred_batches[i] = pred
            
            print('Step {}/{}'.format(i+1, total_step))

            
        if len(list_pred_batches) == 1:
            all_preds = list_pred_batches[0]

        else:
            embeddings = torch.stack(list_pred_batches[:-1], dim=0).flatten(start_dim=0, end_dim=1)
            last_embedding = list_pred_batches[-1]
            all_preds = torch.cat((embeddings, last_embedding), dim = 0)
            
    return all_preds  



def get_embedding_in_batches_tokens(model, loader, device, save_dir='./temp_embeddings', classifier=None):
    model = model.eval()
    os.makedirs(save_dir, exist_ok=True)

    # files to load
    file_paths =[]
    
    with torch.no_grad():
        for i, (x, y) in enumerate(loader):
            x = x.to(device)
            embedding = model(x)

            # check: get pred from classifier
            if classifier is not None:
                pred_class = torch.round(torch.sigmoid(classifier(embedding)).cpu()).squeeze(-1)
                print(pred_class)
                print(y)
                acc = (pred_class == y.cpu()).sum().item()/y.shape[0]
                print('Accuracy: {}'.format(acc))
                
            
            # Save embedding to disk
            file_path = os.path.join(save_dir, f'embedding_batch_{i+1}.npy')
            file_paths.append(file_path)
            np.save(file_path, embedding.cpu().numpy())
            
            print(f'Processed and saved batch {i+1}/{len(loader)}')
            
            del embedding, x
            if device == 'mps':
                torch.mps.empty_cache()
                torch.mps.synchronize()
            gc.collect()

    # Load all embeddings from disk and concatenate
    all_embeddings = []
    for file_path in file_paths:
        embedding = torch.from_numpy(np.load(file_path))
        all_embeddings.append(embedding)
        os.remove(file_path)  # Remove the file after loading

    # Remove the temporary directory
    shutil.rmtree(save_dir)

    return torch.cat(all_embeddings, dim=0)


def get_embedding_in_batches_images( model, loader, to_float32):

  # total steps to take
  total_step = len(loader)
  print('Total steps: {}'.format(total_step))
  
  # set model to evaluation mode
  model = model.eval()   # Set model to evaluate mode
  list_embedding_batches = [None]*total_step

  with torch.no_grad():

    # go over each batch
    for i, (images) in enumerate(loader):
        
        # input and output batch
        b_x =images[0]


        # if batch size is 1, remove the first dimension
        if b_x.shape[0] == 1:
            b_x = b_x.squeeze(0)



        # convert to float32 if needed
        if to_float32:
            b_x = b_x.to(torch.float32)
        
        # check: if the shape is not 4d, add a dimension
        if len(b_x.shape) == 3:
            b_x = b_x.unsqueeze(0)

        # get output and calc acc
        embedding = model(b_x)
        
        # save the embedding
        list_embedding_batches[i] = embedding

        
        print('Step {}/{}'.format(i+1, total_step))

  if len(list_embedding_batches) == 1:
    all_embeddings = list_embedding_batches[0]

  else:
    
    embeddings = torch.stack(list_embedding_batches[:-1], dim=0).flatten(start_dim=0, end_dim=1)
  
    last_embedding = list_embedding_batches[-1]
    all_embeddings = torch.cat((embeddings, last_embedding), dim = 0)
    

  return all_embeddings       

    





    
    
       

    
    
        
