from statistics import mean
import torch
from multiprocessing import shared_memory
import multiprocessing
from collections import defaultdict
from copy import deepcopy
from multiprocessing import Manager, Process, Value, Queue
import torch.multiprocessing as mp
import torch.nn.functional as F

def _calculate_cos_affinity(feature_1, feature_2, smooth=True):
    '''
    NOTE:
    feature_1: B*C
    feature_2: N*C
    return: B*N
    '''
    epsilon=1e-8
    normal_data_f1 = F.normalize(feature_1, p=2, dim=-1)
    normal_data_f2 = F.normalize(feature_2, p=2, dim=-1)
    # get all cosin affinity, [0, 1]
    if smooth:
        cos_affin_normal = 0.5*torch.mm(normal_data_f1, normal_data_f2.t())+0.5
        cos_affin_normal = torch.clamp(cos_affin_normal, 0 + epsilon, 1 - epsilon)
    else:
        cos_affin_normal = torch.mm(normal_data_f1, normal_data_f2.t())
        cos_affin_normal = torch.clamp(cos_affin_normal, -1 + epsilon, 1 - epsilon)
    return cos_affin_normal

def _generate_raw_feature(theta_list, dim_num):
    feature_num = len(theta_list)
    final_tensor = torch.zeros([feature_num, dim_num])
    tmp_tensor = torch.randn([feature_num, dim_num-1])
    tmp_tensor = F.normalize(tmp_tensor, p=2, dim=-1)
    scale = torch.sqrt(1-theta_list**2)
    tmp_tensor *= scale.unsqueeze(-1)
    final_tensor[:,1:] = tmp_tensor[:,:]
    final_tensor[:,0] = theta_list[:]
    return final_tensor.cuda()


def generate_raw_feature(theta_list, dim_num):
    return _generate_raw_feature(theta_list, dim_num)


def multi_release_features_dict(lock, memo_index, memo_label, memory, rot_dic, feature_number):
    dim = memory.shape[-1]
    stand = torch.zeros([dim, 1]).cuda()
    stand[0] = 1
    while True:
        with lock:
            memo_index.value += 1
            index = memo_index.value
            if index >= len(memo_label):
                break
        _memory = memory[index]
        label = memo_label[index]
        R_0, R_0_inv = _rot_map(_memory)
        R_X, _ = _rot_map(torch.mm(R_0, deepcopy(stand)).squeeze())
        R = torch.mm(torch.mm(R_0_inv, R_X), R_0)
        # get different cos similarity tensor
        cos_list = torch.normal(mean=0.8,std=0.08*torch.ones([feature_number]))
        cos_list = torch.clamp(cos_list, min=0.705, max=0.995)
        released_raw_features = _generate_raw_feature(cos_list, dim)
        released_feature = torch.mm(released_raw_features, R.t())
        rot_dic[label] = released_feature.cpu()


def multi_release_features_dict_l(event, lock, phase, memo_index, label_memo_l, rot_dic, feature_number):
    read_time = 0
    while True:
        if not event.is_set():
            read_time = 0
        event.wait()
        if read_time == 0:
            memo_label = label_memo_l[0]
            task_number = len(memo_label)//phase-1
            memory = label_memo_l[1]
            memory = memory.cuda()
            dim = memory.shape[-1]
            stand = torch.zeros([dim, 1]).cuda()
            stand[0] = 1
            read_time += 1
            task_number += 1
        with lock:
            memo_index.value += 1
            index = memo_index.value
            if index >= len(memo_label):
                event.clear()
                read_time = 0
                continue
        _memory = memory[index]
        label = memo_label[index]
        R_0, R_0_inv = _rot_map(_memory)
        R_X, _ = _rot_map(torch.mm(R_0, deepcopy(stand)).squeeze())
        R = torch.mm(torch.mm(R_0_inv, R_X), R_0)
        
        # get different cos similarity tensor
        # mean_number = 0.8 - 0.01*task_number/2
        # mean_number = 0.85
        mean_number = 0.8
        # mean_number = 0.75
        # mean_number = 1.0
        max_number = 0.995
        # max_number = 1.0
        min_number = mean_number*2-max_number
        # min_number = 0.7
        std_number = (max_number-mean_number)/1.96
        # std_number = (0.995-mean_number)/3
        cos_list = torch.normal(mean=mean_number,std=std_number*torch.ones([feature_number]))
        cos_list = torch.clamp(cos_list, min=min_number, max=max_number)
        # cos_list = torch.clamp(cos_list, min=min_number, max=0.995)
        released_raw_features = _generate_raw_feature(cos_list, dim)
        released_feature = torch.mm(released_raw_features, R.t())
        rot_dic[label] = released_feature.cpu()


def multi_calculate_R(event, lock, memo_index, label_memo_l, rot_dic):
    read_time = 0
    while True:
        if not event.is_set():
            read_time = 0
        event.wait()
        if read_time == 0:
            memo_label = label_memo_l[0]
            memory = label_memo_l[1]
            memory = memory.cuda()
            dim = memory.shape[-1]
            stand = torch.zeros([dim, 1]).cuda()
            stand[0] = 1
            read_time += 1
        with lock:
            memo_index.value += 1
            index = memo_index.value
            if index >= len(memo_label):
                event.clear()
                read_time = 0
                continue
        _memory = memory[index]
        label = memo_label[index]
        R_0, R_0_inv = _rot_map(_memory)
        R_X, _ = _rot_map(torch.mm(R_0, deepcopy(stand)).squeeze())
        R = torch.mm(torch.mm(R_0_inv, R_X), R_0)
        rot_dic[label] = R.cpu()



def release_features_dict(memory, memo_list, rot_dic, feature_number):
    for label in memo_list:
        _memory = memory[label]
        dim = len(_memory)
        stand = torch.zeros([dim, 1]).cuda()
        stand[0] = 1
        R_0, R_0_inv = _rot_map(_memory)
        R_X, _ = _rot_map(torch.mm(R_0, stand).squeeze())
        R = torch.mm(torch.mm(R_0_inv, R_X), R_0)
        # get different cos similarity tensor
        cos_list = torch.normal(mean=0.85,std=0.08*torch.ones([feature_number]))
        cos_list = torch.clamp(cos_list, min=0.705, max=0.995)
        released_raw_features = _generate_raw_feature(cos_list, dim)
        released_feature = torch.mm(released_raw_features, R.t())
        rot_dic[label] = released_feature


def _rot_map(V):
    n_dim = V.shape[-1]
    V = F.normalize(V, p=2, dim=-1)
    Rot = torch.eye(n_dim).cuda()
    Rot_inv = torch.eye(n_dim).cuda()
    rot_mat_template = torch.eye(n_dim).cuda()
    for rotate in range(n_dim-1):
        rot_mat = deepcopy(rot_mat_template)
        rot_norm = torch.sqrt(V[rotate]**2 + V[rotate+1]**2)
        cos_theta = V[rotate+1]/rot_norm
        sin_theta = V[rotate]/rot_norm
        rot_mat[rotate,rotate] = cos_theta
        rot_mat[rotate,rotate+1] = - sin_theta
        rot_mat[rotate+1,rotate] = sin_theta
        rot_mat[rotate+1,rotate+1] = cos_theta

        V = torch.mm(rot_mat, V.unsqueeze(-1)).squeeze()

        Rot = torch.mm(rot_mat, Rot)
        Rot_inv = torch.mm(Rot_inv,rot_mat.t())
    return Rot, Rot_inv