import time
import torch
import numpy as np
from scipy import stats
from tqdm import tqdm

def mean_squared_error(ys_pred, ys):
    return (ys - ys_pred).square().mean()

def get_optimal_ncl_acc(ys_list):
    # ys_list consists of num_bool_task tensors, each of shape [bsize * wsize]
    ys_combined = torch.cat(ys_list, dim=1) # Shape: N*(M*num_task)

    # Count the number of -1s and 1s in each row
    counts_neg1 = (ys_combined == -1).sum(dim=1)  # Number of -1s
    counts_pos1 = (ys_combined == 1).sum(dim=1)   # Number of 1s

    # Find the most frequent value
    # most_frequent_value = torch.where(counts_neg1 > counts_pos1, -1, 1)  # The value (-1 or 1) that appears more frequently
    most_frequent_count = torch.max(counts_neg1, counts_pos1)
    counts = torch.sum(most_frequent_count).item()
    
    bsize, wsize_num_tasks = ys_combined.shape  # wsize_num_tasks represents wsize X n_task
    total_num = bsize * wsize_num_tasks
    acc = 100 * counts / total_num 
    return acc

def get_optimal_ncl_loss(ys_list, cont_ncl_loss_list=None):
    # ys_list consists of num_continuous_task tensors, each of shape [ncl_bsize * ncl_wsize]. 
    # In other words, when calculating individually, len(ys_list)=1.
    if cont_ncl_loss_list is None:
        # This is used only when calculating individually.
        ys_combined = torch.cat(ys_list, dim=1) 
        # ys_combined.shape = torch.Size([ncl_bsize, 1*ncl_wsize])
        row_means = torch.mean(ys_combined, dim=1, keepdim=True)
        # The i-th value of row_means represents the average of true y-values for the i-th x, i.e., E_w[f(x_i, w)].
        ncl_optimum_values = row_means.expand_as(ys_combined)
        # Expands row_means to match the shape of ys_combined for MSE calculation. 
        # Although broadcasting might make this unnecessary, it ensures the shapes match safely.
        loss = mean_squared_error(ys_combined, ncl_optimum_values)
        return loss
    else:
        # This is used only when calculating mixed all.
        n_task = len(ys_list)
        ncl_bsize, ncl_wsize = ys_list[0].shape
        opt_y_list = []

        for b in tqdm(range(ncl_bsize)):
            numerator = 0.0
            denominator = 0.0
            for i in range(n_task):
                for j in range(ncl_wsize):
                    numerator += (1 / cont_ncl_loss_list[i]) * ys_list[i][b, j]  # Note that the indexing of ys_list is slightly different from that in Overleaf.
                    denominator += (1 / cont_ncl_loss_list[i])
            opt_y = numerator / denominator
            opt_y_list.append(opt_y)

        opt_ys = torch.tensor(opt_y_list)
        opt_ys = opt_ys.unsqueeze(1).repeat(1, n_task * ncl_wsize)
        ys_combined = torch.cat(ys_list, dim=1) # Shape: (n_task, ncl_wsize X n_task)
        loss = mean_squared_error(ys_combined, opt_ys)
        return loss

sigmoid = torch.nn.Sigmoid()
bce_loss = torch.nn.BCELoss()

def cross_entropy(ys_pred, ys):
    '''
    ys_pred: [-inf, inf]
    ys: {-1, 1}
    '''
    output = sigmoid(ys_pred)
    target = (ys + 1) / 2
    return bce_loss(output, target)

def get_optimal_ncl_bool_loss_from_true_no_context_function(ys_list):
    # ys_list consists of num_bool_task tensors, each of shape [bsize * wsize]
    ys_combined = torch.cat(ys_list, dim=1) # Shape: torch.Size([N, M*num_task])

    # Calculate the most frequent value
    mode_vals, _ = torch.mode(ys_combined, dim=1)

    # Create new_ys_combined
    ncl_opt_ys_combined = mode_vals.unsqueeze(1).expand_as(ys_combined)

    ncl_loss = cross_entropy(ncl_opt_ys_combined, ys_combined)
    return ncl_loss

def get_optimal_ncl_bool_loss(exp_name):
    if exp_name == 'conjunction_15':
        ncl_loss = 0.24639
    elif exp_name == 'disjunction_15':
        ncl_loss = 0.24386
    elif exp_name == 'sparse_parity_15_2':
        ncl_loss = 0.68954
    elif exp_name == 'sparse_parity_15_3':
        ncl_loss = 0.69385
    elif exp_name == 'conjunction_20':
        ncl_loss = 0.14085
    elif exp_name == 'disjunction_20':
        ncl_loss = 0.13822
    elif exp_name == 'sparse_parity_20_2':
        ncl_loss = 0.69102
    elif exp_name == 'sparse_parity_20_3':
        ncl_loss = 0.69464
    elif exp_name == 'conjunction_10':
        ncl_loss = 0.40751
    elif exp_name == 'disjunction_10':
        ncl_loss = 0.40316
    elif exp_name == 'sparse_parity_10_2':
        ncl_loss = 0.68295
    elif exp_name == 'sparse_parity_10_3':
        ncl_loss = 0.69017
    elif exp_name == 'parity_10':
        ncl_loss = 0.69656
    return ncl_loss
