import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import argparse
import os
import pickle
import torch.nn.functional as F
from tqdm import tqdm

def convert_matrix_to_mlp(matrix_model, input_dim=128, latent_dim=64):
    """将矩阵参数转换为1024个独立MLP"""
    mlp_list = nn.ModuleList()
    
    for net_idx in range(1024):
        # 构建单个MLP结构
        mlp = nn.Sequential(
            nn.Linear(input_dim, latent_dim),
            nn.BatchNorm1d(latent_dim),
            nn.GELU(),
            nn.Linear(latent_dim, input_dim)
        )
        
        # 从矩阵参数提取对应网络的参数
        with torch.no_grad():
            # 第一层参数（需要转置以适应Linear层的形状）
            mlp[0].weight.copy_(matrix_model.W1[net_idx].T)  # [latent, input]
            mlp[0].bias.copy_(matrix_model.b1[net_idx].flatten())
            
            # 第二层参数
            mlp[3].weight.copy_(matrix_model.W2[net_idx].T)  # [input, latent]
            mlp[3].bias.copy_(matrix_model.b2[net_idx].flatten())
            
            # 初始化BatchNorm参数保持计算等价性
            nn.init.ones_(mlp[1].weight)    # gamma初始为1
            nn.init.zeros_(mlp[1].bias)     # beta初始为0
            mlp[1].running_mean.zero_()     # 清空运行均值
            mlp[1].running_var.fill_(1)     # 运行方差设为1
            mlp[1].eval()                   # 冻结BN统计量
        
        mlp_list.append(mlp)
    
    return mlp_list

def validate_conversion(matrix_model, mlp_list):
    """验证转换前后计算一致性"""
    dummy_input = torch.randn(32, 4096)
    
    # 矩阵化模型输出
    matrix_output = matrix_model(dummy_input.view(32, 32, 32, 128))
    
    # 转换后模型输出
    chunked_input = dummy_input.view(32, 32, 32, 128)
    converted_output = []
    for i in range(32):
        row_output = []
        for j in range(32):
            net_idx = i * 32 + j
            output = mlp_list[net_idx](chunked_input[:,i,j,:])
            row_output.append(output)
        converted_output.append(torch.stack(row_output, dim=2))
    converted_output = torch.cat(converted_output, dim=1).view(32, 4096)
    
    # 计算最大误差
    max_error = torch.max(torch.abs(matrix_output - converted_output))
    print(f"最大输出误差：{max_error.item():.2e}")


def get_args():
    parser = argparse.ArgumentParser(
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('--model', type=str, default='llava_v1.5_7B', help="specifies the model to be evaluated.")
    parser.add_argument('--guard_path', type=str, default='/data/huggingface/Llama-Guard-3-8B', help="specifies the model to be evaluated.")
    parser.add_argument('--dataset_name', type=str, default='', help="specifies the path to the data")
    parser.add_argument('--input_dim', type=int, default=4096)
    parser.add_argument('--hidden_dim', type=int, default=512)
    parser.add_argument('--latent_dim', type=int, default=20)
    parser.add_argument('--epochs', type=int, default=50)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--device', type=str, default='0')
    parser.add_argument("--subfix", type=str, default='')
    parser.add_argument('--use_center_of_mass', action='store_true', help='use center of mass direction', default=False)
    parser.add_argument("--mode", type=str, default='')
    parser.add_argument("--query_path", type=str, default='/data/multimodal_alignment/mm_iti/features')
    parser.add_argument("--vector_path", type=str, default='/data/multimodal_alignment/mm_iti/probes')
    
    args = parser.parse_args()
    return args

class MultiOffsetGenerator(nn.Module):
    def __init__(self, num_nets=1024, in_dim=128, hidden_dim=64, num_layer=32, num_head=32):
        self.num_layer = num_layer
        self.num_head = num_head
        super().__init__()
        # 将1024个网络的参数矩阵化
        self.W1 = nn.Parameter(torch.randn(num_nets, in_dim, hidden_dim))
        self.b1 = nn.Parameter(torch.zeros(num_nets, 1, hidden_dim))
        self.W2 = nn.Parameter(torch.randn(num_nets, hidden_dim, in_dim))
        self.b2 = nn.Parameter(torch.zeros(num_nets, 1, in_dim))
        
        # 初始化参数
        nn.init.kaiming_normal_(self.W1, mode='fan_out', nonlinearity='relu')
        nn.init.kaiming_normal_(self.W2, mode='fan_in', nonlinearity='linear')

    def forward(self, x):
        """
        x shape: [batch_size, 32, 32, 128]
        """
        # 转换为并行计算维度 [batch_size, 32, 32, 128] -> [batch_size, 1024, 128]
        x = x.view(x.size(0), 1024, 128)
        
        # 第一层并行计算
        h = torch.einsum('bni,nij->bnj', x, self.W1) + self.b1
        h = torch.nn.functional.gelu(h)
        
        # 第二层并行计算
        delta = torch.einsum('bni,nij->bnj', h, self.W2) + self.b2
        
        # 恢复原始形状
        return delta.view(x.size(0), 32, 32, 128)

class OffsetGenerator(nn.Module):
    """约束特征到重参数化噪声的转换模块"""
    def __init__(self, input_dim=128, latent_dim=128, num_layer=32, num_head=32):
        super().__init__()
        self.num_layer = num_layer
        self.num_head = num_head
        self.nets = nn.ModuleList()
        for i in range(num_layer * num_head):
            self.nets.append(nn.Sequential(
            nn.Linear(input_dim, latent_dim, dtype=torch.float32), 
            nn.GELU(),
            nn.Linear(latent_dim, input_dim, dtype=torch.float32)
        ))
        
        # self.net = nn.Sequential(
        #     nn.Linear(input_dim, latent_dim),
        #     nn.BatchNorm1d(latent_dim),  # 加速收敛
        #     nn.GELU(),
        #     nn.Linear(latent_dim, input_dim)
        # )
        # 初始化保证输出的稳定性
        self._init_weights()
    
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                nn.init.constant_(m.bias, 0.0)
    
    def forward(self, x):
        """
        x shape: [batch_size, 32, 32, 128]
        """
        output = torch.zeros_like(x, device=x.device, dtype=torch.float32)
    
        # 遍历每个空间位置
        for i in range(self.num_layer):
            for j in range(self.num_head):
                # 计算当前位置的网络索引
                idx = i * self.num_head + j
                # 提取当前空间位置的特征 [batch_size, input_dim]
                output_patch = self.nets[idx](x[:, i, j, :])
                if torch.isnan(output_patch).any():  # 定位产生NaN的层
                    print(f"NaN出现在layer{i},{j}!")
                # 将结果存入对应位置
                output[:, i, j, :] = output_patch
                
        return output


def train(offset_model, vector_data, train_dataloader, val_dataloader, optimizer, epochs, args):
    offset_model.train()
    loss_fn = nn.MSELoss()  # 适用于精确数值匹配
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for batch_idx, (query, vector_gt) in enumerate(train_dataloader):
            # data = data.to(args.device)
            optimizer.zero_grad()
            
            offset = offset_model(query)
            pred = offset + vector_data
            loss = loss_fn(pred, vector_gt)
            
            # print(loss)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
        
        avg_loss = total_loss / len(train_dataloader.dataset)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')
    
        # val(offset_model, vector_data, val_dataloader)
        
        if (epoch + 1) % 5 == 0:
            print(f'Save {args.save_path}_q_{epoch + 1}.pth')
            torch.save({
                'model_state_dict': offset_model.state_dict(),
                'model_config': {
                    'input_dim': args.input_dim,
                    'latent_dim': args.latent_dim,
                    'num_layer': 32,
                    'num_head': 32
                }
            }, f'{args.save_path}_q_{epoch + 1}.pth')
    return offset_model

def val(offset_model, vector_data, val_dataloader):
    if val_dataloader == None:
        return
    # 测试重构能力
    sims = []
    dsts = []
    for batch_idx, (query, vector_gt) in enumerate(val_dataloader):
        # data = data.to(args.device)
        
        offset = offset_model(query)
        pred = offset + vector_data
        
        sim = torch.nn.functional.cosine_similarity(pred, vector_gt, dim=-1)
        dst = torch.norm(pred - vector_gt, p=2, dim=-1)
        
        sims.append(torch.mean(sim).item())
        dsts.append(torch.mean(dst).item())
        # print(loss)
    
    avg_sim = sum(sims) / len(sims)
    avg_dst = sum(dsts) / len(dsts)
    
    print(f'Val sim: {avg_sim:.4f}, Val dst: {avg_dst:.4f}')
    # 测试生成能力



def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.mse_loss(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + 0.1 * KLD

def prepare_data(vector_data_path, query_data_path, caption_data_path, batch_size=16, num_layer=32, num_head=32):
    # 生成示例数据（替换为真实数据加载逻辑）

    with open(vector_data_path, 'rb') as f: 
        vector_data = pickle.load(f)
    vector_data = vector_data.reshape(num_layer, num_head, -1)
    
    query_data = np.load(query_data_path)
    caption_data = np.load(caption_data_path)
    vector_gt_data = caption_data - query_data
    
    query_data = query_data.reshape(query_data.shape[0], num_layer, num_head, -1)
    vector_gt_data = vector_gt_data.reshape(vector_gt_data.shape[0], num_layer, num_head, -1)
    
    dataset = TensorDataset(torch.from_numpy(query_data).cuda().float(), torch.from_numpy(vector_gt_data).cuda().float())
    return torch.from_numpy(vector_data).cuda().float(), DataLoader(dataset, batch_size=batch_size, shuffle=True)

if __name__ == '__main__':
    args = get_args()
    # 超参数设置
    args.input_dim = 128
    args.latent_dim = 128
    args.batch_size = 16
    args.epochs = 10
    args.lr = 1e-3
    args.device = 'cuda'
    args.model = 'shikra_7B'
    args.query_set = 'POPE_train_YR_I+Q'
    args.caption_set = 'POPE_train_YR_C_p2+Q_query'
    args.vector_set = 'POPE_train_YR_I+Q;C_p2+Q_best'
    # args.vector_set = ['POPE_train_I+Q','POPE_train_C+Q_best']
    args.save_path = f'/data/multimodal_alignment/mm_iti/probes/{args.model}_offset_generator_YR'
    args.subfix = ''
    args.use_center_of_mass = True
    args.mode = 'head'
    
    print(os.path.exists(args.save_path))
    
    # 初始化组件
    model = OffsetGenerator().to(args.device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    if args.use_center_of_mass:
        prefix = 'coms'
    else:
        prefix = 'probes'
        
    if args.mode == 'mlp':
        prefix += '_mlp'
        subfix = f'mlp{args.subfix}'
    else:
        subfix = f'head_wise{args.subfix}'
    
    if type(args.vector_set) == list:
        vector_data_path = [f'{prefix}_{args.model}_{v_set}.pkl' for v_set in args.vector_set]
        vector_data_path = [os.path.join(args.vector_path, path) for path in vector_data_path]
    else:
        vector_data_path = f'{prefix}_{args.model}_{args.vector_set}.pkl'
        vector_data_path = os.path.join(args.vector_path, vector_data_path)
    
    caption_data_path = f'{args.model}_{args.caption_set}_{subfix}.npy'
    caption_data_path = os.path.join(args.query_path, caption_data_path)
        
    query_data_path = f'{args.model}_{args.query_set}_{subfix}.npy'
    query_data_path = os.path.join(args.query_path, query_data_path)
    
    vector_data, train_dataloader = prepare_data(vector_data_path, query_data_path, caption_data_path, args.batch_size)
    
    # caption_data_path = caption_data_path.replace('train', 'test')
    # query_data_path = query_data_path.replace('train', 'test')
    
    # _, val_dataloader = prepare_data(vector_data_path, query_data_path, caption_data_path, args.batch_size)
    val_dataloader = None
    
    # 训练阶段
    print("Starting Training...")
    trained_model = train(model, vector_data, train_dataloader, val_dataloader, optimizer, args.epochs, args)

    
    
    # # 测试阶段
    # print("\nStarting Testing...")
    
    # caption_data_path = caption_data_path.replace('train', 'test')
    # query_data_path = query_data_path.replace('train', 'test')
    
    # vector_data, val_dataloader = prepare_data(vector_data_path, query_data_path, caption_data_path, args.batch_size)
    # test(trained_model)
    
    # 保存模型示例
    # torch.save(trained_model.state_dict(), "vae_model.pth")