from ast import arg
from header import *
from local_datasets import *
from model import *
from config import *
from sklearn.model_selection import train_test_split
from peft import get_peft_model_state_dict, set_peft_model_state_dict
from torch.utils.data import random_split, Subset
from transformers import AutoTokenizer, AutoModel
import gc
import numpy as np
def parser_args():
    parser = argparse.ArgumentParser(description='train parameters')
    parser.add_argument('--model', type=str)
    parser.add_argument('--data_path', type=str)
    parser.add_argument('--local_rank', default=0, type=int)
    parser.add_argument('--save_path', type=str)
    parser.add_argument('--log_path', type=str)
    # model configurations
    parser.add_argument('--image_root_path', type=str) # the directory that stores all images
    parser.add_argument('--imagebind_ckpt_path', type=str) # the path that stores the imagebind checkpoint
    parser.add_argument('--vicuna_ckpt_path', type=str) # the path that stores the vicuna checkpoint
    parser.add_argument('--delta_ckpt_path', type=str) # the delta parameters trained in stage 1
    parser.add_argument('--max_tgt_len', type=int) # the maximum sequence length
    parser.add_argument('--stage', type=int) # the maximum sequence length
    parser.add_argument("--lora_modality_names", nargs="+", type=str, default=["vision", "text","audio"],
                        choices=["vision", "text", "audio", "thermal", "depth", "imu"],
                        help="Modality names to apply LoRA")
    parser.add_argument("--lora_layer_idxs", nargs="+", type=int,
                        help="Layer indices to apply LoRA")
    parser.add_argument("--lora_layer_idxs_vision", nargs="+", type=int,
                        help="Layer indices to apply LoRA for vision modality. Overrides lora_layer_idxs if specified")
    parser.add_argument("--lora_layer_idxs_text", nargs="+", type=int,
                        help="Layer indices to apply LoRA for text modality. Overrides lora_layer_idxs if specified")
    parser.add_argument("--lora_layer_idxs_audio", nargs="+", type=int,
                        help="Layer indices to apply LoRA for audio modality. Overrides lora_layer_idxs if specified")
    parser.add_argument("--lora_layer_idxs_thermal", nargs="+", type=int,
                        help="Layer indices to apply LoRA for thermal modality. Overrides lora_layer_idxs if specified")
    parser.add_argument("--lora_layer_idxs_depth", nargs="+", type=int,
                        help="Layer indices to apply LoRA for depth modality. Overrides lora_layer_idxs if specified")
    parser.add_argument("--lora_layer_idxs_imu", nargs="+", type=int,
                        help="Layer indices to apply LoRA for imu modality. Overrides lora_layer_idxs if specified")
    parser.add_argument("--lora", action="store_true", help="Use LoRA")
    return parser.parse_args()

def initialize_distributed(args):
    args['master_ip'] = os.getenv('MASTER_ADDR', 'localhost')
    args['master_port'] = os.getenv('MASTER_PORT', '6000')
    args['world_size'] = int(os.getenv('WORLD_SIZE', '1'))
    args['local_rank'] = int(os.getenv('RANK', '0')) % torch.cuda.device_count()
    device = args['local_rank'] % torch.cuda.device_count()
    torch.cuda.set_device(device)
    deepspeed.init_distributed(dist_backend='nccl')

def set_random_seed(seed):
    if seed is not None and seed > 0:
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        torch.random.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

def config_env(args):
    args['root_dir'] = '../'
    args['mode'] = 'train'
    config = load_config(args)
    args.update(config)
    initialize_distributed(args)
    set_random_seed(args['seed'])

def build_directory(path):
    if os.path.exists(path):
        pass
    else: # recursively construct directory
        os.makedirs(path, exist_ok=True)

def main(**args):
    config_env(args)
    args['ds_config_path'] = f'dsconfig/{args["model"]}_stage_{args["stage"]}.json'
    dschf = HfDeepSpeedConfig(args['ds_config_path'])
    args['dschf'] = dschf

    build_directory(args['save_path'])
    build_directory(args['log_path'])

    if args['log_path']:
        logging.basicConfig(
            format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s', 
            level=logging.DEBUG,
            filename=f'{args["log_path"]}/train_{time.asctime()}.log',
            filemode='w'
        )


    full_data, _, _ = load_sft_dataset(args)
    collate_fn = full_data.collate

    client_data_per_client = 2000
    num_clients = 5
    test_data_size = 100
    total_client_data = client_data_per_client * num_clients
    all_indices = np.random.permutation(len(full_data))
    train_indices = [int(i) for i in all_indices[:total_client_data]]
    test_indices = [int(i) for i in all_indices[total_client_data:total_client_data + test_data_size]]
    train_data = Subset(full_data, train_indices)
    test_data = Subset(full_data, test_indices)
    client_loaders = []
    split_size = client_data_per_client
    # # 划分训练集和测试集（例如：1% train, 99% test）
    # full_len = len(full_data)
    # test_len = int(full_len * 0.2)
    # train_len = full_len - test_len

    # train_data, test_data = random_split(full_data, [train_len, test_len])

    # # 构建多客户端 Subset + DataLoader
    # data_len = len(train_data)
    # indices = np.arange(data_len)
    # np.random.shuffle(indices)
    # client_loaders = []

    # num_clients = 5
    # split_size = data_len // num_clients

    for i in range(num_clients):
        start = i * split_size
        end = (i + 1) * split_size
        subset = Subset(train_data, list(range(start, end)))
        _, client_loader, _ = build_dataloader(subset, args, collate_fn=collate_fn)
        client_loaders.append(client_loader)

    # 构建测试集 dataloader（注意这里也传入 collate_fn）
    _, test_iter, _ = build_dataloader(test_data, args, collate_fn=collate_fn, shuffle=False)
    # 构建训练dataloader
   # _,train_iter, _ = build_dataloader(train_data, args)

    full_data_audio, _, _ = load_sft_dataset_audio(args)
    collate_fn_audio = full_data_audio.collate
    client_data_per_client_audio = 2000
    num_clients_audio = 5
    test_data_audio_size = 100
    total_client_data_audio = client_data_per_client_audio * num_clients_audio
    all_indices = np.random.permutation(len(full_data_audio))
    train_indices_audio = [int(i) for i in all_indices[:total_client_data_audio]]
    test_indices_audio = [int(i) for i in all_indices[total_client_data_audio:total_client_data_audio + test_data_audio_size]]
    train_data_audio = Subset(full_data_audio, train_indices_audio)
    test_data_audio = Subset(full_data_audio, test_indices_audio)
    client_loaders_audio = []
    split_size_audio = client_data_per_client_audio
    # # 划分训练集和测试集（例如：1% train, 99% test）
    # full_len_audio = len(full_data_audio)
    # test_len_audio = int(full_len_audio * 0.2)
    # train_len_audio = full_len_audio - test_len_audio

    # train_data_audio, test_data_audio = random_split(full_data_audio, [train_len_audio, test_len_audio])

    # # 构建多客户端 Subset + DataLoader
    # data_len_audio = len(train_data_audio)
    # indices_audio = np.arange(data_len_audio)
    # np.random.shuffle(indices_audio)
    # client_loaders_audio = []

    # num_clients_audio = 5
    # split_size_audio = data_len_audio // num_clients_audio

    for i in range(num_clients_audio):
        start_audio = i * split_size_audio
        end_audio = (i + 1) * split_size_audio
        subset_audio = Subset(train_data_audio, list(range(start_audio, end_audio)))
        _, client_loader_audio, _ = build_dataloader(subset_audio, args, collate_fn=collate_fn_audio)
        client_loaders_audio.append(client_loader_audio)
    _, test_iter_audio, _ = build_dataloader(test_data_audio, args, collate_fn=collate_fn_audio, shuffle=False)


    length = args['epochs'] * len(train_data) // args['world_size'] // dschf.config['train_micro_batch_size_per_gpu']
    total_steps = args['epochs'] * len(train_data) // dschf.config['train_batch_size']
    args['total_steps'] = total_steps

    length_audio = args['epochs'] * len(train_data_audio) // args['world_size'] // dschf.config['train_micro_batch_size_per_gpu']
    total_steps_audio = args['epochs'] * len(train_data_audio) // dschf.config['train_batch_size']
    args['total_steps_audio'] = total_steps_audio
    #agent = load_model(args)
    torch.distributed.barrier()


    global_model = load_model(args)
    freeze_non_lora_params(global_model.model)
    base_lora_state = extract_lora_state_dict(global_model.model)
    # begin to train

    num_rounds = 10  # 设置联邦学习的轮数
    for round in range(num_rounds):
        client_models = []
        lora_dicts = []
        client_lora_states= []
        pbar = tqdm(total=length_audio)
        current_step = 0
        for epoch_i in tqdm(range(args['epochs'])):
            for i in range(num_clients_audio):
                    # 每个客户端复制一份全局模型
                load_lora_weights(global_model.model, base_lora_state)

                for batch in client_loaders_audio[i]:
                    # 冻结 A，只训练 B
                    if current_step < 0.2 * total_steps:
                        set_lora_trainable(global_model.model, mode="imagebind", train_A=False, train_B=True)
                    else:
                        set_lora_trainable(global_model.model, mode="llm", train_A=False, train_B=True)

    #                 if current_step < 0.2 * total_steps:
    # # 训练 ImageBind 的 LoRA
    #                     set_lora_trainable(global_model.model, mode="imagebind")
    #                 else:
    #                     # 训练 LLM 的 LoRA
    #                     set_lora_trainable(global_model.model, mode="llm")
                    # trainable_params = [n for n, p in global_model.model.named_parameters() if p.requires_grad]
                    # print(f"[Step {current_step}] Trainable parameters:", trainable_params[:20])  # 只打印前20个名字
                    # print(f"Total trainable params: {len(trainable_params)}")

                    global_model.train_model(
                        batch,
                        current_step=current_step,
                        pbar=pbar
                    )
                    current_step += 1
            print(f"Client {i+1} finished training for round {round + 1}/{num_rounds}.")
            trained_lora = extract_lora_state_dict(global_model.model)


            lora_dicts.append(trained_lora)

            torch.cuda.empty_cache()

    # # 聚合 LoRA 参数 fedavg
    # #     avg_lora = federated_average_lora(lora_dicts)

    #     avg_lora = federated_regularized_lora_aggregation(lora_dicts, lambda_reg=1e-4)
    #     load_lora_weights(global_model.model, avg_lora)
    #         # 服务器端聚合参数（如adapter部分）
    #     #global_model = federated_average(client_models)
    #         # 下发全局模型到各客户端（下轮继续）
    #     # save at the end of the training
    #     torch.distributed.barrier()
    #     global_model.save_model(args['save_path'], 0)


# ##================img=================================

#         for epoch_i in tqdm(range(args['epochs'])):
#             for i in range(num_clients):
#                     # 每个客户端复制一份全局模型
#                 load_lora_weights(global_model.model, base_lora_state)

#                 for batch in client_loaders[i]:


#     #                 if current_step < 0.2 * total_steps:
#     # # 训练 ImageBind 的 LoRA
#     #                     set_lora_trainable(global_model.model, mode="imagebind")
#     #                 else:
#     #                     # 训练 LLM 的 LoRA
#     #                     set_lora_trainable(global_model.model, mode="llm")

#     # 冻结 A，只训练 B
#                     if current_step < 0.2 * total_steps:
#                         set_lora_trainable(global_model.model, mode="imagebind", train_A=False, train_B=True)
#                     else:
#                         set_lora_trainable(global_model.model, mode="llm", train_A=False, train_B=True)

#                     # trainable_params = [n for n, p in global_model.model.named_parameters() if p.requires_grad]
#                     # print(f"[Step {current_step}] Trainable parameters:", trainable_params[:20])  # 只打印前20个名字
#                     # print(f"Total trainable params: {len(trainable_params)}")

#                     global_model.train_model(
#                         batch,
#                         current_step=current_step,
#                         pbar=pbar
#                     )
#                     current_step += 1  # 每一步都要递增

#             print(f"Client {i+1} finished training for round {round + 1}/{num_rounds}.")
#             trained_lora = extract_lora_state_dict_B(global_model.model)


#             lora_dicts.append(trained_lora)

#             torch.cuda.empty_cache()



    # 聚合 LoRA 参数 fedavg
        #avg_lora = federated_average_lora(lora_dicts)

        avg_lora = federated_regularized_lora_aggregation(lora_dicts, lambda_reg=1e-4)
        
        load_lora_weights(global_model.model, avg_lora)
        base_lora_state=extract_lora_state_dict(global_model.model)
            # 服务器端聚合参数（如adapter部分）
        #global_model = federated_average(client_models)
            # 下发全局模型到各客户端（下轮继续）
        # save at the end of the training
        torch.distributed.barrier()
        global_model.save_model(args['save_path'], 0)

        # ========== 测试部分 ==========
        print("Start evaluating on test set...")
        #_,test_iter,  _ = build_dataloader(test_data, args, collate_fn = test_data.collate,shuffle=False)
        # for batch in test_iter:
        #     print("DEBUG - batch keys:", batch.keys())
        #     print("DEBUG - image_paths:", batch['image_paths'])
        #     break  # 只查看第一个 batch 就退出
        # exit()
        global_model.model.eval()
        all_loss = []
        # with torch.no_grad():
        #     for batch in test_iter:
        #         global_model.predict_model(batch,round_id=round,pbar=pbar)
        #         global_model.predict_model_1(batch,round_id=round,pbar=pbar)
        with torch.no_grad():
            for batch in test_iter_audio:
                global_model.predict_model(batch,round_id=round,pbar=pbar)
                global_model.predict_model_1(batch,round_id=round,pbar=pbar)
        with torch.no_grad():
            for batch in test_iter:
                global_model.predict_model(batch,round_id=round,pbar=pbar)
                global_model.predict_model_1(batch,round_id=round,pbar=pbar)
                #global_model.predict(batch)
                #exit()
        #         all_loss.append(loss)
        # avg_loss = sum(all_loss) / len(all_loss)
        print(f"Round {round + 1}/{num_rounds} completed.")
        #print(f"Test set average loss: {avg_loss:.4f}")


def freeze_non_lora_params(model):
    for name, param in model.named_parameters():
        if 'lora' not in name:
            param.requires_grad = False

def extract_lora_state_dict(model):
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items() if 'lora' in k}
def extract_lora_state_dict_A(model):
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items() if 'lora_A' in k}
def extract_lora_state_dict_B(model):
    return {k: v.detach().cpu().clone() for k, v in model.state_dict().items() if 'lora_B' in k}
# def extract_lora_state_dict(model):
#     # 只提取 LoRA adapter 的权重，key/shape 都由 PEFT 处理，更稳
#     return get_peft_model_state_dict(model)
# def load_lora_weights(model, lora_state_dict, adapter_name="default"):
#     # strict=False 可以在 key 缺失或 shape 轻微不一致时跳过而不报错
#     set_peft_model_state_dict(model, lora_state_dict, adapter_name=adapter_name, strict=False)

def load_lora_weights(model, lora_state_dict):
    with torch.no_grad():
        for k, v in lora_state_dict.items():
            model.state_dict()[k].copy_(v.to(model.device))


def federated_average_lora(lora_dicts):
    avg_lora = {}
    for key in lora_dicts[0].keys():
        avg_lora[key] = sum([d[key] for d in lora_dicts]) / len(lora_dicts)
    return avg_lora

def federated_average(client_agents):
    avg_agent = copy.deepcopy(client_agents[0])
    with torch.no_grad():
        for key in avg_agent.model.state_dict().keys():
            if 'lora' in key:
                avg_param = sum([agent.model.state_dict()[key] for agent in client_agents]) / len(client_agents)
                avg_agent.model.state_dict()[key].copy_(avg_param)
    return avg_agent



def federated_regularized_lora_aggregation(lora_dicts, lambda_reg=0.1):
    """
    带正则项的 LoRA 聚合，只有在 BBT 不可逆时才加正则。
    """
    global_lora = {}
    lora_keys = list(lora_dicts[0].keys())

    for key in lora_keys:
        if 'lora_A' in key:
            # base_key = key.replace('lora_A', '')
            # A_keys = [d[base_key + 'lora_A'] for d in lora_dicts]
            # B_keys = [d[base_key + 'lora_B'] for d in lora_dicts]

            A_key = key
            B_key = key.replace('lora_A', 'lora_B')
            ref_tensor = lora_dicts[0][A_key]
            orig_dtype = ref_tensor.dtype
            orig_device = ref_tensor.device
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            A_keys = [d[A_key].to(dtype=torch.float32, device=device) for d in lora_dicts]
            B_keys = [d[B_key].to(dtype=torch.float32, device=device) for d in lora_dicts]


            # A_keys = [d[A_key] for d in lora_dicts]
            # B_keys = [d[B_key] for d in lora_dicts]
            n = len(A_keys)
            d, r = B_keys[0].shape

            B_avg = sum(B_keys) / n
            U = sum([B_i.to(torch.float32) @ A_i.to(torch.float32) for B_i, A_i in zip(B_keys, A_keys)]) / n

            Bt = B_avg.T
            BTB = Bt @ B_avg

            I = torch.eye(BTB.shape[0], device=BTB.device, dtype=torch.float32)
            BTB = BTB.to(torch.float32)

            # Step: 尝试直接求逆，失败则正则化
            try:
                BTB_inv = torch.linalg.inv(BTB)
            except RuntimeError:
                #I = torch.eye(BTB.shape[0], device=BTB.device, dtype=BTB.dtype)
                BTB_inv = torch.linalg.inv(BTB + lambda_reg * I)

            A_global = (BTB_inv @ Bt)@ U

            # global_lora[A_key] = A_global
            # global_lora[B_key] = B_avg
            global_lora[A_key] = A_global.to(dtype=orig_dtype, device=orig_device)
            global_lora[B_key] = B_avg.to(dtype=orig_dtype, device=orig_device)


    return global_lora
def set_lora_trainable(model, mode="imagebind"):
    """
    mode = "imagebind" → 只训练 ImageBind 的 LoRA
    mode = "llm"       → 只训练 LLaMA 的 LoRA
    """
    for name, param in model.named_parameters():
        if "lora" not in name.lower():
            param.requires_grad = False
            continue

        # ImageBind LoRA: 出现在 visual_encoder.* 下面
        is_imagebind = (
            "visual_encoder" in name.lower()
            or "modality_trunks" in name.lower()
            or "modality_heads" in name.lower()
            or "postprocessors" in name.lower()
        )

        # LLM LoRA: 出现在 llama_model.base_model.model.* 下面
        is_llm = (
            "llama_model" in name.lower()
            or "base_model.model" in name.lower()
        )

        if mode == "imagebind":
            param.requires_grad = is_imagebind
        elif mode == "llm":
            param.requires_grad = is_llm
        else:
            param.requires_grad = False

# def set_lora_trainable(model, mode="imagebind"):
#     """
#     mode = "imagebind" → 只训练 ImageBind 里的 LoRA
#     mode = "llm" → 只训练 LLaMA / Vicuna 里的 LoRA
#     """
#     for name, param in model.named_parameters():
#         if "lora" in name:
#             if mode == "imagebind":
#                 param.requires_grad = ("imagebind" in name.lower())
#             elif mode == "llm":
#                 param.requires_grad = ("model" in name.lower() or "vicuna" in name.lower() or "llama" in name.lower())
#             else:
#                 param.requires_grad = False  # 默认都不训练
#         else:
#             param.requires_grad = False

def set_lora_trainable(model, mode="imagebind", train_A=False, train_B=True):
    """
    mode = "imagebind" → 只训练 ImageBind 模块的 LoRA
    mode = "llm"       → 只训练 LLM 模块的 LoRA
    train_A / train_B  → 是否训练 LoRA A / B 矩阵
    """
    for name, param in model.named_parameters():
        if "lora" not in name.lower():
            param.requires_grad = False
            continue

        # ImageBind LoRA: 出现在 visual_encoder.* 下面
        is_imagebind = (
            "visual_encoder" in name.lower()
            or "modality_trunks" in name.lower()
            or "modality_heads" in name.lower()
            or "postprocessors" in name.lower()
        )

        # LLM LoRA: 出现在 llama_model.base_model.model.* 下面
        is_llm = (
            "llama_model" in name.lower()
            or "base_model.model" in name.lower()
        )

        # 判断是 A 还是 B
        is_A = "lora_A" in name
        is_B = "lora_B" in name

        if mode == "imagebind" and is_imagebind:
            param.requires_grad = (is_A and train_A) or (is_B and train_B)
        elif mode == "llm" and is_llm:
            param.requires_grad = (is_A and train_A) or (is_B and train_B)
        else:
            param.requires_grad = False




if __name__ == "__main__":
    args = parser_args()
    args = vars(args)
    main(**args)
