import numpy as np
import torch
import torch.nn.functional as F
import itertools
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import re
import random
from torch.optim import Adam
import json
import ast

def sort_layers(input_layers):

    for i in range(len(input_layers)):
        for j in range(len(input_layers) - i - 1):
            if input_layers[j][1][-1] > input_layers[j + 1][1][-1]:
                input_layers[j], input_layers[j + 1] = input_layers[j + 1], input_layers[j]
    return input_layers

def get_layer_index_grouplist(output_top_layers,output_bottom_layers):
    client_dict = {}
    client_order = [] 
    for item in output_top_layers:
        if isinstance(item, (list, tuple)) and len(item) >= 2:
            client_id = item[0]
            layer_name = item[1]
            client_id_str = str(client_id)
            if client_id_str not in client_dict:
                client_dict[client_id_str] = []
                client_order.append(client_id_str)
            client_dict[client_id_str].append([client_id, layer_name])
        else:
            print(f"跳过无效项目: {item}")
            continue
    
    grouped_list = []
    for client_id in client_order:
        client_group = output_bottom_layers + client_dict[client_id]
        grouped_list.append(client_group)
    #print("分组后的列表（已添加output_bottom_layers）:", grouped_list)
    return grouped_list

def newsample_models(input_cluster, min_layers, max_layers, expected_num_models, global_layer_order,center_layer_name):
    output_models = []
    anchor_block, sorted_clusters = preprocess_clusters(input_cluster, global_layer_order,center_layer_name)
    
    while True:
        for i in range(min_layers, max_layers+1):
            all_layers = [layer for cluster in sorted_clusters for layer in cluster]
            current_model = []
            current_max_index = global_layer_order.get(anchor_block[1], -1)  # anchor_block[1]是层
            current_model.append(anchor_block)
            all_layers.remove(anchor_block)

            for j in range(len(sorted_clusters)):
                if j == 0:
                    continue
                #cluster=[['5', 'conv1'], ['5', 'conv2'], ['5', 'fc1'], ['5', 'fc2']]
                cluster = sorted_clusters[j]
                valid_layers = []
                valid_layers = get_valid_layers(cluster, all_layers, current_max_index, global_layer_order)
                #print('valid layers in cluster', j, ':', valid_layers)

                if valid_layers:
                    #selected_layer = random.choice(valid_layers)
                    priority_layers = []
                    other_layers = []
                    
                    for layer in valid_layers:
                        # 假设 layer 是一个元组或列表，第一个元素是索引
                        # 根据你的数据结构，这里可能需要调整获取索引的方式
                        if isinstance(layer, (list, tuple)) and len(layer) > 0:
                            try:
                                layer_index = int(layer[0])  # 假设第一个元素是索引
                                if 0 < layer_index < 3:
                                    priority_layers.append(layer)
                                else:
                                    other_layers.append(layer)
                            except (ValueError, TypeError):
                                # 如果无法转换为整数，放在其他层
                                other_layers.append(layer)
                        else:
                            other_layers.append(layer)
                    
                    # 如果有优先层的候选，从优先层中随机选择
                    if priority_layers:
                        selected_layer = random.choice(priority_layers)
                    else:
                        # 否则从其他层中随机选择
                        selected_layer = random.choice(other_layers)





                    current_model.append(selected_layer)
                    current_max_index = global_layer_order.get(selected_layer[1], current_max_index)
                    if selected_layer in all_layers:
                        all_layers.remove(selected_layer)
                #print('selected layer: ', selected_layer)
                #print('----------------------------')
            # 3. 块完成：如果模型层数不足，补充层
            if len(current_model) < i:
                fc_completion_layers = []
                for layer in all_layers:
                    layer_index = global_layer_order.get(layer[1], -1)
                    if layer_index > 4 and layer_index > current_max_index:
                        fc_completion_layers.append(layer)
                
                fc_completion_layers.sort(key=lambda x: global_layer_order.get(x[1], float('inf')))
                
                
                other_completion_layers = []
                for layer in all_layers:
                    if layer not in fc_completion_layers:
                        layer_index = global_layer_order.get(layer[1], -1)
                        if layer_index >= current_max_index:
                            other_completion_layers.append(layer)
                
                
                other_completion_layers.sort(key=lambda x: global_layer_order.get(x[1], float('inf')))
                
                
                completion_candidates = fc_completion_layers + other_completion_layers
                
                
                needed = i - len(current_model)
                
                if len(completion_candidates) < needed:
                    all_available_layers = [layer for layer in all_layers if layer not in completion_candidates]
                    all_available_layers.sort(key=lambda x: global_layer_order.get(x[1], float('inf')))
                    completion_layers = completion_candidates + all_available_layers[:needed - len(completion_candidates)]
                else:
    
                    completion_layers = completion_candidates[:needed]
                
                
                current_model.extend(completion_layers)
                
                for layer in completion_layers:
                    if layer in all_layers:
                        all_layers.remove(layer)
            
        
            conv_layers = []
            fc_layers = []
            for layer in current_model:
                layer_name = layer[1]  # layer[1]是层名称
                if layer_name.startswith('conv'):
                    conv_layers.append(layer)
                elif layer_name.startswith('fc') or layer_name == 'linear':
                    fc_layers.append(layer)
                else:
                    print('invalid layer, pass')
            if len(conv_layers) == 0:
                break
            if len(fc_layers) == 0:
                break
            conv_layers = sort_layers(conv_layers)
            fc_layers = sort_layers(fc_layers)
            output_models.append(conv_layers + fc_layers)
            
            if len(output_models) >= expected_num_models:
                break
        
        if len(output_models) >= expected_num_models:
            return output_models[0:expected_num_models]


def preprocess_clusters(input_cluster, global_layer_order, center_layer_name):
    anchor_block = None
    min_index = float('inf')
    anchor_cluster_id = None
    
    for cluster_id, center_layers in center_layer_name.items():
        
        if center_layers and len(center_layers) > 0:
            center_layer = center_layers[0]  # 取第一个中心层
            layer_name = center_layer[1]  # 层名称是第二个元素
            
            # 获取该中心层在全局顺序中的索引
            center_index = global_layer_order.get(layer_name, float('inf'))
            
            if center_index < min_index:
                min_index = center_index
                anchor_cluster_id = cluster_id
                # 最小聚类中心就是锚块
                anchor_block = center_layer  # 直接使用中心层作为锚块
    
    print('anchor block:', anchor_block, 'anchor cluster id:', anchor_cluster_id)
    
    # 重新排序簇：锚块所在簇排在最前，其他簇保持原顺序
    sorted_clusters = []
    if anchor_cluster_id is not None and anchor_cluster_id in input_cluster:
        sorted_clusters.append(input_cluster[anchor_cluster_id])
        
        # 添加其他簇（保持原顺序）
        for cluster_id, cluster in input_cluster.items():
            if cluster_id != anchor_cluster_id:
                sorted_clusters.append(cluster)
    else:
        # 如果没有找到锚块所在的簇，保持原顺序
        sorted_clusters = list(input_cluster.values())
    print('sorted clusters:', sorted_clusters)
    return anchor_block, sorted_clusters

def get_valid_layers(cluster, all_layers, current_max_index, global_layer_order):
    
    valid_layers = []
    for layer in cluster:
        if layer in all_layers:
            layer_index = global_layer_order.get(layer[1], -1)
            if layer_index >= current_max_index:
                valid_layers.append(layer)
    return valid_layers


def sort_layers(input_layers):
    sorted_layers = input_layers.copy()
    for i in range(len(sorted_layers)):
        for j in range(len(sorted_layers) - i - 1):
            if sorted_layers[j][1][-1] > sorted_layers[j + 1][1][-1]:
                sorted_layers[j], sorted_layers[j + 1] = sorted_layers[j + 1], sorted_layers[j]
    return sorted_layers


def build_global_layer_order_detailed():
    model_layer_orders = {
        'CNN1': ['conv1', 'conv2', 'fc1', 'fc2'],
        'CNN2': ['conv1', 'conv2', 'fc1', 'fc2'],
        'CNN3': ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'fc1', 'fc2', 'fc3', 'fc4'],
        'CNN4': ['conv1', 'conv2', 'conv3', 'conv4', 'conv5', 'conv6', 'fc1', 'fc2', 'linear']
    }
    
    global_layer_order = {}
    
   
    conv_layers = []
    fc_layers = []
    
    for model_name, layers in model_layer_orders.items():
        for layer_name in layers:
            if layer_name.startswith('conv') and layer_name not in conv_layers:
                conv_layers.append(layer_name)
            elif (layer_name.startswith('fc') or layer_name == 'linear') and layer_name not in fc_layers:
                fc_layers.append(layer_name)
    
    
    conv_layers.sort(key=lambda x: int(x[4:]) if x[4:].isdigit() else 0)
    
   
    fc_layers_sorted = []
    fc_with_numbers = [layer for layer in fc_layers if layer != 'linear' and layer[2:].isdigit()]
    fc_with_numbers.sort(key=lambda x: int(x[2:]))
    fc_layers_sorted.extend(fc_with_numbers)
    if 'linear' in fc_layers:
        fc_layers_sorted.append('linear')
    
    
    current_index = 0
    for layer in conv_layers:
        global_layer_order[layer] = current_index
        current_index += 1
    
    for layer in fc_layers_sorted:
        global_layer_order[layer] = current_index
        current_index += 1
    
    return global_layer_order

def get_best_combined_model(selection_info, candidate_model_combine_show_model):
    #try:
        best_model_index = selection_info.get('best_model_index')
        if best_model_index is None:
            raise ValueError("selection_info中未找到'best_model_index'")
        if not (0 <= best_model_index < len(candidate_model_combine_show_model)):
            raise IndexError(f"最佳模型索引 {best_model_index} 超出候选模型列表范围 [0, {len(candidate_model_combine_show_model)-1}]")
        best_model_components = candidate_model_combine_show_model[best_model_index]
        #print(f"找到最佳模型 (索引: {best_model_index}, 得分: {selection_info.get('best_score', 'N/A')})")
        print(f"global blocks: {best_model_components}")
        
        return best_model_components, best_model_index
        
    #except Exception as e:
    #    print(f"获取最佳组合模型时出错: {e}")
    #    return None, None

def gettoplayers(local_model_index_to_client_name, output_bottom_layers):
    output_top_layers = sort_layers(local_model_index_to_client_name)
    output_top_layers = []
    for layer_info in local_model_index_to_client_name:
        for layer in layer_info:
            layer_name = layer[1]  
            if layer_name.startswith('fc') or layer_name == ('linear'):
                output_top_layers.append(layer)
            elif layer_name.startswith('conv'):
                pass 
    output_combined_models = get_layer_index_grouplist(output_top_layers,output_bottom_layers)
    print('local model top layers:', output_top_layers)
    return output_combined_models
    
###personalized_model_reassembly
def personalized_model_reassembly(best_global_model_index_to_best_candidate_model, local_model_index_to_client_name):
        ###personal models comb_with_mlp
    output_bottom_layers = []
    output_top_layers = []
    output_models = []
    for layer in best_global_model_index_to_best_candidate_model:
        layer_name = layer[1]  
        if layer_name.startswith('conv'):
            output_bottom_layers.append(layer)
        elif layer_name.startswith('fc') or layer_name == ('linear'):
            pass 
        if len(output_bottom_layers)  == 0:
            break
    output_bottom_layers = sort_layers(output_bottom_layers)    
    output_combined_models = gettoplayers(local_model_index_to_client_name, output_bottom_layers)

    #output_top_layers =[[[...], [...]], [[...], [...]], [[...], [...]], [[...], [...], [...], [...]], [[...], [...], [...]]]

    return output_combined_models