



import numpy as np
import torch
import copy


class weights():
    """
    Class for setting weights for each group in the training data.
    """

    def __init__(self, p_weights, p_train, weighted_loss_weights=True):
        self.p_weights = p_weights
        self.p_train = p_train
        self.groups = list(p_train.keys())
        self.n_groups = len(self.groups)

        # Set the weights for each group
        self.weighted_loss_weights = weighted_loss_weights
        self.weights = self.set_weights_per_group()

        # set the min. p and max. p for the DRO weights
        self.p_min = 0.001
        self.p_max = 1.00

        # first - which group is the smallest in the training data
        self.min_group = min(p_train, key=p_train.get)

    def reset_p_weights(self, p_weights):
        self.p_weights = p_weights
        self.weights = self.set_weights_per_group()



    def set_weights_per_group(self, normalize=False):
        """
        Set the weights for each group of the training data.
        w_g = p_weights[g]/p_train[g], where p_train[g] is the proportion of the training data in group g and p_weights[g] is the desired proportion for group g.

        Parameters:
            p_train: dict
                The proportion of the training data in each group. Each key is an integer representing the group and each value is a float representing the proportion of the training data in that group.
            p_weights: dict
                The desired proportion of the weights for each group. Each key is an integer representing the group and each value is a float representing the desired proportion of the weights for that group.
            normalize: bool
                If True, normalize the weights so that they sum to 1. Default is True.
        """

        weights = {}
        if self.p_weights is None:
            weights = {g: 1 for g in self.p_train.keys()}
        else:
            for g in self.p_train.keys():
                if self.weighted_loss_weights:
                    weights[g] = self.p_weights[g]/self.p_train[g]
                else:
                    weights[g] = self.p_weights[g]
        
        if normalize:
            total = sum(weights.values())
            for g in weights.keys():
                weights[g] = weights[g]/total
        
        return weights


    def get_weights_sample(self, g):
        """
        Get the weights for a sample, where g represents the group membership of each datapoint in the sample.
        Initially set weights to cpu for apply - then move to device of g.

        Parameters:
            g: torch.tensor of shape (n,)
                Integers representing the group membership of each datapoint in the sample.
            
        """
        weights_sample = np.zeros(g.shape)
        assign_weights = np.vectorize(lambda val: self.weights[val])
        weights_sample = assign_weights(g)

        return torch.Tensor(weights_sample).to(torch.float32)


    def set_independent_weights(self, p_y, p_c):
        """
        set independent weights for each group in the training data.
        presumes that there are two classes for y, and two classes for c.
        Also presumes that:
            group 1: y = 1, c = 1
            group 2: y = 1, c = 0
            group 3: y = 0, c = 1
            group 4: y = 0, c = 0
        """

        # Set the proportion of each group in the training data
        self.p_weights =  {1: p_y * p_c,
                    2: p_y * (1-p_c),
                    3: (1-p_y) * p_c,
                    4: (1-p_y) * (1-p_c)
                    }
        
        # Set the weights for each group
        self.weights = self.set_weights_per_group(normalize=True)

        return self.p_weights
    
    def get_q(self, loss_g, eta, eps=10**-5):

        q = torch.exp(eta *(loss_g + eps))

        return q
    

    def update_DRO_weights(self, p_w, loss_dict,  eta, C=0.0, n_dict=None): 

        # create torch.tensor of the values in p_w
        p_w_vec = torch.tensor(list(p_w.values()), dtype=torch.float32)

        # get the groups in the loss dict
        groups_loss = list(loss_dict.keys())
    

        # go through each group and update the weights
        i=0
        for g in groups_loss:
            
            # Get the loss for group g
            loss_g = loss_dict[g]

            if n_dict is not None:
                n_g = n_dict[g]
                regularizer = C/torch.sqrt(n_g)
            else:
                regularizer = 0.0
            

            # Get the  q for group g
            q_g = self.get_q(loss_g, eta) + regularizer

            # Update the weights for group g
            p_w_vec[i] = p_w_vec[i] * q_g
            i += 1
        
        # Normalize the weights
        p_w_vec_normalized = p_w_vec / torch.sum(p_w_vec)

        # clip the p_w_vec_normalized to be between p_min and p_max
        p_w_vec_clipped = torch.clamp(p_w_vec_normalized, self.p_min, self.p_max)

        # then normalize the weights again
        p_w_vec = p_w_vec_clipped / torch.sum(p_w_vec_clipped)

        # return the probability weights as a dictionary
        p_w_vec_dict = {g: p_w_vec[i].item() for i, g in enumerate(self.groups)}

        return p_w_vec_dict
    



        
    






            

        


        
        

                 

        

