import numpy as np
import copy
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torch.utils.data import Dataset
import math

def Average(lst):
    return sum(lst) / len(lst)

def to_categorical(y):
    """ 1-hot encodes a tensor """
    num_classes = len(np.unique(y))
    return np.eye(num_classes, dtype='uint8')[y.astype(int)]

class WeightClipper(object):

    def __init__(self, frequency=5):
        self.frequency = frequency

    def __call__(self, module):
        # filter the variables to get the ones you want
        if hasattr(module, 'weight'):
            w = module.weight.data
            w = w.clamp(-1,1)


class WeightInit(object):

    def __init__(self, frequency=5):
        self.frequency = frequency

    def __call__(self, module):
        # filter the variables to get the ones you want
        # torch.manual_seed(0)
        if hasattr(module, 'weight'):
            w = module.weight.data
            w = nn.init.normal_(w, 0.0, 0.02)


def Russel_similarity(c):  # emotion similarity scores generation

    fear = [1, 117]
    disgust = [1, 153]
    sad = [1, 198]
    neutral = [0, 0]
    happy = [1, 18]
    matrix = np.eye(c)
    if c ==5:
        list = [disgust, fear, sad, neutral, happy]
    elif c == 4:
        list = [neutral, sad, fear, happy]
    elif c ==3:
        list = [sad, neutral, happy]


    for i in range(c):
        for j in range(c):
            matrix[i,j] = 1 - math.sqrt(list[i][0]**2 + list[j][0]**2 -2*list[i][0]*list[j][0]*math.cos(math.radians(list[i][1])-math.radians(list[j][1]))) / 2

    return matrix

# Function to calculate cosine similarity between two vectors
def cosine_similarity(v1, v2):
    return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))

#embedding similarity matrix
def Semantic_Distribution_similarity(c):  # emotion similarity scores generation
    file_path = 'GloVe_Embeddings/seedv_emotions.txt'
    embeddings = {}
    with open(file_path, 'r') as file:
        for line in file:
            data = line.split()
            word = data[0]
            embedding = np.array([float(x) for x in data[1:]])
            embeddings[word] = embedding

    # Emotion labels ordered as specified
    if c == 5:
        labels = ["disgusted", "fearful", "sad", "neutral", "happy"]
    elif c ==4:
        labels = ["neutral", "sad", "fearful", "happy"]
    elif c ==3:
        labels = ['sad', 'neutral', 'happy']

    # Initialize the similarity matrix
    similarity_matrix = np.zeros((len(labels), len(labels)))

    # Compute the cosine similarity between each pair of embeddings
    for i, label1 in enumerate(labels):
        for j, label2 in enumerate(labels):
            similarity_matrix[i, j] = cosine_similarity(embeddings[label1], embeddings[label2])

    # Format the similarity matrix to keep only two decimals
    formatted_matrix = np.round(similarity_matrix, 2)

    return formatted_matrix

def partialize(y, p): # generation of candidate labels based on the uniform distribution
    new_y = copy.deepcopy(y).astype(float)
    n, c = y.shape[0], y.shape[1]
    avgC = 0

    for i in range(n):
        row = new_y[i, :]
        row[np.where(np.random.binomial(1, p, c)==1)] = 1
        while np.sum(row) == 1:
            row[np.random.randint(0, c)] = 1
        avgC += np.sum(row)
        new_y[i] = row / np.sum(row)

    avgC = avgC / n
    return new_y, avgC

def partialize_Russel_Distribution(y, y0): # generation of candidate labels based on Russel_Distribution similarities
    new_y = copy.deepcopy(y).astype(float)
    n, c = y.shape[0], y.shape[1]
    avgC = 0
    matrix = Russel_similarity(c)
    for i in range(n):
        row = new_y[i, :]
        for j in range(c):
            row[j] = np.random.binomial(1,  matrix[int(y0[i]), j], 1)

        while np.sum(row) == 1:
            row[np.random.randint(0, c)] = 1
        avgC += np.sum(row)

        new_y[i] = row / np.sum(row)


    avgC = avgC / n
    return new_y, avgC

def partialize_Semantic_Distribution(y, y0): # generation of candidate labels based on emotion similarities
    new_y = copy.deepcopy(y).astype(float)
    n, c = y.shape[0], y.shape[1]
    avgC = 0
    matrix = Semantic_Distribution_similarity(c)
    for i in range(n):
        row = new_y[i, :]
        for j in range(c):
            row[j] = np.random.binomial(1,  matrix[int(y0[i]), j], 1)

        while np.sum(row) == 1:
            row[np.random.randint(0, c)] = 1
        avgC += np.sum(row)

        new_y[i] = row / np.sum(row)

    avgC = avgC / n
    return new_y, avgC

def add_gaussian_noise_torch(input, std):

    input_shape =input.size()
    noise = torch.normal(mean=0.5, std=std, size =input_shape)

    return input + noise

class CustomEEGDataset(Dataset): #customize dataset, containing data augmention which is required in methods CR and PiCO
    def __init__(self, image, labels, partial_labels, augmentation=True):
        self.data = image
        self.labels = labels
        self.partial_labels = partial_labels
        self.augmentation = augmentation

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        data, label, partial_label = map(lambda x:  torch.Tensor(x[index]), (self.data, self.labels, self.partial_labels))

        if self.augmentation is True:
            weak_aug = add_gaussian_noise_torch(data, 0.2)
            strong_aug = add_gaussian_noise_torch(data, 0.8)
            output = index, data, weak_aug, strong_aug, label, partial_label
        else:
            output = index, data, label, partial_label

        return output

def load_augmented_dataset_to_device(data, label, partial_label, batch_size, shuffle_flag=True, augmentation_flag=True):



    dataset = CustomEEGDataset(data, label, partial_label, augmentation_flag)
    dataset = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                                shuffle=shuffle_flag, num_workers=0, drop_last=True,  pin_memory=True)

    return dataset


class partial_loss(nn.Module): 
    """The supervised loss of PiCO with and without prototype-based label disambiguation"""
    def __init__(self, confidence, conf_ema_m=0.99):
        super().__init__()
        self.confidence = confidence
        self.init_conf = confidence.detach()
        self.conf_ema_m = conf_ema_m

    def set_conf_ema_m(self, epoch, args):
        start = args.conf_ema_range[0]
        end = args.conf_ema_range[1]
        self.conf_ema_m = 1. * epoch / args.epochs * (end - start) + start

    def forward(self, outputs, index):
        logsm_outputs = F.log_softmax(outputs, dim=1)
        # logsm_outputs = F.softmax(outputs, dim=1)

        final_outputs = logsm_outputs * self.confidence[index, :]
        average_loss = - ((final_outputs).sum(dim=1)).mean()
        return average_loss

    def confidence_update(self, temp_un_conf, batch_index, batchY):
        device = (torch.device('cuda')
                  if temp_un_conf.is_cuda
                  else torch.device('cpu'))
        with torch.no_grad():
            _, prot_pred = (temp_un_conf * batchY).max(dim=1)
            pseudo_label = F.one_hot(prot_pred, batchY.shape[1]).float().to(device).detach()
            self.confidence[batch_index, :] = self.conf_ema_m * self.confidence[batch_index, :]\
                 + (1 - self.conf_ema_m) * pseudo_label
        return None

class SupConMocoLoss(nn.Module): # Contrastive Loss used in PiCO
    """Following Supervised Contrastive Learning"""
    def __init__(self, temperature=0.07, base_temperature=0.07):
        super().__init__()
        self.temperature = temperature
        self.base_temperature = base_temperature

    def forward(self, features, mask=None, batch_size=-1):
        device = (torch.device('cuda')
                  if features.is_cuda
                  else torch.device('cpu'))

        if mask is not None:
            # SupCon loss (Partial Label Mode)
            mask = mask.float().detach().to(device)
            # compute logits
            anchor_dot_contrast = torch.div(
                torch.matmul(features[:batch_size], features.T),
                self.temperature)
            # for numerical stability
            logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
            logits = anchor_dot_contrast - logits_max.detach()
            # mask-out self-contrast cases
            logits_mask = torch.scatter(
                torch.ones_like(mask),
                1,
                torch.arange(batch_size).view(-1, 1).to(device),
                0
            )
            mask = mask * logits_mask
            # compute log_prob
            exp_logits = torch.exp(logits) * logits_mask
            log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)
            # compute mean of log-likelihood over positive
            mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)
            # loss
            loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos
            loss = loss.mean()
        else:
            # MoCo loss (unsupervised)
            # compute logits
            # Einstein sum is more intuitive
            # positive logits: Nx1
            q = features[:batch_size]
            k = features[batch_size:batch_size*2]
            queue = features[batch_size*2:]
            l_pos = torch.einsum('nc,nc->n',    [q, k]).unsqueeze(-1)
            # negative logits: NxK
            l_neg = torch.einsum('nc,kc->nk',   [q, queue])
            # logits: Nx(1+K)
            logits = torch.cat([l_pos, l_neg], dim=1)
            # apply temperature
            logits /= self.temperature
            # labels: positive key indicators
            labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)
            loss = F.cross_entropy(logits, labels)
        return loss

class PaPiLoss(nn.Module):
    def __init__(self, predicted_score_cls, pseudo_label_weight=0.99):
        super().__init__()
        self.predicted_score_cls1 = predicted_score_cls
        self.predicted_score_cls2 = predicted_score_cls

        self.init_predicted_score_cls = predicted_score_cls.detach()

        self.pseudo_label_weight = pseudo_label_weight

    def set_alpha(self, epoch, args):
        self.alpha = min((epoch / 10) * args.alpha_weight, args.alpha_weight)

    def set_pseudo_label_weight(self, epoch, args):
        start = float(args.pseudo_label_weight_range.split(",")[0])
        end = float(args.pseudo_label_weight_range.split(",")[1])
        self.pseudo_label_weight = 1. * epoch / args.epochs * (end - start) + start

    def update_weight_byclsout1(self, cls_predicted_score, batch_index, batch_partial_Y, args):
        with torch.no_grad():
            y_pred_raw_probas = torch.softmax(cls_predicted_score, dim=1)

            revisedY_raw = batch_partial_Y.clone()
            revisedY_raw = revisedY_raw * y_pred_raw_probas
            revisedY_raw = revisedY_raw / revisedY_raw.sum(dim=1).repeat(args.num_class, 1).transpose(0, 1)

            cls_pseudo_label = revisedY_raw.detach()

            self.predicted_score_cls1[batch_index, :] = self.pseudo_label_weight * self.predicted_score_cls1[
                                                                                   batch_index, :] + \
                                                        (1 - self.pseudo_label_weight) * cls_pseudo_label

    def forward(self, cls_out1, cls_out2, logits_prot1, logits_prot2, logits_prot_1_mix, logits_prot_2_mix, idx_rp,
                Lambda, index, args, sim_criterion):
        y_pred_1_probas = torch.softmax(cls_out1, dim=1)

        prot_pred_1_mix_probas_log = torch.log_softmax(torch.div(logits_prot_1_mix, args.tau_proto), dim=1)
        prot_pred_2_mix_probas_log = torch.log_softmax(torch.div(logits_prot_2_mix, args.tau_proto), dim=1)

        soft_positive_label_target1 = self.predicted_score_cls1[index, :].clone().detach()
        soft_positive_label_target1_rp = self.predicted_score_cls1[index[idx_rp], :].clone().detach()

        cls_loss_all_1 = soft_positive_label_target1 * torch.log(y_pred_1_probas)
        cls_loss_1 = - ((cls_loss_all_1).sum(dim=1)).mean()

        sim_loss_2_1 = Lambda * sim_criterion(prot_pred_1_mix_probas_log, soft_positive_label_target1) + \
                       (1 - Lambda) * sim_criterion(prot_pred_1_mix_probas_log, soft_positive_label_target1_rp)

        sim_loss_2_2 = Lambda * sim_criterion(prot_pred_2_mix_probas_log, soft_positive_label_target1) + \
                       (1 - Lambda) * sim_criterion(prot_pred_2_mix_probas_log, soft_positive_label_target1_rp)

        sim_loss_2 = sim_loss_2_1 + sim_loss_2_2

        return cls_loss_1, sim_loss_2, self.alpha

def split_balance_class(data, label, train_rate, random):
    # Set random seed for reproducibility
    # np.random.seed(seed)

    # Get unique classes from labels
    unique_classes = np.unique(label)

    index_train = []
    index_val = []

    # Process each class individually
    for class_id in unique_classes:
        class_indices = np.where(label == class_id)[0]

        # Optionally shuffle indices for this class
        if random:
            np.random.shuffle(class_indices)

        # Calculate the number of training samples for this class
        train_size = int(len(class_indices) * train_rate)

        # Split indices into training and validation sets
        index_train.extend(class_indices[:train_size])
        index_val.extend(class_indices[train_size:])

    # Optionally shuffle the combined training and validation indices
    if random:
        np.random.shuffle(index_train)
        np.random.shuffle(index_val)

    # Extract training and validation sets based on the indices
    train = data[index_train]
    train_label = label[index_train]
    val = data[index_val]
    val_label = label[index_val]

    return train, train_label, val, val_label