import torch
import os
import json
import matplotlib.pyplot as plt 
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import torch.nn as nn 
import torch.nn.functional as F 
import torch.nn.init as init 
import torch.optim as optim
import wandb  # 新增导入
import random
import numpy as np 
from scipy import stats
import numpy as np 
from torch_geometric.nn import GCNConv
import os
os.environ["WANDB_DISABLED"] = "true"
device = torch.device("cuda")
def seed_torch(seed: int = 42) -> None:
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)  
    os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed) 
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms(True)
    print(f"Random seed set as {seed}")

seed_torch(3407)

import torch


class Mydataset(Dataset):
    def __init__(self, json_path):
        super().__init__()
        self.data = json.load(open(json_path))
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        
        keypoint_0 = torch.load(sample['keypoint_0'])
        feature_0 = torch.load(sample['feature_0'])
        knowledge_0 = torch.load(sample['knowledge_0'])
        text_embedding_0 = torch.load(sample['comment_0'])
        score_0 = sample['score_0']
        
        keypoint_1 = torch.load(sample['keypoint_1'])
        feature_1 = torch.load(sample['feature_1'])
        knowledge_1 = torch.load(sample['knowledge_1'])
        text_embedding_1 = torch.load(sample['comment_1'])
        score_1 = sample['score_1']
        
        return {
            'keypoint_0':keypoint_0,
            'feature_0':feature_0,
            'knowledge_0':knowledge_0,
            'score_0':score_0,
            'text_0':text_embedding_0,
            
            'keypoint_1':keypoint_1,
            'feature_1':feature_1,
            'knowledge_1':knowledge_1,
            'score_1':score_1,
            'text_1':text_embedding_1
        }

class CrossAttentionModule(nn.Module):
    def __init__(self, embed_dim=4096, num_heads=16, dropout=0.1):
        super().__init__()
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 2048),
            nn.ReLU(),
            nn.Linear(2048, embed_dim)
        )
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout = nn.Dropout(dropout)
        
        self._initialize_weights()
            
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                init.zeros_(m.bias)
    
    def forward(self, x, encoder_output, src_mask=None):
        attn_output, _ = self.cross_attention(
            query=x,
            key=encoder_output,
            value=encoder_output, 
            key_padding_mask=src_mask
        )
        
        attn_output = self.dropout(attn_output)
        x = x + attn_output
        x = self.norm1(x)
        
        ffn_output = self.ffn(x)
        ffn_output = self.dropout(ffn_output)
        x = x + ffn_output
        x = self.norm2(x)
        return x

class Mymodel(nn.Module):
    def __init__(self, input_shape_0=3, input_shape_1=64*40, output_shape=16*4096):
        super().__init__()
        
        def generate_edge():
            adj = [
                [0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
            ]

            # 生成边索引
            edges = []
            for i in range(17):
                for j in range(17):
                    if adj[i][j] == 1:
                        edges.append([i, j])

            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            return edge_index
        
        def generate_org_edge():
            adj = [
                [0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                [0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
                [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]
            ]
            edges = []
            for i in range(17):
                for j in range(17):
                    if adj[i][j] == 0 and i != j:
                        edges.append([i, j])

            edge_index = torch.tensor(edges, dtype =torch.long).t().contiguous()
            return edge_index
        
        self.org_edge_connect = generate_org_edge()
        self.edge_connect = generate_edge()
        self.edge_number = 241
        
        self.reshape_text = nn.Sequential(
            nn.Linear(768, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, output_shape),
            nn.BatchNorm1d(output_shape),
            nn.ReLU()
        )
        
        
        self.mlp1 = nn.Sequential(
            nn.Linear(input_shape_0, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU()
        )
        
        self.ca_1 = CrossAttentionModule()
        
        self.mlp2 = nn.Sequential(
            nn.Linear(input_shape_1, 8196),
            nn.BatchNorm1d(8196),
            nn.ReLU(),
            nn.Linear(8196, output_shape),
            nn.BatchNorm1d(output_shape),
            nn.ReLU()
        )
        
        self.ca_2 = CrossAttentionModule()
        
        self.mlp3 = nn.Sequential(
            nn.Linear(2 * 4096, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
        
        self.gnn_layer1 = GCNConv(4096, 2048)
        self.gnn_layer2 = GCNConv(2048, 1024)
        
        self.org_gnn_layer1 = GCNConv(4096, 2048)
        self.org_gnn_layer2 = GCNConv(2048, 1024)
        
        self.mlp4 = nn.Sequential(
            nn.Linear(17 * 1024 * 2, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 8),
            nn.BatchNorm1d(8),
            nn.ReLU(),
            nn.Linear(8, 1)
        )
        
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
                init.zeros_(m.bias)
                
    def forward(self, keypoint, feature, knowledge, text_embedding=None, inference=False):
        if inference is True:
            B = keypoint.shape[0]
            new_keypoint_list = []
            for i in range(17):
                this_keypoint = keypoint[:, i, :]
                point = self.mlp1(this_keypoint).unsqueeze(1)
                new_keypoint_list.append(point)
            new_keypoint = torch.concat(new_keypoint_list, dim=1)
            
            knowledge = knowledge.view(B, -1)
            new_knowledge = self.mlp2(knowledge)
            new_knowledge = new_knowledge.view(B, 16, 4096)
            
            knowledge_vision_feature = self.ca_2(feature, new_knowledge)
            
            avg_knowledge_vision_feature = knowledge_vision_feature.mean(dim=1).unsqueeze(1)
        
            ca_tensor_list = []
            for i in range(17):
                this_tensor = new_keypoint[:, i, :].unsqueeze(1)
                this_tensor_ca = self.ca_1(avg_knowledge_vision_feature, this_tensor)
                ca_tensor_list.append(this_tensor_ca)
            ca_tensor = torch.concat(ca_tensor_list, dim=1)
            
            score_list = []
            for i in range(self.edge_number):
                index_0 = self.edge_connect[0, i].item()
                index_1 = self.edge_connect[1, i].item()
                
                tensor_0 = ca_tensor[:, index_0, :].unsqueeze(dim=1)
                tensor_1 = ca_tensor[:, index_1, :].unsqueeze(dim=1)
                concat_tensor = torch.concat([tensor_0, tensor_1], dim=1).view(B, -1)
                score = self.mlp3(concat_tensor)
                score_list.append(score)
                
            edge_weight = torch.concat(score_list, dim=1)
            edge_weight = edge_weight.view(-1)
            edge_weight = torch.sigmoid(edge_weight - 0.5)
            edge_indices = []
            for i in range(B):
                offset = i * 17  
                edge_index = self.edge_connect.clone() + offset
                edge_indices.append(edge_index)
            batch_edge_index = torch.cat(edge_indices, dim=1).to(edge_weight.device) 
            org_edge_indices = []
            for i in range(B):
                org_offset = i * 17
                org_edge_index = self.org_edge_connect.clone() + org_offset
                org_edge_indices.append(org_edge_index)
            batch_org_edge_index = torch.cat(org_edge_indices, dim=1).to(edge_weight.device)
            

            gnn_input = ca_tensor.view(B * 17, -1)
            
            gnn_output = self.gnn_layer1(gnn_input, batch_edge_index, edge_weight)
            gnn_output = F.relu(gnn_output)
            gnn_output = self.gnn_layer2(gnn_output, batch_edge_index, edge_weight)
            gnn_output = F.relu(gnn_output)
            
            org_gnn_output = self.org_gnn_layer1(gnn_input, batch_org_edge_index)
            org_gnn_output = F.relu(org_gnn_output)
            org_gnn_output = self.org_gnn_layer2(org_gnn_output, batch_org_edge_index)
            org_gnn_output = F.relu(org_gnn_output)
            
            org_gnn_output = org_gnn_output.view(B, -1)
            gnn_output = gnn_output.view(B, -1) 
            
            final_output = torch.concat([org_gnn_output, gnn_output], dim=1)
            pred_score = self.mlp4(final_output)
            pred_score = 1 + 9 * torch.sigmoid(pred_score)
            return pred_score
        else:
            B = keypoint.shape[0]
            
            text_embedding = text_embedding.view(B, -1)
            #text_embedding = F.normalize(text_embedding, p=2, dim=1)
            new_text_embedding = self.reshape_text(text_embedding)
            new_text_embedding = new_text_embedding.view(B, 16, 4096)
            
            new_keypoint_list = []
            for i in range(17):
                this_keypoint = keypoint[:, i, :]
                point = self.mlp1(this_keypoint).unsqueeze(1)
                new_keypoint_list.append(point)
            new_keypoint = torch.concat(new_keypoint_list, dim=1)
            
            knowledge = knowledge.view(B, -1)
            new_knowledge = self.mlp2(knowledge)
            new_knowledge = new_knowledge.view(B, 16, 4096)
            
            knowledge_vision_feature = self.ca_2(feature, new_knowledge)
            knowledge_text_feature = self.ca_2(new_text_embedding, new_knowledge)
            avg_knowledge_vision_feature = knowledge_vision_feature.mean(dim=1).unsqueeze(1)
            
            ca_tensor_list = []
            for i in range(17):
                this_tensor = new_keypoint[:, i, :].unsqueeze(1)
                this_tensor_ca = self.ca_1(avg_knowledge_vision_feature, this_tensor)
                ca_tensor_list.append(this_tensor_ca)
            ca_tensor = torch.concat(ca_tensor_list, dim=1)
            
            score_list = []
            for i in range(self.edge_number):
                index_0 = self.edge_connect[0, i].item()
                index_1 = self.edge_connect[1, i].item()
                
                tensor_0 = ca_tensor[:, index_0, :].unsqueeze(dim=1)
                tensor_1 = ca_tensor[:, index_1, :].unsqueeze(dim=1)
                concat_tensor = torch.concat([tensor_0, tensor_1], dim=1).view(B, -1)
                score = self.mlp3(concat_tensor)
                score_list.append(score)
                
            edge_weight = torch.concat(score_list, dim=1)
            edge_weight = edge_weight.view(-1)
            edge_weight = torch.sigmoid(edge_weight - 0.5)
            edge_indices = []
            for i in range(B):
                offset = i * 17  
                edge_index = self.edge_connect.clone() + offset
                edge_indices.append(edge_index)
            batch_edge_index = torch.cat(edge_indices, dim=1).to(edge_weight.device)  
            org_edge_indices = []
            for i in range(B):
                org_offset = i * 17
                org_edge_index = self.org_edge_connect.clone() + org_offset
                org_edge_indices.append(org_edge_index)
            batch_org_edge_index = torch.cat(org_edge_indices, dim=1).to(edge_weight.device)
            

            gnn_input = ca_tensor.view(B * 17, -1)
            
            gnn_output = self.gnn_layer1(gnn_input, batch_edge_index, edge_weight)
            gnn_output = F.relu(gnn_output)
            gnn_output = self.gnn_layer2(gnn_output, batch_edge_index, edge_weight)
            gnn_output = F.relu(gnn_output)
            
            org_gnn_output = self.org_gnn_layer1(gnn_input, batch_org_edge_index)
            org_gnn_output = F.relu(org_gnn_output)
            org_gnn_output = self.org_gnn_layer2(org_gnn_output, batch_org_edge_index)
            org_gnn_output = F.relu(org_gnn_output)
            
            org_gnn_output = org_gnn_output.view(B, -1)
            gnn_output = gnn_output.view(B, -1) 
            
            final_output = torch.concat([org_gnn_output, gnn_output], dim=1)
            pred_score = self.mlp4(final_output)
            pred_score = 1 + 9 * torch.sigmoid(pred_score)
            
            knowledge_text_feature = F.normalize(knowledge_text_feature, p=2, dim=-1)
            knowledge_vision_feature = F.normalize(knowledge_vision_feature, p=2, dim=-1)
            return pred_score, knowledge_vision_feature, knowledge_text_feature


def cosine_loss(tensor1, tensor2):

    cos_sim = F.cosine_similarity(tensor1, tensor2, dim=-1, eps=1e-8)
    loss = 1 - cos_sim
    return loss.mean()



wandb.init(project="5_5", 
           name="model_transformer",
           config={
               "learning_rate": 1e-4,
               "batch_size": 32,
               "epochs": 500,
               "architecture": "CrossAttention"
           })

import itertools
import torch

import math
import torch
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 128):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)           
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2, dtype=torch.float32) *
            (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)                          
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x : (B, L, D)"""
        x = x + self.pe[:, :x.size(1)]
        return self.dropout(x)

class reward_model(nn.Module):
    def __init__(
        self,
        d_model: int = 256,        
        nhead: int = 8,
        num_layers: int = 4,
        dim_feedforward: int = 1024,
        dropout: float = 0.1,
    ):
        super().__init__()

        self.tree_proj    = nn.Linear(40, d_model)
        self.feat_proj    = nn.Linear(4096, d_model)


        self.tree_posenc  = PositionalEncoding(d_model, dropout, max_len=64)
        self.feat_posenc  = PositionalEncoding(d_model, dropout, max_len=16)


        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True        
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)


        self.head = nn.Sequential(
            nn.Linear((64 + 16) * d_model, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 1)        
        )


        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                nn.init.zeros_(m.bias)

    def forward(self, tree: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:

        B = tree.size(0)

        tree_seq = self.tree_proj(tree)       
        feat_seq = self.feat_proj(feature)    

        tree_seq = self.tree_posenc(tree_seq)
        feat_seq = self.feat_posenc(feat_seq)

        x = torch.cat([tree_seq, feat_seq], dim=1)   
        x = self.encoder(x)                          

        x = x.reshape(B, -1)                         
        score = self.head(x)                         
        return score



_COMBOS = torch.tensor(list(itertools.product(range(4), repeat=4))).to(device) 

def get_best_knowledge(model, knowledge, feature, keypoint, gd_score, reward_model):
    device = knowledge.device
    B, C, _ = knowledge.shape


    elbow    = torch.cat([knowledge[:,:,25:30], knowledge[:,:,35:40]], dim=2)
    shoulder = torch.cat([knowledge[:,:,20:25], knowledge[:,:,30:35]], dim=2)
    hip      = torch.cat([knowledge[:,:,10:15], knowledge[:,:, 0: 5]], dim=2)
    knee     = torch.cat([knowledge[:,:,15:20], knowledge[:,:, 5:10]], dim=2)
    parts = torch.stack([elbow, shoulder, hip, knee], dim=1)  

    sel = _COMBOS.to(device)  

    part0 = parts[:, sel[:,0], :, :]  
    part1 = parts[:, sel[:,1], :, :]
    part2 = parts[:, sel[:,2], :, :]
    part3 = parts[:, sel[:,3], :, :]
    cands = torch.cat([part0, part1, part2, part3], dim=-1)  


    cand_flat = cands.view(B * 256, C, 40)
    feat_rep = feature.unsqueeze(1).repeat(1,256,1,1)     
    feat_flat = feat_rep.view(B * 256, 16, 4096)

    scores = reward_model(cand_flat, feat_flat)           
    scores = scores.view(B, 256)                          


    weights = torch.sigmoid(scores)  
    norm_weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-8)

    weights = weights.unsqueeze(-1).unsqueeze(-1)         
    best_knowledge = torch.sum(cands * weights, dim=1)    
    
    return best_knowledge, norm_weights



 
ngpus = torch.cuda.device_count()
print(f"Found {ngpus} GPU(s).")


model = Mymodel().to(torch.float32)
reward = reward_model().to(torch.float32)


model = model.to(device)
reward = reward.to(device)
model.train()
mseloss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=3e-4)
reward_optim = optim.SGD(reward.parameters(), lr=3e-4)
wandb.watch(model, log="all", log_freq=100) 
wandb.watch(reward, log="all", log_freq=100)

dataset = Mydataset("/root/4_29_train.json")
dataloader = DataLoader(dataset, batch_size=9, shuffle=True)
test_dataset = Mydataset("/root/test_4_12.json")
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)
loss_item = []
epoch_loss_history = []
num_epoch = 7000
epoch_num = 0
iter_num = 0
epoch_loss = 0
val_epoch_loss = 0
val_iter_num = 0
all_count = 0
best_rho = 0
for epoch in range(num_epoch):
    model.train()
    reward.train()
    epoch_num += 1
    for batch in dataloader:
        iter_num += 1
        optimizer.zero_grad()
        reward_optim.zero_grad()
        
        
        keypoint_0 = batch['keypoint_0'].to(device).to(torch.float32)
        feature_0 = batch['feature_0'].to(device).to(torch.float32)
        knowledge_0 = batch['knowledge_0'].to(device).to(torch.float32)
        text_0 = batch['text_0'].to(device).to(torch.float32)
        score_0 = torch.tensor(batch['score_0']).to(device).to(torch.float32).view(-1, 1)
        new_knowledge_0, comb_0 = get_best_knowledge(model, knowledge_0, feature_0, keypoint_0, score_0, reward_model=reward)
        pred_score_0, knowledge_vison_feature_0, knowledge_text_feature_0 = model(keypoint_0, feature_0, new_knowledge_0, text_0)
        
        keypoint_1 = batch['keypoint_1'].to(device).to(torch.float32)
        feature_1 = batch['feature_1'].to(device).to(torch.float32)
        knowledge_1 = batch['knowledge_1'].to(device).to(torch.float32)
        text_1 = batch['text_1'].to(device).to(torch.float32)
        score_1 = torch.tensor(batch['score_1']).to(device).to(torch.float32).view(-1, 1)
        new_knowledge_1, comb_1 = get_best_knowledge(model, knowledge_1, feature_1, keypoint_1, score_1, reward_model=reward)
        pred_score_1, knowledge_vison_feature_1, knowledge_text_feature_1 = model(keypoint_1, feature_1, new_knowledge_1, text_1)
        
        cos_loss = cosine_loss(knowledge_vison_feature_0, knowledge_text_feature_0) + cosine_loss(knowledge_vison_feature_1, knowledge_text_feature_1)
        mse_loss = mseloss(pred_score_0, score_0) + mseloss(pred_score_1, score_1)
        kl_ab = F.kl_div(comb_0.log(), comb_1, reduction="batchmean")
        kl_ba = F.kl_div(comb_1.log(), comb_0, reduction="batchmean")
        kl_loss = 0.5 * (kl_ab + kl_ba)
        loss = mse_loss + cos_loss + kl_loss
        loss.backward()
        loss_item.append(loss.item())
        epoch_loss += loss.item()
        optimizer.step()
        reward_optim.step()
        print(f'epoch {epoch}, loss {loss.item()}, mse loss {mse_loss.item()}, cos_loss {cos_loss.item()}')
        wandb.log({"batch_loss": loss.item(), 'mse_loss':mse_loss.item(), 'cos_loss':cos_loss.item(), 'kl_loss':kl_loss.item()})

    if epoch_num == 10:
        model.eval()
        reward.eval()
        epoch_num = 0
        epoch_avg_loss = epoch_loss / iter_num
        epoch_loss = 0
        iter_num = 0
        wandb.log({"epoch_loss": epoch_avg_loss, "epoch": epoch})
        with torch.no_grad():
            test_gd = json.load(open("/root/test_4_12.json"))
            gd_data = []
            pred_data = []
            for i in range(len(test_gd)):
                keypoint = torch.load(test_gd[i]['keypoint']).unsqueeze(0).to(device).to(torch.float32)
                feature = torch.load(test_gd[i]['feature']).unsqueeze(0).to(device).to(torch.float32)
                tree = torch.load(test_gd[i]['knowledge']).unsqueeze(0).to(device).to(torch.float32)
                text = torch.load(test_gd[i]['comment']).unsqueeze(0).to(device).to(torch.float32)
                score = test_gd[i]['score']
                judge_score = torch.tensor(score).to(device).to(torch.float32).view(-1, 1)
                new_knowledge, _ = get_best_knowledge(model, tree, feature, keypoint, judge_score, reward_model=reward)
                gd_data.append(score)
                pred_score = model(keypoint, feature, new_knowledge, text, inference=True)
                pred_score = pred_score.item()
                pred_data.append(pred_score)

            pred = np.array(pred_data)
            gd = np.array(gd_data)

            rho, p = stats.spearmanr(pred, gd)
            RL2 = np.power((pred - gd) / (gd.max() - gd.min()), 2).sum() / gd.shape[0]
            if rho > best_rho:
                torch.save(model.state_dict(), "/root/autodl-tmp/best_model_4_28_w_kl_v2.pt")
                torch.save(reward.state_dict(), "/root/autodl-tmp/best_reward_4_28_w_kl_v2.pt")
                best_rho = rho
            print("yes")
        wandb.log({"rho":rho, "epoch":epoch})
        wandb.log({'RL2': RL2})

model = Mymodel().to(torch.float32)
reward = reward_model().to(torch.float32)
model_checkpoint = torch.load("/root/autodl-tmp/best_model_4_28_w_kl_v2.pt")
reward_checkpoint = torch.load("/root/autodl-tmp/best_reward_4_28_w_kl_v2.pt")

model.load_state_dict(model_checkpoint)
reward.load_state_dict(reward_checkpoint)

model.eval()
reward.eval()
train_data = json.load(open("/root/train_4_8.json"))
save_root = "/root/autodl-tmp/new_gnn_ablation"
with torch.no_grad():
    for data in tqdm(train_data):
        name = data['name']
        comment = torch.load(data['comment']).to(device).unsqueeze(0).to(torch.float32)
        keypoint = torch.load(data['keypoint']).to(device).unsqueeze(0).to(torch.float32)
        knowledge = torch.load(data['knowledge']).to(device).unsqueeze(0).to(torch.float32)
        feature = torch.load(data['feature']).to(device).unsqueeze(0).to(torch.float32)
        
        save_path = os.path.join(save_root, name + ".pt")
        new_knowledge = get_best_knowledge(model, knowledge, feature, keypoint, reward_model=reward)
        gnn_output = model(keypoint, feature, new_knowledge, comment)
        gnn_output = gnn_output.to("cpu")
        torch.save(gnn_output, save_path)

test_data = json.load(open("/root/train_4_12.json"))
save_root = "/root/autodl-tmp/new_gnn_ablation"
with torch.no_grad():
    for data in tqdm(test_data):
        name = data['name']
        comment = torch.load(data['comment']).to(device).unsqueeze(0).to(torch.float32)
        keypoint = torch.load(data['keypoint']).to(device).unsqueeze(0).to(torch.float32)
        knowledge = torch.load(data['knowledge']).to(device).unsqueeze(0).to(torch.float32)
        feature = torch.load(data['feature']).to(device).unsqueeze(0).to(torch.float32)
        
        save_path = os.path.join(save_root, name + ".pt")
        new_knowledge = get_best_knowledge(model, knowledge, feature, keypoint, reward_model=reward)
        gnn_output = model(keypoint, feature, new_knowledge, comment)
        gnn_output = gnn_output.to("cpu")
        torch.save(gnn_output, save_path)   
        

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "/root/autodl-tmp/qwen_7b/models--Qwen--Qwen2.5-7B/snapshots/d149729398750b98c0af14eb82c78cfe92750796"

qwen_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("/root/autodl-tmp/qwen_7b/models--Qwen--Qwen2.5-7B/snapshots/d149729398750b98c0af14eb82c78cfe92750796")

for param in qwen_model.parameters():
    param.requires_grad = False

qwen_model.eval()
qwen_model = qwen_model.to(torch.float16)

class PrefixMLP(nn.Module):
    def __init__(self, hidden_dim, d_model, prefix_len):
        super().__init__()
        self.prefix_len = prefix_len
        self.d_model = d_model
        self.layers = nn.Linear(hidden_dim, prefix_len * d_model)

    def forward(self, gnn_embeddings):

        out = self.layers(gnn_embeddings)  
        out = out.view(-1, self.prefix_len, self.d_model)
        return out
    
d_model = qwen_model.config.hidden_size
hidden_dim = 17 * 1024 * 2
prefix_len = 20
prefix_mlp_main = PrefixMLP(hidden_dim, d_model, prefix_len).to(torch.float16)
prefix_mlp_res = PrefixMLP(hidden_dim, d_model, prefix_len).to(torch.float16)                

device = torch.device("cuda")
prefix_mlp_main = prefix_mlp_main.to(device)
prefix_mlp_res = prefix_mlp_res.to(device)


optim_main = optim.SGD(prefix_mlp_main.parameters(), lr=3e-4)
optim_res = optim.SGD(prefix_mlp_res.parameters(), lr=3e-4)


wandb.init(project="text_5_4_ablation", 
           name="text",
           config={
               "learning_rate": 1e-4,
               "batch_size": 1,
               "epochs": 10,
               "architecture": "CrossAttention"
           })
wandb.watch(prefix_mlp_main, log="all", log_freq=100)  
wandb.watch(prefix_mlp_res, log="all", log_freq=100)  
    
import torch
import os
import json
from torch.utils.data import Dataset, DataLoader

class Mydataset(Dataset):
    def __init__(self, json_path):

        self.data = json.load(open(json_path, "r"))
        self.save_root = "/root/autodl-tmp/new_gnn_ablation"

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        sample = self.data[index]
        ground_truth_describe = sample['gd_text']  

        name_0 = sample['video_0_name']
        name_1 = sample['video_1_name']
        

        tensor_0 = torch.load(os.path.join(self.save_root, name_0 + ".pt"))
        tensor_1 = torch.load(os.path.join(self.save_root, name_1 + ".pt"))

        return tensor_0, tensor_1, ground_truth_describe


def collate_fn(batch):

    tensor_0_list = []
    tensor_1_list = []
    gd_text_list = []

    for (t0, t1, gd_text) in batch:

        t0_flat = t0.view(-1)
        t1_flat = t1.view(-1)

        tensor_0_list.append(t0_flat)
        tensor_1_list.append(t1_flat)
        gd_text_list.append(gd_text)

    tensor_0_batch = torch.stack(tensor_0_list, dim=0)  
    tensor_1_batch = torch.stack(tensor_1_list, dim=0)  

    return tensor_0_batch, tensor_1_batch, gd_text_list

train_dataset = Mydataset(json_path="/root/all_gd.json")
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

epoch_loss = 0
num_epochs = 8
import math
scale = math.sqrt(d_model)
def l2_norm(x):
    return F.normalize(x, dim=-1) * scale

for epoch in range(num_epochs):
    prefix_mlp_main.train()
    prefix_mlp_res.train()
    epoch_loss = 0
    iter_num = 0
    for tensor_0_batch, tensor_1_batch, gd_text_list in train_dataloader:
        iter_num += 1
        tensor_0_batch = tensor_0_batch.to(device).to(torch.float16)
        tensor_1_batch = tensor_1_batch.to(device).to(torch.float16)
        
        tensor_res = tensor_0_batch - tensor_1_batch
        tensor_res_rev = tensor_1_batch - tensor_0_batch
        prefix_0 = prefix_mlp_main(tensor_0_batch)
        prefix_1 = prefix_mlp_main(tensor_1_batch)
        prefix_res  = prefix_mlp_res(tensor_res)
        prefix_res_rev = prefix_mlp_res(tensor_res_rev)
        prefix_0  = l2_norm(prefix_0)
        prefix_1  = l2_norm(prefix_1)
        prefix_res = l2_norm(prefix_res)
        prefix_res_rev = l2_norm(prefix_res_rev)

        prefix_embeds = torch.concat([prefix_0, prefix_res, prefix_res_rev, prefix_1], dim=1)

        system_prompt = "You are Qwen, a helpful assistant created by Alibaba Cloud."
        user_prompt = (
            "Please describe the action difference between player 1 and player 2 given their respective videos. "
            "The description should be no more than one sentence and should include specific comparisons on the two players "
            "regarding which part does player 1 do better than player 2, and which part player 2 do better than player 1."
        )

        messages_list = []
        for gd_text in gd_text_list:
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
                {"role": "assistant", "content": gd_text}  
            ]

            text = tokenizer.apply_chat_template(
                messages,
                tokenize=False,
                add_generation_prompt=True
            )
            messages_list.append(text)

        tokens = tokenizer(messages_list, return_tensors='pt', padding=True)
        input_ids  = tokens["input_ids"].to(device)      
        attention_mask = tokens["attention_mask"].to(device) 

        with torch.no_grad():
            input_embeds = qwen_model.get_input_embeddings()(input_ids)  


        combined_embeds = torch.cat([prefix_embeds, input_embeds], dim=1)


        batch_size, seq_len = input_ids.size()
        prefix_mask = torch.full((batch_size, 4 * prefix_len), -100, dtype=torch.long, device=device)
        new_labels = torch.cat([prefix_mask, input_ids], dim=1) 

        outputs = qwen_model(
            inputs_embeds=combined_embeds,
            attention_mask=None,  
            labels=new_labels
        )
        loss = outputs.loss


        optim_main.zero_grad()
        optim_res.zero_grad()
        loss.backward()
        wandb.log({"batch_loss": loss.item()})
        epoch_loss += loss.item()
        optim_main.step()
        optim_res.step()
        print(f'epoch {epoch}, loss {loss.item()}')
    
    avg_epoch_loss = epoch_loss / iter_num
    wandb.log({"epoch_loss": avg_epoch_loss, "epoch": epoch})
        
        
torch.save(prefix_mlp_main.state_dict(), "/root/autodl-tmp/predix_main_5_1_ablation.pt")
torch.save(prefix_mlp_res.state_dict(), "/root/autodl-tmp/predix_res_5_1_ablation.pt")
wandb.finish()
