import os
import time
import torch
import numpy as np
from tqdm import tqdm

def mimic_layer_def(mimic_layer_type, mimic_layer_idx):
    
    if mimic_layer_type == "visual_attn_out":
        mimic_layer_name = f'visual.transformer.resblocks.{mimic_layer_idx}.attn.out_proj.weight'
    elif mimic_layer_type == "visual_attn_in":
        mimic_layer_name = f'visual.transformer.resblocks.{mimic_layer_idx}.attn.in_proj_weight'
    elif mimic_layer_type == "visual_mlp_fc":
        mimic_layer_name = f'visual.transformer.resblocks.{mimic_layer_idx}.mlp.c_fc.weight'
    elif mimic_layer_type == "visual_mlp_proj":
        mimic_layer_name = f'visual.transformer.resblocks.{mimic_layer_idx}.mlp.c_proj.weight'
    elif mimic_layer_type == "visual_ln":
        mimic_layer_name = 'visual.ln_post.weight'
    elif mimic_layer_type == "text_attn_in":
        mimic_layer_name = f'transformer.resblocks.{mimic_layer_idx}.attn.in_proj_weight'
    elif mimic_layer_type == "text_attn_out":
        mimic_layer_name = f'transformer.resblocks.{mimic_layer_idx}.attn.out_proj.weight'
    elif mimic_layer_type == "text_mlp_fc":
        mimic_layer_name = f'transformer.resblocks.{mimic_layer_idx}.mlp.c_fc.weight'
    elif mimic_layer_type == "text_mlp_proj":
        mimic_layer_name = f'transformer.resblocks.{mimic_layer_idx}.mlp.c_proj.weight'
    elif mimic_layer_type == "token":
        mimic_layer_name = 'token_embedding.weight'
    elif mimic_layer_type == "text_ln":
        mimic_layer_name = 'ln_final.weight'

    return mimic_layer_name
        
def compute_task_vector(current_model, reference_model, mimic_layer_name):

    current_model_state_dict = current_model.state_dict()
    reference_model_state_dict = reference_model.state_dict()

    task_vector = reference_model_state_dict[mimic_layer_name] - current_model_state_dict[mimic_layer_name]
    
    return task_vector
   
def compute_mimic_score(model, per_sample_loss, task_vector, mimic_layer_type="visual_mlp_fc", mimic_layer_idx=11, method="normed_proj", temperature=1.0):

    tv_flat = task_vector.flatten()
    normed_tv_flat = torch.nn.functional.normalize(tv_flat, p=2, dim=0)
    
    mimic_score_tensor = torch.zeros(len(per_sample_loss), device=task_vector.device)
    
    ## compute per-sample gradient ##
    for i, loss_ in tqdm(enumerate(per_sample_loss), total=len(per_sample_loss)):
        if mimic_layer_type == "visual_attn_out":
            grad = torch.autograd.grad(loss_, model.visual.transformer.resblocks[mimic_layer_idx].attn.out_proj.weight, retain_graph=True)[0]
        elif mimic_layer_type == "visual_attn_in":
            grad = torch.autograd.grad(loss_, model.visual.transformer.resblocks[mimic_layer_idx].attn.in_proj_weight, retain_graph=True)[0]
        elif mimic_layer_type == "visual_mlp_fc":
            grad = torch.autograd.grad(loss_, model.visual.transformer.resblocks[mimic_layer_idx].mlp.c_fc.weight, retain_graph=True)[0]
        elif mimic_layer_type == "visual_mlp_proj":
            grad = torch.autograd.grad(loss_, model.visual.transformer.resblocks[mimic_layer_idx].mlp.c_proj.weight, retain_graph=True)[0]
        elif mimic_layer_type == "visual_ln":
            grad = torch.autograd.grad(loss_, model.visual.ln_post.weight, retain_graph=True)[0]
        elif mimic_layer_type == "text_attn_out":
            grad = torch.autograd.grad(loss_, model.transformer.resblocks[mimic_layer_idx].attn.out_proj.weight, retain_graph=True)[0]
        elif mimic_layer_type == "text_attn_in":
            grad = torch.autograd.grad(loss_, model.transformer.resblocks[mimic_layer_idx].attn.in_proj_weight, retain_graph=True)[0]
        elif mimic_layer_type == "text_mlp_fc":
            grad = torch.autograd.grad(loss_, model.transformer.resblocks[mimic_layer_idx].mlp.c_fc.weight, retain_graph=True)[0]
        elif mimic_layer_type == "text_mlp_proj":
            grad = torch.autograd.grad(loss_, model.transformer.resblocks[mimic_layer_idx].mlp.c_proj.weight, retain_graph=True)[0]
        elif mimic_layer_type == "token":
            grad = torch.autograd.grad(loss_, model.token_embedding.weight, retain_graph=True)[0]
        elif mimic_layer_type == "text_ln":
            grad = torch.autograd.grad(loss_, model.ln_final.weight, retain_graph=True)[0]
        
        ## turn to negative gradient ##
        neg_grad = -grad.clone()
        neg_grad_flat = neg_grad.flatten()
        
        ## compute mimic score ##
        if method.endswith("cos"):
            normed_neg_grad_flat = torch.nn.functional.normalize(neg_grad_flat, p=2, dim=0)
            score = torch.dot(normed_neg_grad_flat, normed_tv_flat)
        elif method.endswith("proj"):
            score = torch.dot(neg_grad_flat, normed_tv_flat)
            
        mimic_score_tensor[i] = score.item()

    if method.startswith("normed"):
        mimic_score_tensor = mimic_score_tensor / temperature
        normed_mimic_score_tensor = torch.nn.Softmax(dim=0)(mimic_score_tensor)
        return normed_mimic_score_tensor
    else:
        return mimic_score_tensor
    
def reweigh_gradients(model, reference_model, per_sample_loss, mimic_layer_type, mimic_layer_idx, mimic_method, mimic_temperature):

    ## define mimic layer name ##
    mimic_layer_name = mimic_layer_def(mimic_layer_type, mimic_layer_idx)

    ## compute task vector ##
    specific_layer_task_vector = compute_task_vector(model, reference_model, mimic_layer_name)
    
    ## compute per-sample gradients then find mimic score ##
    mimic_score = compute_mimic_score(model, per_sample_loss, specific_layer_task_vector, mimic_layer_type=mimic_layer_type, mimic_layer_idx=mimic_layer_idx, method=mimic_method, temperature=mimic_temperature)

    ## reweigh loss to calibrate gradients of entire model ##
    calibrated_loss = per_sample_loss @ mimic_score
    
    return calibrated_loss, mimic_score
