
import torch
import sys
from weights import weights
import copy
from torch.autograd.functional import hessian, jacobian
import numpy as np
import time
from helpers import fast_xtdx, set_seed

class weight_searcher():
    #least_squares, p_train, X_train, y_train, g_train, X_val, y_val, g_val, weights_val_obj
    def __init__(self, model_class, p_train, X_train, y_train, g_train, X_val, y_val, g_val, weights_val_obj, grad_type='finite_diff', weight_rounding=6, GDRO=False, **model_param):

        # set the attributes
        self.model_class = model_class
        self.p_train = p_train
        self.X_train = X_train
        self.y_train = y_train
        self.g_train = g_train
        self.X_val = X_val
        self.y_val = y_val
        self.g_val = g_val
        self.grad_type = grad_type
        self.weight_rounding = weight_rounding
        self.weights_val_obj = weights_val_obj
        self.weights_val = self.weights_val_obj.get_weights_sample(self.g_val)
        self.GDRO = GDRO

        # take the smallest group, calculate number of samples
        min_group = min(self.p_train, key=self.p_train.get)
        n_dict = {g: torch.sum(g_train == g) for g in torch.unique(g_train)}
        n_dict[min_group] = torch.sum(g_train == min_group)
        self.n_dict = n_dict

        # round the weights to the specified number of decimal places for p_train
        self.p_train = {g: self.p_train[g] for g in self.p_train.keys()}

        # set the number of groups
        self.n_groups = torch.unique(g_train).shape[0]

        # here, the gradient is saved - 0 for group 0, 1 for group 1, etc.
        self.current_grad = torch.zeros(self.n_groups)

        # save dict of model param
        self.model_param = model_param

        # save the penalty type and l_1, l_2 penalty
        self.penalty_type = model_param['penalty_type']
        self.penalty_strength = model_param['penalty_strength']
        self.parallel_fit = model_param['parallel_fit']
    
       
         # check; if penalty_type is l1, set l_1_penalty to penalty_strength
        if self.penalty_type == 'l1':
            self.l_1_penalty = torch.tensor(float(self.penalty_strength))
        else:
            self.l_1_penalty = torch.tensor(0.)
        
        # check; if penalty_type is l2, set l_2_penalty to penalty_strength
        if self.penalty_type == 'l2':
            self.l_2_penalty = torch.tensor(float(self.penalty_strength))
        else:
            self.l_2_penalty = torch.tensor(0.)
        


    def reset_to_smallest_group_last(self, p_train, p_weights, subsample_weights=True):
        """
        Reset the weights to the smallest group last
        """
        print('p_train is {}'.format(p_train))

        # first, change the p_train, p_weights to have the smallest group last
        p_train, p_weights = self.weights_val_obj.reset_p_dict_smallest_last(p_train, p_weights) 

        # reset p_train of this object
        self.p_train = p_train

        # second, reset the g_train, g_val
        g_train = self.weights_val_obj.reset_g_smallest_last(self.g_train)
        g_val = self.weights_val_obj.reset_g_smallest_last(self.g_val)

        # reset the g_train, g_val
        self.g_train = g_train
        self.g_val = g_val

        # reset the p_train, p_weights of the weights val object
        self.weights_val_obj = weights(p_weights=p_weights, p_train=p_train)

        # set the weights for the validation data
        self.weights_val_obj.set_weights_per_group()

        # set the weights for the validation data
        self.weights_val = self.weights_val_obj.get_weights_sample(self.g_val)

        return p_train, p_weights, g_train, g_val



    

    def round_p_dict(self, p):

        # round each entry in p to the specified number of decimal places
        p = {g: torch.round(p[g], decimals=self.weight_rounding) for g in p.keys()}

        return p

    def sum_to_1_p_dict(self, p):

        # normalize each entry in p to sum to 1
        p_sum = sum(p.values())
        p = {g: p[g] / p_sum for g in p.keys()}

        return p
    
    def get_max_p_dict(self, p):
        return max(p.values())
    
    def clip_p_dict(self, p, min_p=0.0, max_p=1.0):
        
        # convert min_p and max_p to tensors
        min_p_tensor = torch.as_tensor(min_p)
        max_p_tensor = torch.as_tensor(max_p)

        # clip each entry in p to be higher than min_p, lower than max_p
        p = {g: min(max_p_tensor, max(min_p_tensor, p[g])) for g in p.keys()}

        return p
    
    def clip_p_dict_per_group(self, p, p_min, p_max):

        # check; if p_min is a float, convert to tensor and apply to all groups
        if type(p_min) == float:
            p_min = {g: torch.tensor(p_min) for g in p.keys()}
        
        # check; if p_max is a float, convert to tensor and apply to all groups
        if type(p_max) == float:
            p_max = {g: torch.tensor(p_max) for g in p.keys()}

        # clip each entry in p to be higher than min_p, lower than max_p
        p = {g: min(p_max[g], max(p_min[g], p[g])) for g in p.keys()}

        return p
    
    

    def tuple_of_tensors_to_tensor(self, tuple_of_tensors):
        return  torch.stack(list(tuple_of_tensors), dim=0)
    

    


    def sigmoid_x_beta(self, x, beta):

        # add intercept to x
        x = torch.cat((torch.ones(x.shape[0], 1), x), dim=1).to(torch.float32)
  
        # calculate the sigmoid
        return 1.0 / (1.0 + torch.exp(-torch.matmul(x, beta)))
    
    def calc_grad_logistic_loss(self, X, beta, w,  y, l_1_penalty, l_2_penalty, eps=1e-6):
        """
        Calculate the gradient of the logistic loss function
        """

        # calculate the sigmoid
        sigmoid = self.sigmoid_x_beta(X, beta).squeeze()

        # add the intercept to X
        X = torch.cat((torch.ones(X.shape[0], 1), X), dim=1).to(torch.float32)

        # calculate the gradient in two steps:
        # 1. first term: w(sigmoid - y)
        # 2. second term: multiply X^T with the first term
        grad = torch.matmul(X.T, w*(sigmoid - y))

        # divide by the number of samples
        grad /= X.shape[0]

        # add the l_2 penalty
        if l_2_penalty>0:
            grad += 2 * l_2_penalty * beta

        elif l_1_penalty>0:
            # the following is an approximation of the derivative of the l_1 penalty
            beta_squared = (beta**2).squeeze()
            sqrt_beta_squared = torch.sqrt(beta_squared + eps)


            grad += ((beta / sqrt_beta_squared) * l_1_penalty)

        return grad
    


    
    
    def calc_Hessian_weighted_logistic_loss(self, X,  w, beta, l_1_penalty, l_2_penalty, eps=1e-6, divide_by_n=True):
        """
        Calculate the Hessian of the logistic loss function
        """

        # create a diagonal matrix with inputs sigmoid(x_i^T beta) * (1 - sigmoid(x_i^T beta))
        sigmoid = self.sigmoid_x_beta(X, beta).squeeze()
        
        # calculate the diagonal matrix
        diag_H = ((sigmoid * (1 - sigmoid)) * w ).to(torch.float32)
        
        # add the intercept to X
        X = torch.cat((torch.ones(X.shape[0], 1), X), dim=1).to(torch.float32)
        
        # calculate the Hessian
        H = fast_xtdx(X, diag_H)

        # divide by the number of samples
        if divide_by_n:
            H /= X.shape[0]

        # add the l_2 penalty
        if l_2_penalty>0:
            added_term = torch.eye(H.shape[0])*l_2_penalty *2

            # add the term
            if divide_by_n:
                H+= (added_term/X.shape[0])
            else:
                H += added_term

        elif l_1_penalty>0:
            # the following is an approximation of the derivative of the l_1 penalty
            beta_squared = (beta**2).squeeze()
            H_l_1_approx =   eps/((beta_squared + eps)**(3/2))
            H_l_1_approx_diag = torch.diag(H_l_1_approx)*l_1_penalty
            added_term = H_l_1_approx_diag/X.shape[0]

            # add the term
            if divide_by_n:
                H += (added_term/X.shape[0])
            else:
                H += added_term

        return H
        
    
    
    

    def weight_grad_via_ift(self, model, p, X_train, y_train, g_train, X_val, y_val, g_val, weights_val_obj, eps=1e-6,  analytical_hessian=False, subsample_weights=False):

        # create a copy of the starting weights
        groups = list(p.keys())
        last_group = self.n_groups 

        # # set the X_train, y_train, X_val, y_val, and g_val
        model.X_train = X_train
        model.y_train = y_train
        model.g_train = g_train

        model.X_val = X_val
        model.y_val = y_val
        model.g_val = g_val
        
        # set the w_train and w_val
        w_train =  torch.as_tensor(model.weights_obj.get_weights_sample(g_train))
        w_val = torch.as_tensor(weights_val_obj.get_weights_sample(g_val))
        model.w_train = w_train
        model.w_val = w_val

        # if the weights are subsampled, calculate the m - e.g. how many 
        if subsample_weights:
            self.m = model.m
            print('m is {}'.format(self.m))


        # first, calculate the Hessian
        if analytical_hessian:

            # calculate the hessian
            H = self.calc_Hessian_weighted_logistic_loss(model.X_train, w_train, model.Beta, self.l_1_penalty, self.l_2_penalty, eps=1e-6)

            # use multiplication factor if subsample_weights
            if subsample_weights:
                factor = model.X_train.shape[0]/self.m
                H *= factor

            
        else:
            H = hessian(model.train_loss_via_param, model.Beta)
  
            # check the last dim of model.Beta
            if model.Beta.shape[-1] == 1:
                H = H.squeeze()

        # ensure the Hessian is positive definite
        H += torch.eye(H.shape[0])*eps

        # if the d > n, use the moores-penrose inverse
        if H.shape[0] > H.shape[1]:
            H_inv = torch.pinverse(H).detach()
        else:
            H_inv =torch.inverse(H).detach()

        # second, calc the the jacobian with respect to the augmented loss
        if subsample_weights:
            J_augmented_w = jacobian(model.augmented_loss_subsample_via_param, model.Beta)
        else:
            J_augmented_w = jacobian(model.augmented_loss_via_param, model.Beta)

        # if tuple, turn to np.array
        if type(J_augmented_w) == tuple:
            J_augmented_w = self.tuple_of_tensors_to_tensor(J_augmented_w).detach()
        else:
            J_augmented_w = J_augmented_w.detach()

        # third, calc the jacobian with respect to the weighted validation loss
        J_val_w = jacobian(model.weighted_val_loss_via_param, model.Beta)
        J_val_w = J_val_w.detach()

        
        # if the first dim is 1, squeeze - TODO: check if right
        if J_augmented_w.shape[0] == 1:
            J_augmented_w = J_augmented_w.squeeze(0)

        # calculate the derivative of the parameters with respect to w
        partial_deriv_param_w = torch.matmul(-H_inv,  J_augmented_w).squeeze(-1)

        # now, calculate the derivative of the validation loss with respect to w
        grad_ift = torch.matmul(J_val_w.T, partial_deriv_param_w.T).T
      
      
        # now, calculate the derivative
        if subsample_weights:
            grad_ift =  {g:torch.as_tensor(grad_ift[g-1]) for g in groups}

        # for the last group, sum the changes in the other groups and taking the negative
        else:

            # turn into a dictionary
            grad_ift = {g:torch.as_tensor(grad_ift[g-1]) for g in groups[:-1]}

            # calculate the change for the last group
            grad_last_group = -torch.sum(torch.as_tensor([grad_ift[g] for g in groups[:-1]]))

            # set the change for the last group based on change in all other groups
            grad_ift[last_group] = grad_last_group.reshape((1,))
       

        return grad_ift
    
    def calc_worst_group_loss_at_p(self, p):
        """
        Calculate the loss function at a given weight, parameterized by p

        Parameters:
            p: dict
                The probability for each group. Each key is an integer representing the group and each value is a float representing the weight for that group.
        """

        # calculate the loss per group
        model = self.model_class(self.model_param, p_weights=p, p_train=self.p_train)
        model.fit(self.X_train, self.y_train, self.g_train)
        loss_per_group_at_t = model.group_loss(self.X_val, self.y_val, self.g_val, acc=False)
        acc_per_group_at_t = model.group_loss(self.X_val, self.y_val, self.g_val, acc=True)

       
        worst_loss_per_group = max(loss_per_group_at_t.values())
        worst_acc_per_group = min(acc_per_group_at_t.values())
      
        
        return worst_loss_per_group, loss_per_group_at_t, worst_acc_per_group
    




    def calc_loss_for_model(self, model):
        """
        Calculate the loss function at a given model

        Parameters:
            model: model
                The model to calculate the loss function for
        """

       

        # predict the validation set
        y_hat_val = model.predict(self.X_val)
        
        
        # calculate the loss
        loss = model.loss(y_hat_val, self.y_val, w_loss=self.weights_val)

        return loss

    def calc_loss_at_p(self, p):
        """
        Calculate the loss function at a given weight, parameterized by p

        Parameters:
            p: dict
                The probability for each group. Each key is an integer representing the group and each value is a float representing the weight for that group.
        """
        
        # create a new model with the new weights
        model = self.model_class(self.model_param, p_weights=p, p_train=self.p_train)

        # fit the model on the training data
        model.fit(self.X_train, self.y_train, self.g_train)

        # predict the validation set
        y_hat_val = model.predict(self.X_val)

        # calculate the loss
        loss = model.loss(y_hat_val, self.y_val, w_loss=self.weights_val)

        return loss
    


        
    def exp_gradient_descent(self, start_p, lr,  eps, patience, T=100, save_trajectory=False, diff=0.01, momentum=None, verbose=True,  gradient_clip=None, eta=0.1, lr_schedule='constant', decay=0.9, stable_exp=True, p_min=10e-4, analytical_hessian=False, subsample_weights=False, normalize=True, lock_in_p_g = None, use_acc=False, seed=0):
        """
        Perform gradient descent to find the optimal weights for the model

        Parameters:
            start_p: dict
                The probability for each group. Each key is an integer representing the group and each value is a float representing the weight for that group.
            lr: float
                The learning rate for the gradient descent
            eps: float
                The threshold for the gradient descent. If after patience steps, no improvement greater than eps is made, the gradient descent stops
            T: int
                The maximum number of iterations for the gradient descent
            save_trajectory: bool
                If True, save the trajectory of the weights during the gradient descent
             diff: float
                The difference used to calculate the finite difference
        """

        # set learning rate, and eps as attributes
        self.lr = lr
        self.eps = eps

        # initialize the gradient and the current weight
        groups = list(start_p.keys())
        grad = dict.fromkeys(start_p, torch.tensor(999))

        # check if entries in start_p are torch tensors - if not, convert them
        start_p = {g: torch.as_tensor(start_p[g]) for g in start_p.keys()}

        # normalize the dictionary with probabilities
        p_at_t = self.round_p_dict(start_p)

        # if the trajectory is saved, initialize the trajectory
        if save_trajectory:
            p_at_t_traj = torch.zeros(T, self.n_groups)
            p_at_t_traj[0] = torch.stack([torch.tensor(value) for value in p_at_t.values()])
        # save the loss trajectory
        loss_at_t_traj = torch.zeros(T-1)

        # check if momentum is not None
        if momentum is not None:
            prev_update = dict.fromkeys(groups, torch.tensor(0))
        
        # initialize the iteration
        t = 0
        best_loss = torch.inf
        best_p = start_p
        stop_GD = False
        patience_count = patience

        # turn the decay into a tensor
        decay = torch.tensor(decay)

        # if GDRO, save the p_at_t
        if self.GDRO:
            p_GDRO_at_t = {g: torch.tensor(1/self.n_groups) for g in groups}
            best_worst_group_criterion = torch.inf
            print('The GDRO weights are initialized to {}'.format(p_GDRO_at_t))

        # calculate the model at the current weights
        model_at_t = self.model_class(self.model_param, p_weights=p_at_t, p_train=self.p_train)

        # start the gradient descent
        while not stop_GD and (t < T):
            
            # fit the model
            time_start = time.time()
            set_seed(seed)
            model_at_t.fit(self.X_train, self.y_train, self.g_train)
            time_fit = time.time() 
            print('Time taken to fit the model is {}'.format(time_fit - time_start))

            # calculate the loss at the current weights
            loss_at_t = self.calc_loss_for_model(model_at_t).detach()

            # if GDRO, calculate the worst group loss
            if self.GDRO:
                worst_group_loss_at_t, loss_per_group_at_t, worst_acc_at_t = self.calc_worst_group_loss_at_p(p_at_t)

                # Define the worst group criterion
                if use_acc:
                    worst_group_criterion = -worst_acc_at_t
                else:
                    worst_group_criterion = worst_group_loss_at_t
               
             # save the loss at t
            if save_trajectory:
                loss_at_t_traj[t-1] = loss_at_t

            # if GDRO, this is done based on the worst group loss
            if self.GDRO:
                if worst_group_criterion < best_worst_group_criterion:
                    best_worst_group_criterion = worst_group_loss_at_t
                    patience_count = patience
                    best_p = p_at_t.copy()
                else:
                    patience_count -= 1
            
            # if not GDRO, this is done based on the overall loss
            else:
                # check if the loss is less than the best loss minus eps
                if loss_at_t < (best_loss - eps):
                    best_loss = loss_at_t
                    patience_count = patience
                    best_p = p_at_t.copy()
                else:
                    patience_count -= 1
            
            # check if the patience count is 0
            if patience_count == 0:
                stop_GD = True
            

            # if GDRO, change the weights based on the loss
            if self.GDRO:

                # update the weights
                p_GDRO_at_t =  self.weights_val_obj.update_DRO_weights(p_GDRO_at_t, loss_per_group_at_t, eta  )

                # set the weights for the validation set
                self.weights_val_obj.p_weights = p_GDRO_at_t
                self.weights_val_obj.weights = self.weights_val_obj.set_weights_per_group()

                # set the weights for the validation data
                print('The GDRO probabilities are updated to {}, based on this loss per group: {}'.format(p_GDRO_at_t, loss_per_group_at_t))
            
                # if finite_diff, update the weights_val
                if self.grad_type == 'finite_diff':
                    self.weights_val = self.weights_val_obj.get_weights_sample(self.g_val)


            # calculate the finite difference
            if self.grad_type == 'finite_diff':

                # calculate the finite difference
                grad = self.calc_finite_diff_loss_p(p_at_t, diff, return_loss=False)

            # calculate the gradient using the inverse function theorem
            elif self.grad_type == 'ift':

                # calculate the gradient for the alphas
                grad = self.weight_grad_via_ift(model_at_t, p_at_t, self.X_train, self.y_train, self.g_train, self.X_val, self.y_val, self.g_val, self.weights_val_obj,  analytical_hessian=analytical_hessian, subsample_weights=subsample_weights)
            
            if verbose:
                 # Give the user some information about process
                p_at_formatted = [np.round(p_at_t[g].item(), decimals=self.weight_rounding) for g in groups]
                if self.GDRO:
                    loss_formatted = worst_group_criterion
                else:
                    loss_formatted = loss_at_t
                print('At step {}, the loss is {}, we have {} patience left, and the probabilities are {}, which sum to {} with gradients {}.'.format(t, loss_formatted, patience_count, p_at_formatted,  sum(p_at_t.values()), grad))

            # save the p at t via copy
            #alpha_at_t_plus_1 = alpha_at_t.copy()
            p_at_t_plus_1 = p_at_t.copy()

            # impose gradient clipping
            if gradient_clip is not None:

                # turn the gradient into a tensor
                grad_tensor = torch.stack([grad[g] for g in groups])

                # check the norm of the gradient
                grad_norm = torch.norm(grad_tensor)

                # if the norm is greater than the gradient clip, normalize the gradient
                if grad_norm > gradient_clip:
                    grad = {g: (grad[g] / grad_norm)*gradient_clip for g in groups}

            # determine the learning rate at time t
            if lr_schedule == 'constant':
                lr_t = lr
            elif lr_schedule == 'exponential':
                lr_t = lr * torch.exp(-decay*t)
                #print('The learning rate at time {} is {}'.format(t, lr_t))
            elif lr_schedule == 'inverse':
                lr_t = lr / (1 + decay*t)
                #print('The learning rate at time {} is {}'.format(t, lr_t))
            else:
                Exception('The learning rate schedule is not recognized')

                
            # calculate the updates
            updates = dict.fromkeys(groups, torch.as_tensor(0))

            for g in groups:
                # update the weight
                update =  (grad[g])
                
                # if locked in, do not update
                if g == lock_in_p_g and lock_in_p_g is not None:
                    continue
                
                # check if momentum is not None
                if momentum is not None:
                    update = (1-momentum)*update + (momentum * prev_update[g])
                    
                    # save the update
                    prev_update[g] = update
              
                # add to dict of updates
                updates[g] = -lr_t*update

            # if stable, then deduct the max update
            if stable_exp:
                max_update = self.get_max_p_dict(updates)
                updates = {g: updates[g] - max_update for g in groups}
            
            # update the p
            for g in groups:
                p_at_t_plus_1[g] = p_at_t_plus_1[g] * torch.exp(updates[g])


            # round the p_at_t to the specified number of decimal places
            p_at_t= self.round_p_dict(p_at_t_plus_1)

            # then normalize it if needed
            if normalize:
                p_at_t = self.round_p_dict(p_at_t_plus_1)

            # clip the p_at_t
            p_at_t =  self.clip_p_dict_per_group(p_at_t, p_min=p_min, p_max=1.0)

            # if normalize, clip again
            if normalize:
                p_at_t = self.sum_to_1_p_dict(p_at_t)


            # after the p_at_t is determined, update the model
            model_at_t.reset_weights(p_weights=p_at_t)
            model_at_t.p_weights = p_at_t
            
            
            # save the trajectory if needed
            if save_trajectory:
                p_at_t_traj[t] = torch.tensor([value for value in p_at_t.values()])
            t += 1
        
        if self.GDRO:
            print('Returning the p={}, for which loss is {}'.format(best_p, best_worst_group_criterion))
        else:
            print('Returning the p={}, for which loss is {}'.format(best_p, best_loss))

        

        # return the weight
        if save_trajectory:
            return best_p, p_at_t_traj[:t-1], loss_at_t_traj[:t-1]
        else:
            return best_p
    

