import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from m_ctc import ctc_cost  # 确保这个函数可以在CUDA上运行
from torch.autograd import Variable

# 检查是否有可用的 CUDA 设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class RelationModel(nn.Module):
    def __init__(self):
        super(RelationModel, self).__init__()
        # self.score_vector = torch.tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device=device)
        # self.score_vector = torch.rand(10, requires_grad=True, dtype=torch.float32, device=device)
        times = 10#torch.rand(1, device=device) * 100
        relation_matrix = torch.rand(100, 100, device=device) * times
        # relation_matrix.diagonal().div_(times)
        # relation_matrix.fill_diagonal_(0)
        self.relation_matrix = relation_matrix

    def forward(self, x):
        # weighted_input = x * self.score_vector
        score = torch.matmul(x, self.relation_matrix)
        final_score = score * x# + weighted_input
        return final_score.mean()

class BS_Layer_Gumble(nn.Module):
    def __init__(self):
        super(BS_Layer_Gumble, self).__init__()
        self.relu_1 = nn.ReLU()
        self.relu_2 = nn.ReLU()


    def forward(self, X, T):
        weights =  self.relu_2(1 - self.relu_1(1 - X))
        eps = 1e-10
        temp = T
        uniform0 = torch.rand_like(weights)#.cuda()
        uniform1 = torch.rand_like(weights)#.cuda()
        noise = -torch.log(torch.log(uniform0 + eps) / torch.log(uniform1 + eps) + eps)
        weights = torch.sigmoid((torch.log(weights + eps) - torch.log(1.0 - weights + eps) + noise) * temp)
        X_out = weights
        return X_out

def Loss_EM(prob, select_band=2):
    prob = prob.unsqueeze(0).unsqueeze(2).unsqueeze(3).cuda()
    band_number = prob.shape[1]
    prob = prob.squeeze(2)
    prob = (torch.cat((1-prob,prob),dim=2).permute(1,0,2)+0.00001).log()
    token = Variable(torch.ones(select_band)).cuda()
    sizes = Variable(torch.IntTensor(np.array([band_number]))).cuda()
    target_sizes = Variable(torch.IntTensor(np.array([select_band]))).cuda()
    cost = ctc_cost(prob, token, sizes, target_sizes)
    return cost
    
def custom_loss(output, x, select_num):
    gumble_loss = (x.sum() - select_num).abs().mean()  # L1 norm for sparsity
    maximization_loss = -output  # Negative output for maximization
    return 10 * gumble_loss + maximization_loss

final_outputs = []
seed_values = range(42, 82)  # 10 different seeds from 42 to 51

for seed in seed_values:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    input_vector = torch.full((100,), 0.5, requires_grad=True, dtype=torch.float32, device=device)
    select_num = 10
    model = RelationModel().to(device)
    model_ly = BS_Layer_Gumble().to(device)
    optimizer = optim.Adam([input_vector], lr=0.01)
    m_epoch = 2000
    for epoch in range(m_epoch):
        model.train()
        optimizer.zero_grad()
        T = 1 / ((1 - 0.001) * (1 - epoch / m_epoch) + 0.001)
        weight = model_ly(input_vector,T)
        output = model(weight)
        
        loss = custom_loss(output, input_vector, select_num)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            input_vector.clamp_(0, 1)

        if epoch % 50 == 0:
            print(f'Epoch {epoch}, Loss: {loss.item()}, Output: {output.item()}')

        non_zero_indices = (input_vector > 0.01).nonzero(as_tuple=True)[0]
        if len(non_zero_indices) == select_num:
            print("Optimization stopped as only two non-zero elements are present.")
            break

    optimized_vector = input_vector.detach().clone()
    non_zero_indices = optimized_vector.nonzero(as_tuple=True)[0]
    new_vector = torch.zeros_like(optimized_vector)
    if len(non_zero_indices) > select_num:
        top_two_values, top_two_indices = torch.topk(optimized_vector, select_num)
        new_vector[top_two_indices] = 1
    else:
        new_vector[non_zero_indices] = 1

    final_output = model(new_vector).item()
    final_outputs.append(final_output)

# Save the outputs to a text file
with open('model_outputs_gumble.txt', 'w') as f:
    for output in final_outputs:
        f.write(f"{output}\n")

print("Final outputs for different seeds saved to 'model_outputs_gumble.txt'.")
