import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from m_ctc import ctc_cost  # 确保这个函数可以在CUDA上运行
from torch.autograd import Variable

# 检查是否有可用的 CUDA 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class RelationModel(nn.Module):
    def __init__(self):
        super(RelationModel, self).__init__()
        # self.score_vector = torch.tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device=device)
        # self.score_vector = torch.rand(10, requires_grad=True, dtype=torch.float32, device=device)
        times = 10#torch.rand(1, device=device) * 100
        relation_matrix = torch.rand(100, 100, device=device) * times
        # relation_matrix.diagonal().div_(times)
        # relation_matrix.fill_diagonal_(0)
        self.relation_matrix = relation_matrix

    def forward(self, x):
        # weighted_input = x * self.score_vector
        score = torch.matmul(x, self.relation_matrix)
        final_score = score * x# + weighted_input
        return final_score.mean()

def Loss_EM(prob, select_band=2):
    prob = prob.unsqueeze(0).unsqueeze(2).unsqueeze(3).cuda()
    band_number = prob.shape[1]
    prob = prob.squeeze(2)
    prob = (torch.cat((1-prob,prob),dim=2).permute(1,0,2)+0.00001).log()
    token = Variable(torch.ones(select_band)).cuda()
    sizes = Variable(torch.IntTensor(np.array([band_number]))).cuda()
    target_sizes = Variable(torch.IntTensor(np.array([select_band]))).cuda()
    cost = ctc_cost(prob, token, sizes, target_sizes)
    return cost

def custom_loss(output, x, select_num):
    EM_loss = Loss_EM(x, select_num)
    maximization_loss = -output
    return 5 * EM_loss + maximization_loss

final_outputs = []
seed_values = range(42, 82)  # 10 different seeds from 42 to 51

for seed in seed_values:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    input_vector = torch.full((100,), 0.5, requires_grad=True, dtype=torch.float32, device=device)

    model = RelationModel().to(device)
    optimizer = optim.Adam([input_vector], lr=0.01)
    select_num = 10
    for epoch in range(300):
        model.train()
        optimizer.zero_grad()

        output = model(input_vector)
        loss = custom_loss(output, input_vector, select_num)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            input_vector.clamp_(0, 1)

        if epoch % 50 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}, Output: {output.item()}')

    optimized_vector = input_vector.detach().clone()
    non_zero_indices = optimized_vector.nonzero(as_tuple=True)[0]
    new_vector = torch.zeros_like(optimized_vector)
    if len(non_zero_indices) > select_num:
        top_two_values, top_two_indices = torch.topk(optimized_vector, select_num)
        new_vector[top_two_indices] = 1
    else:
        new_vector[non_zero_indices] = 1

    final_output = model(new_vector).item()
    final_outputs.append(final_output)

# Save the outputs to a text file
with open('model_outputs_EM.txt', 'w') as f:
    for output in final_outputs:
        f.write(f"{output}\n")

print("Final outputs for different seeds saved to 'model_outputs_EM.txt'.")
