import os, sys, time
from communications import mix_partial_parameters
from Clients.groups import get_groups
import numpy as np
import argparse
import importlib
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

def parse_args():
    parser = argparse.ArgumentParser(description="Configure the experiment settings")
    
    # Define the arguments
    parser.add_argument('--seed', 
        type=int, default=42, help='Random seed')
    parser.add_argument('--n_clients', 
        type=int, default=30, help='Number of clients (total)')
    # har, cap, ave, modelnet
    parser.add_argument('--dataset', 
        type=str, default="cap", help='Dataset name')
    parser.add_argument('--non_iid_alpha', 
        type=float, default=0.5, help='Non-IID alpha')
    parser.add_argument('--batch_size', 
        type=int, default=64, help='Batch size')
    parser.add_argument('--n_epochs', 
        type=int, default=1, help='Number of local epochs')
    parser.add_argument('--rounds', 
        type=int, default=200, help='Number of communication rounds')
    parser.add_argument('--topology', 
        type=str, default="ring", help='Network topology for each modality group')
    parser.add_argument('--n_neighbours', 
        type=int, default=2, help='Number of neighbors to communicate in each round (Random Gossip)')
    parser.add_argument('--group_size', 
        type=int, default=10, help='Number of clients in each unimodal group')
    parser.add_argument('--learning_rate', 
        type=float, default=0.1, help='Learning rate')
    parser.add_argument('--sample_type', 
                        type=str, default="ring", help='Sample type for FedLay, sample a ring or just neighbors')
    #choices=["robust", "modality", "task", "hybrid"]
    parser.add_argument('--mode', 
        type=str, default="hybrid", help='Mode of operation')
    parser.add_argument('--contrastive', type=float, default=0.7, help='contrastive ratio')
    
    return parser.parse_args()

####### 1. Parse arguments ######
configs = parse_args()
print("\nParsed arguments:")
for key, value in vars(configs).items():
    print(f"{key}: {value}")

######2. get the data loaders #######
data_module = importlib.import_module(f'datasets.{configs.dataset}')

try:
    get_loaders = getattr(data_module, 'get_loaders')
except AttributeError:
    raise ImportError(f"`get_loader` not found in datasets.{configs.dataset}")

n_clients = configs.n_clients
client_dataloaders, test_dataloader = get_loaders(n_clients, configs)

######3. Get setup according to selected method ######
setup_module = importlib.import_module(f'Clients.{configs.dataset}')
try:
    get_setup = getattr(setup_module, 'get_setup')
except AttributeError:
    raise ImportError(f"`get_setup` not found in setup.{configs.dataset}")

encoder_fns, classifier_dim, output_dim = get_setup(configs)
######4. Setup the communication networks ######
modality_groups, client_groups, modality_topologies, task_topologies, whole_group, global_topologies = get_groups(configs, client_dataloaders, encoder_fns, classifier_dim, output_dim)


# client_results = evaluate_models(test_dataloader, clients)
# print(f"Round 0- Test Loss: {client_results['loss']:.4f}, Test Accuracy: {client_results['accuracy']:.2f}")


"""train the clients"""
acc_list = []
for r in range(configs.rounds):
    time_start = time.time()

    # Train each client
    for group in client_groups:
        for client in group:
            client.train(configs.n_epochs)
            
    # Mix per-modality parts
    if configs.mode != "task":
        for i, group in enumerate(modality_groups):
            W = modality_topologies[i].mixing_matrix().cuda()
            mix_partial_parameters(group, W, f"encoders.{i}")
            if configs.mode != "hybrid":
                mix_partial_parameters(group, W, f"classifiers.{i}")
            else:
                mix_partial_parameters(group, W, "classifier")
            

    # Robust setting: mix shared + synergy modules
    if configs.mode == "robust":
        #simulated broadcast
        W = global_topologies[0].mixing_matrix().cuda()
        mix_partial_parameters(whole_group, W, "shared_classifier")
        #only multimodal group sharing
        W = task_topologies[-1].mixing_matrix().cuda()
        mix_partial_parameters(client_groups[-1], W, "synergy_classifier")
        # Optional
        for i in range(len(modality_groups)):
            mix_partial_parameters(client_groups[-1], W, f"projects.{i}")

    # Task-based: mix all modules
    if configs.mode == "task":
        for i, group in enumerate(client_groups):
            W = task_topologies[i].mixing_matrix().cuda()
            mix_partial_parameters(group, W, "all")
            
    """test the clients"""
    client_results = []
    for i, group in enumerate(client_groups):
        group_n = len(group)
        group_acc = 0.0
        for client in group:
            res = client.test(test_dataloader)
            group_acc += res
        group_acc /= (group_n + 0.0001)
        client_results.append(group_acc)
    acc_list.append(client_results)
    print(f"Round {r+1} - Group Test Accuracy: " + " | ".join([f"{acc:.4f}" for acc in client_results]))
    time_end = time.time()
    print(f"Time taken: {time_end - time_start:.2f} seconds")
    
np.save(f"./save/{configs.dataset}/{configs.mode}_{configs.topology}_{configs.non_iid_alpha}_groupn_{configs.group_size}_list_{configs.contrastive}.npy", acc_list)





