import torch
import torch.nn as nn
import collections
import copy
import numpy as np
import psutil
import time
import random
from thop import profile
from utils.data_poison import get_poison_batch, get_poison_batch_aaai, get_ra_batch_aaai
from utils.sup_contrastive_loss import SupConLoss
import torch.nn.functional as F
from collections import OrderedDict

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def communication_fedvpt(server_model, models, client_weights, client_num, d_weight):
    with torch.no_grad():
        # server prompt
        server_prompt_state_dict = server_model.obtain_prompt() # head, prompt_tokens

        initial_server_prompt_state_dict = {
            "head": {key: server_prompt_state_dict["head"][key].clone().detach() for key in server_prompt_state_dict["head"]},
            "Prompt_Tokens": server_prompt_state_dict["Prompt_Tokens"].clone().detach()
        }

        # head
        for key in server_prompt_state_dict["head"]:  # weight, bias
            temp = torch.zeros_like(server_prompt_state_dict["head"][key], dtype=torch.float32)
            for client_idx in range(client_num):
                client_prompt_state_dict = models[client_idx].obtain_prompt()
                temp += d_weight[client_idx] * client_weights[client_idx] * client_prompt_state_dict["head"][key]
            server_prompt_state_dict["head"][key].data.copy_(temp)

        # Prompt_Tokens
        temp1 = torch.zeros_like(server_prompt_state_dict["Prompt_Tokens"], dtype=torch.float32)          
        for client_idx in range(client_num):
            client_prompt_state_dict = models[client_idx].obtain_prompt()
            temp1 += d_weight[client_idx] * client_weights[client_idx] * client_prompt_state_dict["Prompt_Tokens"]
        server_prompt_state_dict["Prompt_Tokens"].data.copy_(temp1)

        # aggregate and distribute
        server_model.load_prompt(server_prompt_state_dict)
        for client_idx in range(client_num):
            models[client_idx].load_prompt(server_prompt_state_dict)
            
        benign_gradient = {
            "head": {key: server_prompt_state_dict["head"][key].clone().detach() - initial_server_prompt_state_dict["head"][key] for key in server_prompt_state_dict["head"]},
            "Prompt_Tokens": server_prompt_state_dict["Prompt_Tokens"].clone().detach() - initial_server_prompt_state_dict["Prompt_Tokens"]
        }

    return benign_gradient


# vpt client train
def vpt_train(model, data_loader, optimizer, loss_fun, device):
    model.train()
    loss_all = 0
    total = 0
    correct = 0

    for batch_id, (data, target) in enumerate(data_loader):
        # start_time = time.time()

        optimizer.zero_grad()
        data = data.to(device)
        target = target.to(device)
        output, _ = model(data)
        # print(output.shape)
        loss = loss_fun(output, target)

        
        loss_all += loss.item()
        total += target.size(0)
        pred = output.data.max(1)[1]
        correct += pred.eq(target.view(-1)).sum().item()

        loss.backward()
        optimizer.step()

        # end_time = time.time()
        # print(f"Total Runtime: {end_time - start_time} seconds")
        
    return loss_all / len(data_loader), correct/total
    

# Vpt Malicious client train
def vpt_poison_train_backdoor_aaai(model, data_loader, optimizer, loss_fun, device, poison_number_per_batch, poison_label_swap, attack_type, benign_gradients):
    model.train()
    loss_all = 0
    total = 0
    correct = 0

    scl = SupConLoss()
    for batch_id, batch in enumerate(data_loader):
        optimizer.zero_grad()
        if attack_type == "badnets":
            data, targets = get_poison_batch_aaai(batch, poison_label_swap, poison_number_per_batch, evaluation=False)
            data, targets = data.to(device), targets.to(device)
            output, _ = model(data)
            loss = loss_fun(output, targets)

            loss_all += loss.item()
            total += targets.size(0)
            pred = output.data.max(1)[1]
            correct += pred.eq(targets.view(-1)).sum().item()

            loss.backward()
            optimizer.step()

        elif attack_type == "neurotoxin":
            bottom_k_masks = select_bottom_k_fedvpt(benign_gradients, k_percent=98)
            data, targets= get_poison_batch_aaai(batch, poison_label_swap, poison_number_per_batch, evaluation=False)
            data = data.to(device)
            targets = targets.to(device)
            output, _ = model(data)
            loss = loss_fun(output, targets)

            loss.backward()
            
            grads = {"head": {}, "Prompt_Tokens": None}
            for name, param in model.named_parameters():
                if param.grad is not None:
                    if 'head' in name:
                        key = name.split('.')[-1]
                        grads["head"][key] = param.grad.clone().detach()
                    elif 'Prompt_Tokens' in name:
                        grads["Prompt_Tokens"] = param.grad.clone().detach()

            projected_grads = project_gradients_fedvpt(grads, bottom_k_masks)

            for name, param in model.named_parameters():
                if 'head' in name:
                    key = name.split('.')[-1]
                    if key in projected_grads["head"]:
                        param.grad.copy_(projected_grads["head"][key])
                elif 'Prompt_Tokens' in name and projected_grads["Prompt_Tokens"] is not None:
                    param.grad.copy_(projected_grads["Prompt_Tokens"])
            
            optimizer.step()
            
            loss_all += loss.item()
            total += targets.size(0)     
            pred = output.data.max(1)[1]
            correct += pred.eq(targets.view(-1)).sum().item()

        if attack_type == "CBI":
            data, targets = get_poison_batch_aaai(batch, poison_label_swap, poison_number_per_batch)
            data = data.to(device)
            targets = targets.to(device)

            output, features = model(data)
            batch_count = data.size(0)
            nor_out = torch.nn.functional.normalize(features, p=2.0, dim=1, eps=1e-12, out=None)  
            scl_loss = scl.forward(nor_out.reshape(batch_count, 1, -1), targets)
            loss = loss_fun(output, targets)  + 0.5 * scl_loss

            loss_all += loss.item()
            total += targets.size(0)
            pred = output.data.max(1)[1]
            correct += pred.eq(targets.view(-1)).sum().item()

            loss.backward()
            optimizer.step()
        
    return loss_all / len(data_loader), correct/total


# Benign Test vpt
def test_vpt(model, data_loader, loss_fun, device):
    model.eval()
    loss_all = 0
    total = 0
    correct = 0
    for data, target in data_loader:

        data = data.to(device)
        target = target.to(device) 
        output, _ = model(data)
        # print(target)
        loss = loss_fun(output, target)

        loss_all += loss.item()
        total += target.size(0)
        pred = output.data.max(1)[1]
        correct += pred.eq(target.view(-1)).sum().item()

    return loss_all / len(data_loader), correct/total


def fast_test_vpt(model, data_loader, loss_fun, device, sample_ratio=0.3):
    """
    Fast test by sampling a subset of the test data
    """
    model.eval()
    loss_all = 0
    total = 0
    correct = 0
    
    # Calculate how many batches to sample
    total_batches = len(data_loader)
    sample_batches = max(1, int(total_batches * sample_ratio))
    
    # Sample random batches
    sampled_indices = random.sample(range(total_batches), sample_batches)
    
    with torch.no_grad():
        for batch_id, batch in enumerate(data_loader):
            if batch_id not in sampled_indices:
                continue
                
            data, targets = batch
            data = data.to(device)
            targets = targets.to(device)

            output, _ = model(data)
            loss = loss_fun(output, targets)

            loss_all += loss.item()
            total += targets.size(0)
            pred = output.data.max(1)[1]
            correct += pred.eq(targets.view(-1)).sum().item()

    return loss_all / sample_batches, correct / total if total > 0 else 0


def asr_test_vpt_backdoor_aaai(model, data_loaders, loss_fun, device, poison_label_swap):
    model.eval()
    asrs = []  
    
    for loader in data_loaders:
        loss_all = 0
        total = 0
        correct = 0

        for batch_id, batch in enumerate(loader):
            data, targets = get_poison_batch_aaai(batch, poison_label_swap, 4, evaluation=True)
            data = data.to(device)
            targets = targets.to(device)

            output, _ = model(data)
            loss = loss_fun(output, targets) 

            loss_all += loss.item()
            total += targets.size(0)
            pred = output.data.max(1)[1]
                
            correct += pred.eq(targets.view(-1)).sum().item()

            
        current_asr = correct/total
        asrs.append(current_asr)
        
    return loss_all / len(loader), asrs  


def fast_asr_test_vpt_backdoor_aaai(model, data_loaders, loss_fun, device, poison_label_swap, sample_ratio=0.3):
    """
    Fast ASR test by sampling a subset of the test data
    """
    model.eval()
    asrs = []  
    
    for loader in data_loaders:
        loss_all = 0
        total = 0
        correct = 0
        
        # Calculate how many batches to sample
        total_batches = len(loader)
        sample_batches = max(1, int(total_batches * sample_ratio))
        
        # Sample random batches
        sampled_indices = random.sample(range(total_batches), sample_batches)

        for batch_id, batch in enumerate(loader):
            if batch_id not in sampled_indices:
                continue
                
            data, targets = get_poison_batch_aaai(batch, poison_label_swap, 4, evaluation=True)
            data = data.to(device)
            targets = targets.to(device)

            output, _ = model(data)
            loss = loss_fun(output, targets) 

            loss_all += loss.item()
            total += targets.size(0)
            pred = output.data.max(1)[1]
                
            correct += pred.eq(targets.view(-1)).sum().item()

            
        current_asr = correct/total if total > 0 else 0
        asrs.append(current_asr)
        
    return loss_all / sample_batches, asrs  


def robust_accuracy_test_vpt_backdoor_aaai(model, data_loaders, loss_fun, device):
    """
    Test robust accuracy: model's accuracy on triggered images with original labels.
    """
    model.eval()
    robust_accuracies = []  
    
    for loader in data_loaders:
        loss_all = 0
        total = 0
        correct = 0

        for batch_id, batch in enumerate(loader):
            data, targets = get_ra_batch_aaai(batch, 4, evaluation=True)
            data = data.to(device)
            targets = targets.to(device)

            output, _ = model(data)
            loss = loss_fun(output, targets) 

            loss_all += loss.item()
            total += targets.size(0)
            pred = output.data.max(1)[1]
                
            correct += pred.eq(targets.view(-1)).sum().item()

            
        current_robust_accuracy = correct/total
        robust_accuracies.append(current_robust_accuracy)
        
    return loss_all / len(loader), robust_accuracies

def fast_robust_accuracy_test_vpt_backdoor_aaai(model, data_loaders, loss_fun, device, sample_ratio=0.3):
    """
    Fast robust accuracy test by sampling a subset of the test data
    """
    model.eval()
    robust_accuracies = []  
    
    for loader in data_loaders:
        loss_all = 0
        total = 0
        correct = 0
        
        # Calculate how many batches to sample
        total_batches = len(loader)
        sample_batches = max(1, int(total_batches * sample_ratio))
        
        # Sample random batches
        sampled_indices = random.sample(range(total_batches), sample_batches)

        for batch_id, batch in enumerate(loader):
            if batch_id not in sampled_indices:
                continue
                
            data, targets = get_ra_batch_aaai(batch, 4, evaluation=True)
            data = data.to(device)
            targets = targets.to(device)

            output, _ = model(data)
            loss = loss_fun(output, targets) 

            loss_all += loss.item()
            total += targets.size(0)
            pred = output.data.max(1)[1]
                
            correct += pred.eq(targets.view(-1)).sum().item()

            
        current_robust_accuracy = correct/total if total > 0 else 0
        robust_accuracies.append(current_robust_accuracy)
        
    return loss_all / sample_batches, robust_accuracies

def select_bottom_k_fedvpt(benign_gradient, k_percent):
    # Process head part gradients
    bottom_k_masks = {"head": {}, "Prompt_Tokens": None}
    for key, g in benign_gradient["head"].items():
        flat_grad = g.view(-1)
        bottom_k_values, _ = torch.topk(flat_grad.abs(), int(k_percent * flat_grad.numel() / 100), largest=False)
        threshold = bottom_k_values.max()
        bottom_k_masks["head"][key] = (flat_grad.abs() <= threshold).float()
    
    # Process Prompt_Tokens part gradients
    flat_grads_prompt = benign_gradient["Prompt_Tokens"].view(-1)
    bottom_k_values, _ = torch.topk(flat_grads_prompt.abs(), int(k_percent * flat_grads_prompt.numel() / 100), largest=False)
    threshold = bottom_k_values.max()
    bottom_k_masks["Prompt_Tokens"] = (flat_grads_prompt.abs() <= threshold).float()
    
    return bottom_k_masks

def project_gradients_fedvpt(grads, bottom_k_masks):
    projected_grads = {"head": {}, "Prompt_Tokens": None}
    
    for key, g in grads["head"].items():
        flat_grad = g.view(-1)
        projected_flat_grad = flat_grad * bottom_k_masks["head"][key]
        projected_grads["head"][key] = projected_flat_grad.view_as(g)
    
    flat_grads_prompt = grads["Prompt_Tokens"].view(-1)
    projected_flat_grad_prompt = flat_grads_prompt * bottom_k_masks["Prompt_Tokens"]
    projected_grads["Prompt_Tokens"] = projected_flat_grad_prompt.view_as(grads["Prompt_Tokens"])
    
    return projected_grads
