import os
import copy
import fire
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import matplotlib.pyplot as plt
from models.vpt_structure import build_promptmodel
from utils import public_utils
from utils import domainnet_data_utils, data_poison, PACS_utils, Officehome_utils, office_data_utils
from defenses.defense_align_new import defense_tdf_chif
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def setup_seed(seed = 42):  # setting up the random seed
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)    
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


def main(
        # base para
        poison: bool = False,
        poison_label_swap = 0,
        poison_number_per_batch: int = 8,
        seed: int = 1,
        batchsize: int = 32,
        epochs: int = 51,
        lr: float = 0.005,
        poison_lr: float = 0.005,
        local_epochs: int = 1,
        poison_local_epochs: int = 2,
        logout: str = "domainnet_vpt_res",
        data_base_path: str = 'data/PACS',
        dataset_name: str = "PACS",
        VPT_type: str = "Deep",
        defense: bool = False,
        pre_trained: bool = True,
        poison_client_indices: list = [3,6],  # Support multiple poisoned clients, pass list directly, e.g. [0,2,4]
        attack_type: str = "badnets",  
        defense_index: int = 0,
        clients_per_domain: int = 3,  # Number of clients per domain, default is 2
        mode: str = "both",  
    ):
    setup_seed(seed)

    poison_client_indices_list = poison_client_indices
    
    print("Attack type:", attack_type)
    print("data_base_path:", data_base_path)
    print("dataset_name:", dataset_name)
    print("logout:", logout)
    print("pre_trained:", pre_trained)
    print("poison:", poison)
    print("poison_label_swap:", poison_label_swap)
    print("poison_number_per_batch:", poison_number_per_batch)
    print("poison_client_indices:", poison_client_indices)

    # defense
    defense_list = [defense_tdf_chif]
    defense_name = defense_list[defense_index]
    print("defense_name:", defense_name)
    
    if dataset_name == "domainnet":
        # Train and test data loader - use multi-client version
        domain_names = ['Clipart', 'Infograph', 'Painting', 'Quickdraw', 'Real', 'Sketch']
        domain_num = len(domain_names)
        client_num = domain_num * clients_per_domain  # Number of clients per domain
        client_weights = [1/client_num for i in range(client_num)]
        
        # Create client names: domain_1, domain_2, ...
        client_names = []
        for domain in domain_names:
            for i in range(clients_per_domain):
                client_names.append(f"{domain}_{i+1}")
        
        print(f"DomainNet - Domains: {domain_num}, Clients per domain: {clients_per_domain}, Total clients: {client_num}")
        
        # Use new multi-client data preparation function
        train_loaders, val_loaders, test_loaders = domainnet_data_utils.prepare_data_multi(
            data_base_path, batchsize, clients_per_domain=clients_per_domain)
        
        poison_test_loaders = data_poison.poison_test_dataset(dataset_name, data_base_path, batchsize, poison_label_swap)
        # Create corresponding client allocation for poisoned test data
        poison_test_loaders_clients = []
        for i in range(domain_num):
            for j in range(clients_per_domain):
                poison_test_loaders_clients.append(poison_test_loaders[i])
        poison_test_loaders = poison_test_loaders_clients

        # server and clients model
        server_model = build_promptmodel(num_classes=10, edge_size=224, patch_size=16,
                            Prompt_Token_num=15, VPT_type=VPT_type).to(device)
        client_models = [copy.deepcopy(server_model).to(device) for idx in range(client_num)]

    if dataset_name == "PACS":
        domain_names = ['Art_painting', 'Cartoon', 'Photo', 'Sketch']
        domain_num = len(domain_names)
        client_num = domain_num * clients_per_domain  # Number of clients per domain
        client_weights = [1/client_num for i in range(client_num)]
        
        # Create client names: domain_1, domain_2, ...
        client_names = []
        for domain in domain_names:
            for i in range(clients_per_domain):
                client_names.append(f"{domain}_{i+1}")
        
        print(f"PACS - Domains: {domain_num}, Clients per domain: {clients_per_domain}, Total clients: {client_num}")
        
        # Use new multi-client data preparation function
        train_loaders, test_loaders = PACS_utils.prepare_PACS_multi(
            data_base_path, batchsize, clients_per_domain=clients_per_domain)
        
        poison_test_loaders = data_poison.poison_test_dataset(dataset_name, data_base_path, batchsize, poison_label_swap)
        # Create corresponding client allocation for poisoned test data
        poison_test_loaders_clients = []
        for i in range(domain_num):
            for j in range(clients_per_domain):
                poison_test_loaders_clients.append(poison_test_loaders[i])
        poison_test_loaders = poison_test_loaders_clients
        
        # server and clients model
        server_model = build_promptmodel(num_classes=7, edge_size=224, patch_size=16,
                            Prompt_Token_num=15, VPT_type=VPT_type).to(device)
        client_models = [copy.deepcopy(server_model).to(device) for idx in range(client_num)]

    if dataset_name == "officehome":
        domain_names = ['Art', 'Clipart', 'Product', 'RealWorld']
        domain_num = len(domain_names)
        client_num = domain_num * clients_per_domain  # Number of clients per domain
        client_weights = [1/client_num for i in range(client_num)]
        
        # Create client names: domain_1, domain_2, ...
        client_names = []
        for domain in domain_names:
            for i in range(clients_per_domain):
                client_names.append(f"{domain}_{i+1}")
        
        print(f"OfficeHome - Domains: {domain_num}, Clients per domain: {clients_per_domain}, Total clients: {client_num}")
        
        # Use new multi-client data preparation function
        train_loaders, test_loaders = Officehome_utils.prepare_OH_multi(
            data_base_path, batchsize, clients_per_domain=clients_per_domain)
        
        poison_test_loaders = data_poison.poison_test_dataset(dataset_name, data_base_path, batchsize, poison_label_swap)
        # Create corresponding client allocation for poisoned test data
        poison_test_loaders_clients = []
        for i in range(domain_num):
            for j in range(clients_per_domain):
                poison_test_loaders_clients.append(poison_test_loaders[i])
        poison_test_loaders = poison_test_loaders_clients
        
        # server and clients model
        server_model = build_promptmodel(num_classes=10, edge_size=224, patch_size=16,
                            Prompt_Token_num=15, VPT_type=VPT_type).to(device)
        client_models = [copy.deepcopy(server_model).to(device) for idx in range(client_num)]
        
    if dataset_name == "office_caltech":
        domain_names = ['amazon', 'caltech', 'dslr', 'webcam']
        domain_num = len(domain_names)
        client_num = domain_num * clients_per_domain  # Number of clients per domain
        client_weights = [1/client_num for i in range(client_num)]
        
        # Create client names: domain_1, domain_2, ...
        client_names = []
        for domain in domain_names:
            for i in range(clients_per_domain):
                client_names.append(f"{domain}_{i+1}")
        
        print(f"Office-Caltech - Domains: {domain_num}, Clients per domain: {clients_per_domain}, Total clients: {client_num}")
        
        # Use new multi-client data preparation function
        train_loaders, test_loaders = office_data_utils.prepare_data_multi(
            data_base_path, batchsize, train_ratio=0.5, clients_per_domain=clients_per_domain)

        poison_test_loaders = data_poison.poison_test_dataset(dataset_name, data_base_path, batchsize, poison_label_swap)
        # Create corresponding client allocation for poisoned test data
        poison_test_loaders_clients = []
        for i in range(domain_num):
            for j in range(clients_per_domain):
                poison_test_loaders_clients.append(poison_test_loaders[i])
        poison_test_loaders = poison_test_loaders_clients
        
        # server and clients model
        server_model = build_promptmodel(num_classes=10, edge_size=224, patch_size=16,
                            Prompt_Token_num=15, VPT_type=VPT_type).to(device)
        client_models = [copy.deepcopy(server_model).to(device) for idx in range(client_num)]


    print(f"Total domains: {domain_num}, Total clients: {client_num}")
    print(f"Client names: {client_names}")
    
    if poison:
        print(f"Poison client indices: {poison_client_indices_list}")
        poison_client_names = [client_names[idx] for idx in poison_client_indices_list]
        print(f"Poison clients: {poison_client_names}")
    else:
        print("No poison clients (poison=False)")

    if pre_trained:
        # Load pre-trained model for all clients
        for i in range(client_num):
            loaded_prompt_state_dict = torch.load(f'checkpoint/fedvpt_prompt_dict_{dataset_name}_multi_50_{clients_per_domain}.pth')
            client_models[i].load_prompt(loaded_prompt_state_dict)
        # Load pre-trained model for server model too
        server_model.load_prompt(loaded_prompt_state_dict)
        print(f"Loaded pre-trained model for all {client_num} clients and server")

    # Training
    best_acc = 0.
    loss_fun = nn.CrossEntropyLoss()
    for a_iter in range(epochs):
        optimizers = [optim.SGD(params=client_models[idx].parameters(), lr=lr) for idx in range(client_num)]
        
        # Create poison optimizers for all poisoned clients
        poison_optimizers = {}
        if poison:
            for poison_idx in poison_client_indices_list:
                poison_optimizers[poison_idx] = optim.SGD(params=client_models[poison_idx].parameters(), lr=poison_lr)
        
        # train
        for client_idx, model in enumerate(client_models):
            if poison and client_idx in poison_client_indices_list and a_iter > 0:
                for wi in range(poison_local_epochs):
                    train_loss, train_acc = public_utils.vpt_poison_train_backdoor_aaai(
                        model=model, 
                        data_loader=train_loaders[client_idx],
                        optimizer=poison_optimizers[client_idx],
                        loss_fun=loss_fun,
                        device=device,
                        poison_number_per_batch=poison_number_per_batch,
                        poison_label_swap=poison_label_swap,
                        attack_type=attack_type,
                        benign_gradients=benign_gradients,
                    )
                    print(f'Train Epoch:{a_iter} | {client_names[client_idx]} (POISONED) | Train iters: {wi} | Train Loss: {train_loss:.{4}f} | Train Acc: {train_acc:.{4}f}')

                # badnets attack doesn't need model replacement, remove constrain-and-scale related code
                    
            # clean train
            else:
                for wi in range(local_epochs):
                    train_loss, train_acc = public_utils.vpt_train(model, train_loaders[client_idx], optimizers[client_idx], loss_fun, device)
                    poison_status = " (POISONED)" if poison and client_idx in poison_client_indices_list else ""
                    print(f'Train Epoch:{a_iter} | {client_names[client_idx]}{poison_status} | Train iters: {wi} | Train Loss: {train_loss:.{4}f} | Train Acc: {train_acc:.{4}f}')
                
        d_weight = [1] * client_num
        # defense Algorithm
        if defense:
            if a_iter > 0: 
                update_dict = []
                for i in range(client_num):
                    update_dict_one = {}
                    client_prompt_state_dict = client_models[i].obtain_prompt()
                    # head
                    for key in client_prompt_state_dict["head"]:  # weight, bias
                        update_dict_one[key] = client_prompt_state_dict["head"][key].clone().detach()
                    update_dict_one["Prompt_Tokens"] = client_prompt_state_dict["Prompt_Tokens"].clone().detach()

                    update_dict.append(update_dict_one)
                
                print(f"Applying defense: {defense_name.__name__} with {client_num} clients")
                if defense_index == 7:
                    update_dict, benign_name_keys, d_weight = defense_name(dataset_name, client_num, server_model, client_models, update_dict, mode=mode)
                else:
                    update_dict, benign_name_keys, d_weight = defense_name(dataset_name, client_num, server_model, client_models, update_dict)
                print(f"Defense completed. Benign clients: {len(benign_name_keys)}/{client_num}")

        # aggregate
        benign_gradients = public_utils.communication_fedvpt(server_model, client_models, client_weights, client_num, d_weight)

        # Layered testing strategy: fast test + full test
        fast_test_values = list(range(epochs+1, epochs, 100))  # Fast test: every 20 rounds
        full_test_values = [epochs-1]  # Last few rounds
        
        if a_iter in fast_test_values or a_iter in full_test_values:
            test_acc_mean = []
            test_asrs = []
            test_robust_accs = []
            
            # Determine if it's fast test or full test
            is_full_test = a_iter in full_test_values
            test_type = "Full" if is_full_test else "Fast"
            
            # In multi-client setup, only test the first client of each domain (domain representative)
            domain_representative_clients = list(range(0, client_num, clients_per_domain))  # 0, clients_per_domain, 2*clients_per_domain, ...
            
            for client_idx in domain_representative_clients:
                # if poison
                if poison:
                    if is_full_test:
                        # full test
                        _, test_acc = public_utils.test_vpt(client_models[client_idx], poison_test_loaders[client_idx], loss_fun, device)
                        _, test_asr = public_utils.asr_test_vpt_backdoor_aaai(client_models[client_idx], poison_test_loaders, loss_fun, device, poison_label_swap)
                        _, test_robust_acc = public_utils.robust_accuracy_test_vpt_backdoor_aaai(client_models[client_idx], poison_test_loaders, loss_fun, device)
                    else:
                        # fast test (30% data sampling)
                        _, test_acc = public_utils.fast_test_vpt(client_models[client_idx], poison_test_loaders[client_idx], loss_fun, device, sample_ratio=0.3)
                        _, test_asr = public_utils.fast_asr_test_vpt_backdoor_aaai(client_models[client_idx], poison_test_loaders, loss_fun, device, poison_label_swap, sample_ratio=0.3)
                        _, test_robust_acc = public_utils.fast_robust_accuracy_test_vpt_backdoor_aaai(client_models[client_idx], poison_test_loaders, loss_fun, device, sample_ratio=0.3)
                    
                    test_asrs.append(test_asr[client_idx] if isinstance(test_asr, list) else test_asr)
                    test_robust_accs.append(test_robust_acc[client_idx] if isinstance(test_robust_acc, list) else test_robust_acc)
                else:
                    if is_full_test:
                        _, test_acc = public_utils.test_vpt(client_models[client_idx], test_loaders[client_idx], loss_fun, device)
                    else:
                        _, test_acc = public_utils.fast_test_vpt(client_models[client_idx], test_loaders[client_idx], loss_fun, device, sample_ratio=0.3)
                    
                    test_asrs.append(0.0)  # ASR is 0 when not poisoned
                    test_robust_accs.append(test_acc if isinstance(test_acc, float) else np.mean(test_acc))  # Robust ACC equals normal ACC when not poisoned
                    
                test_acc_mean.append(test_acc)

                # Get domain name (remove _1 suffix)
                domain_name = client_names[client_idx].replace("_1", "")
                poison_status = " (POISONED)" if poison and client_idx in poison_client_indices_list else ""
                test_acc_val = test_acc if isinstance(test_acc, float) else np.mean(test_acc)
                print(f'[{test_type} Test] Epoch:{a_iter} | {domain_name}{poison_status} | Test Acc: {test_acc_val:.{4}f} | Test ASR: {test_asrs[-1]:.4f} | Robust Acc: {test_robust_accs[-1]:.4f}')
                
                if logout:
                    with open(f'{logout}.txt', 'a+') as f_out:
                        f_out.write(f'[{test_type} Test] Epoch:{a_iter} | {domain_name}{poison_status} | Test Acc: {test_acc_val:.{4}f} | Test ASR: {test_asrs[-1]:.4f} | Robust Acc: {test_robust_accs[-1]:.4f}\n')   

            # Calculate overall average ASR and Robust Accuracy
            overall_asr = np.mean(test_asrs)
            overall_robust_acc = np.mean(test_robust_accs)
            print(f'[{test_type} Test] Epoch:{a_iter} | Overall Average ASR: {overall_asr:.4f} | Overall Average Robust Acc: {overall_robust_acc:.4f}')
            
            # Since we have tested by domain, test_acc_mean, test_asrs and test_robust_accs are the results of each domain
            print(f'[{test_type} Test] Epoch:{a_iter} | Domain Summary:')
            for domain_idx in range(domain_num):
                domain_acc = np.mean(test_acc_mean[domain_idx]) if not isinstance(test_acc_mean[domain_idx], float) else test_acc_mean[domain_idx]
                domain_asr = test_asrs[domain_idx]
                domain_robust_acc = test_robust_accs[domain_idx]
                print(f'[{test_type} Test] Epoch:{a_iter} | Domain {domain_names[domain_idx]} | Test Acc: {domain_acc:.4f} | Test ASR: {domain_asr:.4f} | Robust Acc: {domain_robust_acc:.4f}')
            
            if logout:
                with open(f'{logout}.txt', 'a+') as f_out:   
                    f_out.write(f'[{test_type} Test] Epoch:{a_iter} | Overall Average ASR: {overall_asr:.4f} | Overall Average Robust Acc: {overall_robust_acc:.4f}\n')
                    for domain_idx in range(domain_num):
                        domain_acc = np.mean(test_acc_mean[domain_idx]) if not isinstance(test_acc_mean[domain_idx], float) else test_acc_mean[domain_idx]
                        domain_asr = test_asrs[domain_idx]
                        domain_robust_acc = test_robust_accs[domain_idx]
                        f_out.write(f'[{test_type} Test] Epoch:{a_iter} | Domain {domain_names[domain_idx]} | Test Acc: {domain_acc:.4f} | Test ASR: {domain_asr:.4f} | Robust Acc: {domain_robust_acc:.4f}\n')

            # best acc    
            # Process test_acc_mean to ensure correct mean calculation
            processed_acc_values = []
            for acc in test_acc_mean:
                if isinstance(acc, float):
                    processed_acc_values.append(acc)
                else:
                    processed_acc_values.append(np.mean(acc))
            
            global_acc_mean = np.mean(processed_acc_values)
            if best_acc < global_acc_mean:
                best_acc = global_acc_mean
            print(f'[{test_type} Test] Epoch:{a_iter} | Global | Test Acc Mean: {global_acc_mean:.{4}f} | Best Acc: {best_acc:.{4}f}')
            
            if logout:
                with open(f'{logout}.txt', 'a+') as f_out:
                    f_out.write(f'[{test_type} Test] Epoch:{a_iter} | Global | Test Acc Mean: {global_acc_mean:.{4}f} | Best Acc: {best_acc:.{4}f}\n')

        #    save model
        if a_iter == epochs - 1 and pre_trained == False:
            torch.save(server_model.obtain_prompt(), f'checkpoint/fedvpt_prompt_dict_{dataset_name}_multi_{a_iter + 1}_{clients_per_domain}.pth')


if __name__ == '__main__':
    fire.Fire(main)