import torch
from torch import linalg as LA
import numpy as np
from tqdm import tqdm
from Experiments.cost import *

def get_class_mean(X, y):
    """
    Compute the mean feature vector for each class.

    Args:
        X (torch.Tensor): Input features of shape (n, d)
        y (torch.Tensor): Class labels of shape (n,)

    Returns:
        class_means (torch.Tensor): Mean features for each class, shape (c, d)
        class_labels (torch.Tensor): Unique class labels, shape (c,)
    """
    if y.dim()==2:
        y = y.squeeze()
    class_labels = torch.unique(y)
    class_means = []

    for cls in class_labels:
        cls_mask = (y == cls)
        cls_features = X[cls_mask]
        cls_mean = cls_features.mean(dim=0)
        class_means.append(cls_mean)

    class_means = torch.stack(class_means, dim=0)
    return class_means, class_labels



def max_min_mat(MAT):
    min_vec, _ = torch.min(MAT, 0)
    return torch.max(min_vec)

class Probability_Measure:
    def __init__(self, location, weight=None, label=None, max_min_values=None):
        self.location = location
        self.size = location.size(0)
        self.label = label
        self.predicted_label = None
        self.num_classes = int(torch.max(label)) + 1
        # self.max_min_values = max_min_values
        # self.value_gap = None
        # self.best_weight = None
        if weight is None:
            self.weight = 1 / (self.size * torch.ones(location.size(0), 1, device=self.location.device))
        else:
            self.weight = weight

    def get_label_indices(self, cls, predicted=True):
        if predicted:
            return torch.nonzero(self.predicted_label == cls).squeeze()
        else:
            return torch.nonzero(self.label == cls).squeeze()


    def sampler(self, sampling_size):
        sample_idx = torch.randperm(self.size)[:sampling_size]
        sample = self.location[sample_idx]
        return sample

    def append(self, p, q):
        pq = Probability_Measure(torch.cat((p.location, q.location), dim=0),
                                 weight=torch.cat((p.weight, q.weight), dim=0), label=torch.cat((p.label, q.label)))
        return pq

    def get_best_weights(self, tar_class_weights):
        num_class = len(tar_class_weights)
        total_tar_num = 0
        for key, value in tar_class_weights.items():
            total_tar_num += value
        if self.best_weight is None:
            self.best_weight = torch.zeros(self.location.size(0), device=self.location.device)
        for key, value in tar_class_weights.items():
            class_indices = torch.nonzero(self.label == key).squeeze()
            if class_indices.size() == torch.Size([]):
                num_i = 1
            else:
                num_i = len(class_indices)
            self.best_weight[class_indices] = (tar_class_weights[key] / total_tar_num * num_i) * torch.ones(num_i,
                                                                                                            device=self.location.device)
        return self.best_weight

    def generate_class_balanced_weights(self):
        # Get the total number of samples
        total_samples = self.label.size(0)

        # Find unique classes and their respective counts
        unique_labels, counts = self.label.unique(return_counts=True)

        # Calculate the class weights as the inverse of their counts
        class_weights = 1.0 / counts.float()

        # Normalize the class weights so that they sum up to 1 when applied to all samples
        class_weights = class_weights / class_weights.sum()

        # Create a weight tensor where each sample is assigned the weight of its class
        sample_weights = torch.zeros_like(self.label, dtype=torch.float)
        for i, label in enumerate(unique_labels):
            sample_weights[label == label] = class_weights[i]

        # Normalize the sample weights so that they sum to 1
        sample_weights = sample_weights / sample_weights.sum()

        return sample_weights

    def get_value_gap(self, cla):
        label_1_indices = torch.nonzero(self.label == cla[0], as_tuple=False)
        label_2_indices = torch.nonzero(self.label == cla[1], as_tuple=False)
        self.value_gap = torch.min(self.max_min_values[label_1_indices]) - torch.max(
            self.max_min_values[label_2_indices])
        return self.value_gap  ###########################

    def get_indices(self, class_number):
        return torch.squeeze(torch.nonzero(self.label == class_number, as_tuple=False))


class Semi_Discrete_OT:
    def __init__(self, nu, mu, cost=None):
        '''
        :param nu: Probability object with sampler of a continuous measure \nu
        :param mu: Probability object with sampler of a discrete measure \mu
        '''
        self.nu = nu
        self.mu = mu
        self.num_classes = self.nu.num_classes
        self.cost = cost
        self.reweight_factors = None
        self.weighted_distance = None


    def get_weighted_distance(self):
        if self.cost is None:
            self.cost = l2_distance
        self.weighted_distance = self.cost(self.nu.location, self.mu.location) - self.reweight_factors.T


    def compute_OT(self, lr=None, max_iter=50000, batch_size=1, epsilon=0.05):
        if self.reweight_factors == None:
            self.reweight_factors = torch.zeros(self.mu.size, 1, device=self.mu.location.device)
        if lr is None:
            lr_0 = 1.0
            l_0 = 100
            lr = lambda i: lr_0 / (1 + i / l_0)
        else:
            step_size = lr
            lr = lambda i: step_size
        # mu_sample = mu.sampler()
        for i in tqdm(range(max_iter)):
            nu_sample_x = self.nu.sampler(batch_size)
            kai = torch.exp((-LA.norm(nu_sample_x.unsqueeze(1) - self.mu.location.unsqueeze(0),
                                      dim=2).T + self.reweight_factors) / epsilon)
            kai_normalized = kai / torch.sum(kai, 0)
            increment = lr(i) * (- torch.mean(kai_normalized, 1).unsqueeze(1) + self.mu.weight)
            self.reweight_factors.data = self.reweight_factors.data + increment
        self.reweight_factors = self.reweight_factors.data
        return self.reweight_factors  # - torch.mean(g)

    def classify(self, epsilon=0.05):
        kai_nu_mu = torch.zeros(self.nu.size, self.mu.size)
        i = 0
        for nu_sample in self.nu.location:
            kai = torch.exp((-LA.norm(nu_sample - self.mu.location, dim=1).reshape(-1, 1) + self.reweight_factors) / epsilon)
            kai_nu_mu[i, :] = kai.T / torch.sum(kai)
            i += 1
        self.nu.predicted_label = self.mu.label[torch.argmax(kai_nu_mu, dim=1)]
        self.nu.transported_to = torch.argmax(kai_nu_mu, dim=1)
        return self.nu.predicted_label

    #################################
    def evaluate_max_min(self):
        max_min_values = torch.zeros(self.values.size()[0])
        x_index = 0
        for x_mat in self.values:
            max_min_values[x_index] = self.max_min_mat(x_mat)
            x_index += 1
        return max_min_values

    def get_values(self):
        self.values = torch.zeros(self.X.size, self.mu_1.size, self.mu_2.size)
        g_mu_1 = self.g[self.mu_1_indices]
        g_mu_2 = self.g[self.mu_2_indices]
        x_index = 0
        for x in self.X.location:
            self.values[x_index, :, :] = ((LA.norm(x - self.mu_2.location, dim=1).reshape(-1, 1) - g_mu_2).reshape(1,
                                                                                                                   -1)).repeat(
                self.mu_1.size, 1) - ((LA.norm(x - self.mu_1.location, dim=1).reshape(-1, 1) - g_mu_1)).repeat(1,
                                                                                                               self.mu_2.size)
            x_index += 1


class OT_Score:
    def __init__(self, semi_discrete_ot):
        self.semi_discrete_ot = semi_discrete_ot
        self.ot_score = torch.zeros(self.semi_discrete_ot.nu.size, self.semi_discrete_ot.num_classes, device="cuda" if torch.cuda.is_available() else "cpu")
        self.min_score = None
    def compute_ot_score(self, cls=None):
        #ot_score[i,j,k] =
        if cls:
            predicted_cls_indices = self.semi_discrete_ot.nu.get_label_indices(cls, predicted=True)
            cls_indices_mu = self.semi_discrete_ot.mu.get_label_indices(cls, predicted=False)
            min_dxz = torch.min(self.semi_discrete_ot.weighted_distance[predicted_cls_indices][:, cls_indices_mu], dim=1)
            for cls_mu in range(self.semi_discrete_ot.mu.num_classes):
                othercls_indices = self.semi_discrete_ot.mu.get_label_indices(cls_mu, predicted=False)
                min_dxy = torch.min(self.semi_discrete_ot.weighted_distance[predicted_cls_indices][:, othercls_indices], dim=1)

                self.ot_score[predicted_cls_indices][:, cls_mu] = min_dxy - min_dxz
        else:
            for cls_nu in range(self.semi_discrete_ot.nu.num_classes):
                predicted_cls_indices = self.semi_discrete_ot.nu.get_label_indices(cls_nu, predicted=True)
                cls_indices_mu = self.semi_discrete_ot.mu.get_label_indices(cls_nu, predicted=False)
                temp = self.semi_discrete_ot.weighted_distance[predicted_cls_indices][:, cls_indices_mu]
                if temp.dim() == 1:
                    temp = temp.unsqueeze(1)

                min_dxz, _ = torch.min(temp, dim=1)
                for cls_mu in range(self.semi_discrete_ot.mu.num_classes):
                    othercls_indices = self.semi_discrete_ot.mu.get_label_indices(cls_mu, predicted=False)
                    temp = self.semi_discrete_ot.weighted_distance[predicted_cls_indices][:, othercls_indices]
                    if temp.dim() == 1:
                        temp = temp.unsqueeze(1)

                    min_dxy, _ = torch.min(temp, dim=1)
                    row_idx, col_idx = torch.meshgrid(predicted_cls_indices.cpu(), torch.tensor(cls_mu), indexing='ij')
                    self.ot_score[row_idx, col_idx] = (min_dxy - min_dxz).unsqueeze(1)

    def get_min_ot_score(self):
        masked_tensor = self.ot_score.clone()
        masked_tensor[masked_tensor == 0] = float('inf')

        self.min_score, min_indices = torch.min(masked_tensor, dim=1)
        return self.min_score
#######################################################

def sample_gaussian_within_ball(mean, cov, radius, num_samples=1):

    n_features = len(mean)
    samples = []
    while len(samples) < num_samples:
        sample = np.random.multivariate_normal(mean, cov)
        distance = np.linalg.norm(sample - mean)
        if distance <= radius:
            samples.append(sample)
    return np.array(samples)


def sample_uniform_within_ball(mean, radius, num_samples=1):
    samples = []
    while len(samples) < num_samples:
        sample_x = np.random.uniform(mean[0] - radius, mean[0] + radius, 1)
        sample_y = np.random.uniform(mean[1] - radius, mean[1] + radius, 1)
        sample = np.append(sample_x, sample_y, 0)
        distance = np.linalg.norm(sample - mean)
        if distance <= radius:
            samples.append(sample)
    return np.array(samples)


class Semi_OT:
    def __init__(self, mu, g, X=None, mu_1=None, mu_2=None):
        self.mu = mu
        self.mu_1 = mu_1
        self.mu_2 = mu_2
        self.mu_1_indices = None
        self.mu_2_indices = None
        self.nu_1_indices = None
        self.nu_2_indices = None
        self.g = g
        self.X = X
        self.X_labels = None
        self.values = None  # torch.zeros(X.size, mu_1.size, mu_2.size)  # torch.zeros(X.size, mu_1.size, mu_2.size)

    def split(self, cls_1, cls_2):
        pass

    def get_values(self):
        self.values = torch.zeros(self.X.size, self.mu_1.size, self.mu_2.size)
        g_mu_1 = self.g[self.mu_1_indices]
        g_mu_2 = self.g[self.mu_2_indices]
        x_index = 0
        for x in self.X.location:
            self.values[x_index, :, :] = ((LA.norm(x - self.mu_2.location, dim=1).reshape(-1, 1) - g_mu_2).reshape(1,
                                                                                                                   -1)).repeat(
                self.mu_1.size, 1) - ((LA.norm(x - self.mu_1.location, dim=1).reshape(-1, 1) - g_mu_1)).repeat(1,
                                                                                                               self.mu_2.size)
            x_index += 1

    def max_min_mat(self, MAT):
        min_vec, _ = torch.min(MAT, 0)
        return torch.max(min_vec)

    def min_max_mat(self, MAT):
        max_vec, _ = torch.max(MAT, 1)
        return torch.min(max_vec)

    def evaluate_max_min(self):
        max_min_values = torch.zeros(self.values.size()[0])
        x_index = 0
        for x_mat in self.values:
            max_min_values[x_index] = self.max_min_mat(x_mat)
            x_index += 1
        return max_min_values

    def evaluate_min_max(self):
        min_max_values = torch.zeros(self.values.size()[0])
        x_index = 0
        for x_mat in self.values:
            min_max_values[x_index] = self.min_max_mat(x_mat)
            x_index += 1
        return min_max_values

    def classify(self, epsilon=0.05):
        kai_nu_mu = torch.zeros(self.X.size, self.mu.size)
        i = 0
        for nu_sample in self.X.location:
            kai = torch.exp((-LA.norm(nu_sample - self.mu.location, dim=1).reshape(-1, 1) + self.g) / epsilon)
            kai_nu_mu[i, :] = kai.T / torch.sum(kai)
            i += 1
        self.X_labels = self.mu.label[torch.argmax(kai_nu_mu, dim=1)]
        return self.X_labels

    def transport(self, epsilon=0.05):
        mu = Probability_Measure.append(self.mu_1, self.mu_2)
        kai_nu_mu = torch.zeros(self.X.size, self.mu.size)
        i = 0
        for nu_sample in self.X.location:
            kai = torch.exp((-LA.norm(nu_sample - mu.location, dim=1).reshape(-1, 1) + self.g) / epsilon)
            kai_nu_mu[i, :] = kai.T / torch.sum(kai)
            i += 1
        self.X_transported_to = torch.argmax(kai_nu_mu, dim=1)
        return self.X_transported_to

    def get_indices(self, class_num):
        return torch.nonzero(self.label == class_num, as_tuple=False)


def get_indices(one_hot_labels, target_label):
    numerical_labels = torch.argmax(one_hot_labels, dim=1)
    return torch.nonzero(numerical_labels == target_label).squeeze()


class Label_Handler:
    def __init__(self, label_vec):
        self.min_cla_num = None
        self.label_vec = label_vec
        self.class_labels = torch.unique(label_vec)
        self.num_class = self.class_labels.numel()
        self.class_indices = {}
        self.class_nums = {}
        self.get_class_indices()

    def get_class_indices(self):
        min_cla_num = len(self.label_vec)
        for i in self.class_labels:
            self.class_indices[int(i)] = torch.nonzero(self.label_vec == i)
            self.class_nums[int(i)] = len(self.class_indices[int(i)])
            if len(self.class_indices[int(i)]) < min_cla_num:
                min_cla_num = len(self.class_indices[int(i)])
            # self.class_indices[int(i)] = self.class_indices[int(i)].squeeze()
        self.min_cla_num = min_cla_num

    def get_indices(self, i):
        return torch.nonzero(self.label_vec == i)

    def sampler(self, cla, k):
        k = int(min(self.class_nums[cla], k))
        cla_indices = self.class_indices[cla]
        return cla_indices[torch.randperm(cla_indices.size(0))][:k]

    def sample_balance_indices(self, classes, k):
        k = int(min(min([self.class_nums[i] for i in classes]), k))
        sample_indices = {}
        for cla in classes:
            sample_indices[cla] = self.sampler(cla, k)
        return sample_indices

    def combine_labels(self, labels):
        res = None
        for value in labels.values():
            if res is None:
                res = value
            else:
                res = torch.cat((res, value))
        return res


def get_features(model, dataloader, noise=None):
    for batch in dataloader:
        if len(batch) == 2:
            label_included = True
        break
    features = None
    model.eval()
    if label_included:
        for inputs, labels in dataloader:
            inputs, labels = inputs.cuda(), labels.cuda()
            if noise:
                inputs += torch.normal(0, noise, size=inputs.size(), device=inputs.device)
            if features is None:
                features = model.get_feature(inputs).data
                y = labels
            else:
                features = torch.cat((features, model.get_feature(inputs).data), dim=0)
                y = torch.cat((y, labels), dim=0)
        # Now you can access the output of the first network
        print("Output of the feature network:", features.size())
        return features, y
    else:
        for inputs in dataloader:
            inputs = inputs.cuda()
            if noise:
                inputs += torch.normal(0, noise, size=inputs.size(), device=inputs.device)
            if features is None:
                features = model.get_feature(inputs).data
            else:
                features = torch.cat((features, model.get_feature(inputs).data), dim=0)
        # Now you can access the output of the first network
        print("Output of the feature network:", features.size())
        return features


def get_mean_features(source_features, src_sampler):
    src_mean_features = torch.zeros(src_sampler.num_class, source_features.size(1), device=source_features.device)
    cla = 0
    for i in src_sampler.class_labels:
        i_indices = src_sampler.class_indices[int(i)].squeeze()
        src_mean_features[cla, :] = torch.mean(source_features[i_indices], 0)
        cla += 1
    return src_mean_features
