import torch
from tools import *
import random
from options import args_parser
args = args_parser()


def model_assign_generator(client_number, args=args):
    # should be the 5*
    result_dict = {}
    client_id_model_name = {}
    client_num_each_model = int(client_number/4)
    # print(client_num_each_model)
    model_list = ['CNN1', 'CNN2', 'CNN3', 'CNN4']
    for i in range(4):
        for j in range(client_num_each_model):
            # print(i+j)
            if model_list[i] == 'CNN1':
                result_dict[i*client_num_each_model+j] = CNN1(args=args)
                client_id_model_name[i*client_num_each_model+j] = 'CNN1'
            elif model_list[i] == 'CNN2':
                result_dict[i*client_num_each_model+j] = CNN2(args=args)
                client_id_model_name[i*client_num_each_model+j] = 'CNN2'
            elif model_list[i] == 'CNN3':
                result_dict[i*client_num_each_model+j] = CNN3(args=args)
                client_id_model_name[i*client_num_each_model+j] = 'CNN3'
            elif model_list[i] == 'CNN4':
                result_dict[i*client_num_each_model+j] = CNN4(args=args)
                client_id_model_name[i*client_num_each_model+j] = 'CNN4'
            # elif model_list[i] == 'E':
            #     result_dict[i*client_num_each_model+j] = CNN5(args=args)
            else:
                print("no correct model loaded, exit!")  
                exit()        
    return result_dict, client_id_model_name

def new_model_generator(client_number,args=args):
    model_distribution = {
        'CNN1': 12,
        'CNN2': 12,
        'CNN3': 12,
        'CNN4': 14 
    }
    

    total = sum(model_distribution.values())
    if total != client_number:
        print(f"Warning: Model distribution total ({total}) != client_number ({client_number})")
        
        scale = client_number / total
        model_distribution = {k: int(v * scale) for k, v in model_distribution.items()}
    
        remainder = client_number - sum(model_distribution.values())
        model_names = list(model_distribution.keys())
        for i in range(remainder):
            model_distribution[model_names[i % len(model_names)]] += 1
    

    client_index_list = list(range(client_number))
    random.shuffle(client_index_list)

    client_model_dict = {}
    client_model_instances = {}
    start = 0
    
    for model_name, count in model_distribution.items():
        assigned_clients = client_index_list[start:start+count]
        start += count
        
        for client_id in assigned_clients:
            client_model_dict[client_id] = model_name
            
           
            if model_name == 'CNN1':
                client_model_dict[client_id] = CNN1(args=args)
                client_model_instances[client_id] = 'CNN1'
            elif model_name == 'CNN2':
                client_model_dict[client_id] = CNN2(args=args)
                client_model_instances[client_id] = 'CNN2'
            elif model_name == 'CNN3':
                client_model_dict[client_id] = CNN3(args=args)
                client_model_instances[client_id] = 'CNN3'
            elif model_name == 'CNN4':
                client_model_dict[client_id] = CNN4(args=args)
                client_model_instances[client_id] = 'CNN4'
            else:
                raise ValueError(f"Unknown model name: {model_name}")

   
    return client_model_dict, client_model_instances

def new_model_generator_vgg2cnn3(client_number, args=args):
   
    model_distribution = {
        'CNN1': 10,
        'CNN4': 10,
        'CNN3':10,
        'VGG11' : 10,
        'VGG16' : 10 
    }
    
  
    total = sum(model_distribution.values())
    if total != client_number:
        print(f"Warning: Model distribution total ({total}) != client_number ({client_number})")
        
        scale = client_number / total
        model_distribution = {k: int(v * scale) for k, v in model_distribution.items()}
    
        remainder = client_number - sum(model_distribution.values())
        model_names = list(model_distribution.keys())
        for i in range(remainder):
            model_distribution[model_names[i % len(model_names)]] += 1
    

    client_index_list = list(range(client_number))
    random.shuffle(client_index_list)

    client_model_dict = {}
    client_model_instances = {}
    start = 0
    
    for model_name, count in model_distribution.items():
        assigned_clients = client_index_list[start:start+count]
        start += count
        
        for client_id in assigned_clients:
            client_model_dict[client_id] = model_name
            
            
            if model_name == 'CNN1':
                client_model_dict[client_id] = CNN1(args=args)
                client_model_instances[client_id] = 'CNN1'
            elif model_name == 'CNN4':
                client_model_dict[client_id] = CNN4(args=args)
                client_model_instances[client_id] = 'CNN4'
            elif model_name == 'CNN3':
                client_model_dict[client_id] = CNN3(args=args)
                client_model_instances[client_id] = 'CNN3'
            elif model_name == 'VGG11':
                client_model_dict[client_id] = VGG11(args=args)
                client_model_instances[client_id] = 'VGG11'
            
            elif model_name == 'VGG16':
                client_model_dict[client_id] = VGG16(args=args)
                client_model_instances[client_id] = 'VGG16'
            else:
                raise ValueError(f"Unknown model name: {model_name}")
            
    
    return client_model_dict, client_model_instances
def new_model_generator_vgg3cnn3(client_number, args=args):
   
    model_distribution = {
        'CNN1': 8,
        'CNN4': 8,
        'CNN3':8, 
        'VGG11' : 8,
        'VGG13': 8, 
        'VGG16' : 10 
    }
    
    total = sum(model_distribution.values())
    if total != client_number:
        print(f"Warning: Model distribution total ({total}) != client_number ({client_number})")
        
        scale = client_number / total
        model_distribution = {k: int(v * scale) for k, v in model_distribution.items()}
    
        remainder = client_number - sum(model_distribution.values())
        model_names = list(model_distribution.keys())
        for i in range(remainder):
            model_distribution[model_names[i % len(model_names)]] += 1
    

    client_index_list = list(range(client_number))
    random.shuffle(client_index_list)

    client_model_dict = {}
    client_model_instances = {}
    start = 0
    
    for model_name, count in model_distribution.items():
        assigned_clients = client_index_list[start:start+count]
        start += count
        
        for client_id in assigned_clients:
            client_model_dict[client_id] = model_name
            
           
            if model_name == 'CNN1':
                client_model_dict[client_id] = CNN1(args=args)
                client_model_instances[client_id] = 'CNN1'
            elif model_name == 'CNN4':
                client_model_dict[client_id] = CNN4(args=args)
                client_model_instances[client_id] = 'CNN4'
            elif model_name == 'CNN3':
                client_model_dict[client_id] = CNN3(args=args)
                client_model_instances[client_id] = 'CNN3'
            elif model_name == 'VGG11':
                client_model_dict[client_id] = VGG11(args=args)
                client_model_instances[client_id] = 'VGG11'
            elif model_name == 'VGG13':
                client_model_dict[client_id] = VGG13(args=args)
                client_model_instances[client_id] = 'VGG13'
            elif model_name == 'VGG16':
                client_model_dict[client_id] = VGG16(args=args)
                client_model_instances[client_id] = 'VGG16'
            else:
                raise ValueError(f"Unknown model name: {model_name}")
            
    
    return client_model_dict, client_model_instances
def new_model_generator_vgg3cnn4(client_number, args=args):
   
    model_distribution = {
        'CNN1': 7,
        'CNN4': 7,
        'CNN2':7,
        'CNN3':7,
        'VGG11' : 7,
        'VGG13': 7, 
        'VGG16' : 8 
    }
    
    total = sum(model_distribution.values())
    if total != client_number:
        print(f"Warning: Model distribution total ({total}) != client_number ({client_number})")
        
        scale = client_number / total
        model_distribution = {k: int(v * scale) for k, v in model_distribution.items()}
    
        remainder = client_number - sum(model_distribution.values())
        model_names = list(model_distribution.keys())
        for i in range(remainder):
            model_distribution[model_names[i % len(model_names)]] += 1
    

    client_index_list = list(range(client_number))
    random.shuffle(client_index_list)

    client_model_dict = {}
    client_model_instances = {}
    start = 0
    
    for model_name, count in model_distribution.items():
        assigned_clients = client_index_list[start:start+count]
        start += count
        
        for client_id in assigned_clients:
            client_model_dict[client_id] = model_name
            
            
            if model_name == 'CNN1':
                client_model_dict[client_id] = CNN1(args=args)
                client_model_instances[client_id] = 'CNN1'
            elif model_name == 'CNN4':
                client_model_dict[client_id] = CNN4(args=args)
                client_model_instances[client_id] = 'CNN4'
            elif model_name == 'CNN2':
                client_model_dict[client_id] = CNN2(args=args)
                client_model_instances[client_id] = 'CNN2'
            elif model_name == 'CNN3':
                client_model_dict[client_id] = CNN3(args=args)
                client_model_instances[client_id] = 'CNN3'
            elif model_name == 'VGG11':
                client_model_dict[client_id] = VGG11(args=args)
                client_model_instances[client_id] = 'VGG11'
            elif model_name == 'VGG13':
                client_model_dict[client_id] = VGG13(args=args)
                client_model_instances[client_id] = 'VGG13'
            elif model_name == 'VGG16':
                client_model_dict[client_id] = VGG16(args=args)
                client_model_instances[client_id] = 'VGG16'
            else:
                raise ValueError(f"Unknown model name: {model_name}")
            
    
    return client_model_dict, client_model_instances



def new_model_generatorhom(client_number,args=args):
    model_distribution = {
       
        'CNN4': 50
    }
    

    total = sum(model_distribution.values())
    if total != client_number:
        print(f"Warning: Model distribution total ({total}) != client_number ({client_number})")
        
        scale = client_number / total
        model_distribution = {k: int(v * scale) for k, v in model_distribution.items()}
    
        remainder = client_number - sum(model_distribution.values())
        model_names = list(model_distribution.keys())
        for i in range(remainder):
            model_distribution[model_names[i % len(model_names)]] += 1
    

    client_index_list = list(range(client_number))
    random.shuffle(client_index_list)

    client_model_dict = {}
    client_model_instances = {}
    start = 0
    
    for model_name, count in model_distribution.items():
        assigned_clients = client_index_list[start:start+count]
        start += count
        
        for client_id in assigned_clients:
            client_model_dict[client_id] = model_name
            
           
            if model_name == 'CNN1':
                client_model_dict[client_id] = CNN1(args=args)
                client_model_instances[client_id] = 'CNN1'
            elif model_name == 'CNN2':
                client_model_dict[client_id] = CNN2(args=args)
                client_model_instances[client_id] = 'CNN2'
            elif model_name == 'CNN3':
                client_model_dict[client_id] = CNN3(args=args)
                client_model_instances[client_id] = 'CNN3'
            elif model_name == 'CNN4':
                client_model_dict[client_id] = CNN4(args=args)
                client_model_instances[client_id] = 'CNN4'
            else:
                raise ValueError(f"Unknown model name: {model_name}")

   
    return client_model_dict, client_model_instances