"""
Code is partially adapted from (hao2022manifold)
"""


import torch
from torch.nn import functional as F
import numpy as np



def patch_based_manifold_Distillation(teacher, student, layer):
    """
    Computes \mathcal{L}_{MD} for trajectory at particular layer (j) where j \in [i+1,N] for patches
    - teacher refers to fixed embeddings for batch B
    - student refers to embdding from current masked forward pass
    """
    
    err = 0
    F_s = student
    F_t = teacher
    
    F_s = F.normalize(F_s, dim=-1)
    F_t = F.normalize(F_t, dim=-1)
    
    f_s = F_s.permute(1, 0, 2)
    f_t = F_t.permute(1, 0, 2)

    M_s = f_s.bmm(f_s.transpose(-1, -2))
    M_t = f_t.bmm(f_t.transpose(-1, -2))

    M_diff = M_t.mean(0) - M_s.mean(0)
    loss_mf_sample = (M_diff * M_diff).sum()
    
    err += -1*loss_mf_sample # We use a priority queue structure, so -1* assigns top of the queue
    
    return err

def manifold_Distillation(args, teacher, student):
    """
    Computes \mathcal{L}_{MD} for trajectory at particular layer (j) where j \in [i+1,N] for neurons or heads
    - teacher refers to fixed embeddings for batch B
    - student refers to embdding from current masked forward pass
    """
    
    err = 0
    F_s = student
    F_t = teacher
    
    F_s = F.normalize(F_s, dim=-1)
    F_t = F.normalize(F_t, dim=-1)
    
    
    
    
    bsz, patch_num, _ = F_s.shape
    K = 0.25*bsz*patch_num
    sampler = torch.randperm(bsz * patch_num)[:K]

    f_s = F_s.reshape(bsz * patch_num, -1)[sampler]
    f_t = F_t.reshape(bsz * patch_num, -1)[sampler]

    M_s = f_s.mm(f_s.T)
    M_t = f_t.mm(f_t.T)

    M_diff = M_t - M_s
    
    if args.aggregate == 'sum':
        loss_mf_rand = (M_diff * M_diff).sum()
    elif args.aggregate == 'mean':
        loss_mf_rand = (M_diff * M_diff).mean()
    else:
        raise Exception('aggregate not specified')
    
    err += -1*loss_mf_rand
    
    return err



def cnn_mmd(teacher, student):
    """
    Computes \mathcal{L}_{MD} for trajectory at particular layer (j) where j \in [i+1,N] for cnn output cheannels
    - teacher refers to fixed embeddings for batch B
    - student refers to embdding from current masked forward pass
    """
    err = 0
    F_s = student
    F_t = teacher
    
    F_s = F.normalize(F_s, dim=-1)
    F_t = F.normalize(F_t, dim=-1)
    
    
    loss_mf_rand += torch.sum((torch.mean(F_t, dim=0) - torch.mean(F_s, dim=0)))**2
    
    err += -1*loss_mf_rand
    
    return err



def KLDiv(TeacherOutput,StudentOutput, temp=4):
    """
    Computes \mathcal{L}_{KD} for trajectory at particular layer (j) where j \in [i+1,N] for patches
    - TeacherOutput refers to fixed logits
    - StudentOutput refers to logits from current masked forward pass
    """
    T = temp
    kl_div = F.kl_div(
            F.log_softmax(StudentOutput.logits, dim=1),
            F.log_softmax(TeacherOutput.logits, dim=1),
            reduction='sum',
            log_target=True
        ) * (T * T)
    
    kl_div = torch.clamp(kl_div, min=0)
    return -1*(kl_div)




def genericPrune(model, train_dataset, args, prunedProps):
    """
    Generic Implementation of Core Algorithim from Main Paper
    -Easily Expandable
    -Applies Hooks (on embeddings)
    -Iterates over prunable components
    """
    
    batch = next(iter(train_dataset))
    for k, v in batch.items():
        batch[k] = v.to("cuda", non_blocking=True)
        
    # Forward Hooking Classes to Pre-Compute the outputs
    modelObject = ModelHooking(args=args, model=model, maskProps=maskingProps) 
    base_logit_output, base_layer_wise_output = modelObject.forwardPass(batch)
    modelObject.purge_hooks()
    
    globalHeadRanking = PriorityQueue() 
    for layer in range(prunedProps["num_layers"]):
        
        for head in (range(prunedProps["num_att_head"])):
            
            modelObject = ModelHooking(args=args, model=model.eval(), maskProps=maskingProps)
            with torch.no_grad():
                current_logit_output, current_layer_wise_output = modelObject.forwardPass(batch)
            modelObject.purge_hooks()
            
            # Accumulating the MD Loss
            for idx in range(len(base_layer_wise_output)):
                    if idx > layer or ((layer == prunedProps["num_layers"]-1) and layer == idx and args.head_include_fin_layer_mmd):
                        with torch.no_grad():
                            err = manifold_Distillation(args, base_layer_wise_output[idx], current_layer_wise_output[idx])
                            MMDLayerResults += err
            # Obtain KL Loss
            if args.loss_type == "KL" or args.loss_type == "MMD+KL":    
                KLErr = args.lambda * KLDiv(base_logit_output, current_logit_output, temp=args.temp)
            
            globalHeadRanking.put((MMDResults.detach().cpu(), layer, head, "name"))
            
        for neuron in (range(prunedProps["inter_size"])):
            
            ... # Re-Use above...
            
    return globalHeadRanking, ....
def pruneModel(args, model, train_dataset, model_config):
    """
    Framework Compression Code for Pruning (apply standard hooks here.)
    - Executes Head Pruning and Neuron Pruning using salience metrics above
    - Easily adaptable wireframe for (patches, output channels, qkv modules, etc.)
    """
    
    prunedProps = {
        "num_att_head": model_config["num_attention_heads"],
       "inter_size": model_config["intermediate_size"],
       "hidden_size": model_config["hidden_size"],
       "num_layers":model_config["num_hidden_layers"],
       "patch_size": args.seq_len+1 
    }
    
        
    head_mask_results, intermediate_neuron_results = genericPrune(model, train_dataset, args, prunedProps)
    
    
    
    lProps = {
        "head_results": head_mask_results,
        "intermediate_results": intermediate_neuron_results,
        "mac_details": get_mac_details(args, prunedProps)
        
    }
        
    masks = globalRanking(args, prunedProps, lProps)
        
        
    pruningParams = {
    "head_mask":  masks["head_mask"], 
    "neuron_mask": masks["intermediate_mask"],
    }
        
    
    return pruningParams



def globalRanking(args, prunedProps, lProps):
    """
    Framework Compression Code For Searching
    - Executes Partition Search over attention heads and neurons
    - Expandable Searching with more components.
    """
    head_mask = lProps["head_results"]["final_head_ranking"]
    head_rank = [list((tensor_cpu.cpu().detach().item(), *rest)) for tensor_cpu, *rest in head_mask]
    head_rank = np.array(head_rank)
    
    neuron_mask = lProps["intermediate_results"]["final_neuron_ranking"]
    neuron_rank = [list((tensor_cpu.cpu().detach().item(), *rest)) for tensor_cpu, *rest in neuron_mask]
    neuron_rank = np.array(neuron_rank)
    
    head_mac = lProps["mac_details"]["head_mac"]
    neuron_mac = lProps["mac_details"]["neuron_mac"]
    baseline_mac = lProps["mac_details"]["base_mac"]
    
    capacity_mac = args.mac_constraint * baseline_mac
    
    
    max_importance = 0
    for num_heads in (range(1, prunedProps["num_att_head"]*prunedProps["num_layers"] + 1)):
        current_importance = 0
        
        for i in range(num_heads):
            score, _, _, _ = head_rank[i]
            current_importance += -1*float(score)
        
        count_head_mac = head_mac * (num_heads)
        remaining_mac = capacity_mac - count_head_mac
        
        num_neurons=0
        while remaining_mac >= neuron_mac and num_neurons < len(neuron_rank):
            score, neuron_layer, neuron_index, name = neuron_rank[num_neurons]
            current_importance += -1*float(score)
            num_neurons +=1 
            remaining_mac -= neuron_mac
        
        if current_importance > max_importance:
            max_importance = current_importance
            head_indicies = num_heads
            neuron_indicies = num_neurons
    
    final_head_mask = torch.zeros((prunedProps["num_layers"],prunedProps["num_att_head"]))
    final_neuron_mask = torch.zeros((prunedProps["num_layers"],prunedProps["inter_size"]))
    
    for i in range(head_indicies):
        score, head_layer, head_index, name = head_rank[i]
        final_head_mask[int(head_layer)][int(head_index)] = 1
        
    for i in range(neuron_indicies):
        score, neuron_layer, neuron_index, name = neuron_rank[i]
        final_neuron_mask[int(neuron_layer)][int(neuron_index)] = 1
    
    
    print(final_head_mask.sum(-1),final_neuron_mask.sum(-1))
    
    masks = {
        "head_mask": final_head_mask,
        "intermediate_mask": final_neuron_mask
    }
    
    return masks
