import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import scipy.io.arff as sioarff
import scipy.linalg as scilin

def iou_measure(pred, target):
    num_class_sample = torch.sum(target != 0, dim=1) # Number of classes each sample belong to
    pred_prob = F.softmax(pred, dim=1)
    target_prob = target / torch.sum(target, dim=1)[:,None]
    iou = torch.sum(pred_prob * target_prob, dim=1) / 2 
    iou = iou * num_class_sample * 2
    
    return iou

def compute_info(model, dataloader, device):
    """
    Aim: Store a dictionary of two layers contain all features
         The first layer is about multiplicity
             (i.e., Dict[m] will give a new dict containing all features for multiplicity m)
         The second layer is about class
             (i.e., Dict[m][k] will be an array containing all features with multiplicity m and k 
                    represents the label set, e.g., k could be '12' or '24', etc)
    """
    model.eval()
    num_data = 0
    mu_G = 0 # Only care about multiplicity 1
    mu_c_dict = dict() # Only care about multiplicity 1
    num_class_dict = dict() # Only care about multiplicity 1
    before_class_dict = dict() # Store features (care about all multiplicity)
    
    all_features = []
    all_labels = []

    for batch_idx, (inputs, targets) in enumerate(dataloader):

        inputs, targets = inputs.to(device).float(), targets.to(device).float()

        with torch.no_grad():
            output, features = model(inputs)
            
        all_features.extend(features.cpu())
        all_labels.extend(targets.cpu())
        num_data += targets.shape[0]
    
    # Process these features according to the label
    for i in range(len(all_labels)):
        cur_label = all_labels[i]
        cur_feature = all_features[i]
        
        multip = torch.sum(cur_label != 0).item() # multiplicity
        if multip not in before_class_dict:
            before_class_dict[multip] = {}
        if multip == 1: # multiplicity 1
            mu_G += cur_feature.numpy()
            y = torch.argmax(cur_label).item()
            if y not in mu_c_dict:
                mu_c_dict[y] = cur_feature.numpy()
                before_class_dict[multip][y] = [cur_feature.numpy()]
                num_class_dict[y] = 1
            else:
                mu_c_dict[y] += cur_feature.numpy()
                before_class_dict[multip][y].append(cur_feature.numpy())
                num_class_dict[y] = num_class_dict[y] + 1
        
        else: # Other multiplicity
            y = str(torch.argwhere(cur_label).squeeze().detach().cpu().numpy())
            if y not in before_class_dict[multip]:
                before_class_dict[multip][y] = [cur_feature.numpy()]
            else:
                before_class_dict[multip][y].append(cur_feature.numpy())
        
    mu_G = mu_G / num_data
    for cla in mu_c_dict:
        mu_c_dict[cla] /= num_class_dict[cla]
    print(num_data, num_class_dict)
    return mu_G, mu_c_dict, before_class_dict

# Within-class covariance matrix
def compute_Sigma_W(before_class_dict, mu_c_dict, device):
    num_data = 0
    Sigma_W = 0
    
    for target in before_class_dict.keys():
        class_feature_list = torch.from_numpy(np.array(before_class_dict[target])).float().to(device)
        class_mean = torch.from_numpy(mu_c_dict[target]).float().to(device)
        for feature in class_feature_list:
            diff = feature - class_mean
            Sigma_W += torch.outer(diff,diff)
            num_data += 1
    Sigma_W /= num_data
    
    return Sigma_W.cpu().numpy()

# Between-class covariance matrix
def compute_Sigma_B(mu_c_dict, mu_G, device):
    mu_G = torch.from_numpy(mu_G).float().to(device)
    Sigma_B = 0
    K = len(mu_c_dict)
    for i in range(K):
        class_mean = torch.from_numpy(mu_c_dict[i]).float().to(device)
        diff = class_mean - mu_G
        Sigma_B += torch.outer(diff,diff)

    Sigma_B /= K

    return Sigma_B.cpu().numpy()

def compute_W_H_relation(W, mu_c_dict, mu_G):
    K = len(mu_c_dict)
    H = torch.empty(mu_c_dict[0].shape[0], K)
    for i in range(K):
        H[:, i] = torch.from_numpy(mu_c_dict[i] - mu_G).float()

    WH = torch.mm(W, H.cuda())
    WH /= torch.norm(WH, p='fro')
    sub = 1 / pow(K - 1, 0.5) * (torch.eye(K) - 1 / K * torch.ones((K, K))).cuda()

    res = torch.norm(WH - sub, p='fro')
    return res.detach().cpu().numpy().item(), H

def compute_ETF(W):
    K = W.shape[0]
    WWT = torch.mm(W, W.T)
    WWT /= torch.norm(WWT, p='fro')

    sub = (torch.eye(K) - 1 / K * torch.ones((K, K))) / pow(K - 1, 0.5)
    ETF_metric = torch.norm(WWT - sub.to(W.device), p='fro')
    return ETF_metric.detach().cpu().numpy().item()

def angle_metric(m1_feature_dict, m2_feature_dict):
    numerator = []
    denominator = []
    
    m1_feature_means = []
    m2_feature_means = []
    for m1_key in m1_feature_dict:
        feature_cla = np.stack(m1_feature_dict[m1_key], axis=0)
        feature_cla_mean = np.mean(feature_cla, axis=0)
        m1_feature_means.append(feature_cla_mean)
        
    for m_key in m2_feature_dict:
        # m1
        key1, key2 = m_key[1:-1].split(' ')[-2:]
        key1, key2 = int(key1), int(key2)
        feature_key1 = np.stack(m1_feature_dict[key1], axis=0)
        feature_key2 = np.stack(m1_feature_dict[key2], axis=0)
        feature_key1_mean = np.mean(feature_key1, axis=0)
        feature_key2_mean = np.mean(feature_key2, axis=0)
        sum_feature_key12 = feature_key1_mean + feature_key2_mean
        sum_feature_key12 /= np.linalg.norm(sum_feature_key12)
        
        # m2
        m_feature = np.stack(m2_feature_dict[m_key], axis=0)
        m_feature_mean = np.mean(m_feature, axis=0)
        m_feature_mean /= np.linalg.norm(m_feature_mean)
        m2_feature_means.append(m_feature_mean)

        inner = np.sum(sum_feature_key12 * m_feature_mean)
        angle = np.degrees(np.arccos(inner))
        numerator.append(angle)
        
    # Get denominator
    for i in range(len(m1_feature_means)):
        for j in range(i+1, len(m1_feature_means)):
            for k in range(len(m2_feature_means)):
                fi_mean = m1_feature_means[i]
                fj_mean = m1_feature_means[j]
                fm2_mean = m2_feature_means[k] # m2
                fij_combine = fi_mean + fj_mean
                fij_combine /= np.linalg.norm(fij_combine) # m1
                inner = np.sum(fij_combine * fm2_mean)
                angle = np.degrees(np.arccos(inner))
                denominator.append(angle)
    print(len(numerator), len(denominator))
    return np.mean(numerator) / np.mean(denominator)

def imbalance_angle_stat(m1_feature_dict, m2_feature_dict, lengths = [500,50,5]):
    # For multiplicity 1 balance, multiplicity 2 not balance case
    cur_angle_list_1 = []
    cur_angle_list_2 = []
    cur_angle_list_3 = []
    for m_key in m2_feature_dict:
        key1, key2 = m_key[1:-1].split(' ')[-2:]
        key1, key2 = int(key1), int(key2)
        m_feature = np.stack(m2_feature_dict[m_key], axis=0)
        feature_key1 = np.stack(m1_feature_dict[key1], axis=0)
        feature_key2 = np.stack(m1_feature_dict[key2], axis=0)
        feature_key1_mean = np.mean(feature_key1, axis=0)
        feature_key2_mean = np.mean(feature_key2, axis=0)
        m_feature_mean = np.mean(m_feature, axis=0)
        sum_feature_key12 = feature_key1_mean + feature_key2_mean
        sum_feature_key12 /= np.linalg.norm(sum_feature_key12)

        inner = np.sum(sum_feature_key12 * (m_feature_mean / np.linalg.norm(m_feature_mean)))
        angle = np.degrees(np.arccos(inner))
        
        if len(m2_feature_dict[m_key]) == lengths[0]:
            cur_angle_list_1.append(angle)
        elif len(m2_feature_dict[m_key]) == lengths[1]:
            cur_angle_list_2.append(angle)
        elif len(m2_feature_dict[m_key]) == lengths[2]:
            cur_angle_list_3.append(angle)
        else:
            raise ValueError("Check Length!")
    
    return np.mean(cur_angle_list_1),np.mean(cur_angle_list_2),np.mean(cur_angle_list_3)

def calculate_nc_stats(model, trainloader, device, imbalance=False):
    mu_G, mu_c_dict, feature_dict = compute_info(model, trainloader, device)
    Sigma_W = compute_Sigma_W(feature_dict[1], mu_c_dict, device)
    Sigma_B = compute_Sigma_B(mu_c_dict, mu_G, device)
    collapse_metric = np.trace(Sigma_W @ scilin.pinv(Sigma_B)) / len(mu_c_dict)
    
    # NC2
    try:
        nc2_w = compute_ETF(model.fc.weight.data)
    except:
        nc2_w = compute_ETF(model.classifier.weight.data)
    
    feature_means = [mu_c_dict[i] for i in mu_c_dict.keys()]
    feature_means = torch.from_numpy(np.array(feature_means))
    nc2_h = compute_ETF(feature_means)
    
    # NC3 
    try:
        nc3,_ = compute_W_H_relation(model.fc.weight.data, mu_c_dict, mu_G)
    except:
        nc3,_ = compute_W_H_relation(model.classifier.weight.data, mu_c_dict, mu_G)
    
    # Angle_diff_list
    m1_feature_dict = feature_dict[1]
    m2_feature_dict = feature_dict[2]
    
    if imbalance:
        al1,al2,al3 = imbalance_angle_stat(m1_feature_dict, m2_feature_dict, lengths = [500,50,5])
    
    cur_angle_list = []
    for m_key in m2_feature_dict:
        key1, key2 = m_key[1:-1].split(' ')[-2:]
        key1, key2 = int(key1), int(key2)
        m_feature = np.stack(m2_feature_dict[m_key], axis=0)
        feature_key1 = np.stack(m1_feature_dict[key1], axis=0)
        feature_key2 = np.stack(m1_feature_dict[key2], axis=0)
        feature_key1_mean = np.mean(feature_key1, axis=0)
        feature_key2_mean = np.mean(feature_key2, axis=0)
        m_feature_mean = np.mean(m_feature, axis=0)
        sum_feature_key12 = feature_key1_mean + feature_key2_mean
        sum_feature_key12 /= np.linalg.norm(sum_feature_key12)

        inner = np.sum(sum_feature_key12 * (m_feature_mean / np.linalg.norm(m_feature_mean)))
        angle = np.degrees(np.arccos(inner))
        cur_angle_list.append(angle)
    
    # Angle metric
    angle_m = angle_metric(m1_feature_dict, m2_feature_dict)
    if imbalance:
        return collapse_metric, nc2_w, nc2_h, nc3, np.mean(cur_angle_list), angle_m, [al1,al2,al3]
    else:
        return collapse_metric, nc2_w, nc2_h, nc3, np.mean(cur_angle_list), angle_m
