import torch
import torch.nn as nn
from layers.QANLayer import QAN as  DRQC
class Model(nn.Module):
    def __init__(self, configs):
        super(Model, self).__init__()

        self.seq_len = configs.seq_len
        self.pred_len = configs.pred_len
        self.enc_in = configs.enc_in
        self.cycle_len = configs.cycle
        self.model_type = configs.model_type
        self.d_model = configs.d_model
        self.dropout = configs.dropout
        self.use_revin = configs.use_revin
        self.num_qubits = configs.num_qubits
        self.num_layers = configs.num_layers
        self.use_lpv = True # ablation parameter, default: True
        self.channel_aggre = True # ablation parameter, default: True
        # 新增：从 configs 读取 IBM/Qiskit-remote 相关配置（可选）
        self.qiskit_remote = True
        self.ibm_channel = getattr(configs, "ibm_channel", "ibm_quantum")
        self.ibm_instance = getattr(configs, "ibm_instance", "PQ-Net")
        self.ibm_backend_name = getattr(configs, "ibm_backend_name", "ibm_brisbane")
        self.ibm_min_qubits = getattr(configs, "ibm_min_qubits", self.num_qubits)
        self.q_shots = getattr(configs, "q_shots", 1024)
        self.q_optimization_level = getattr(configs, "q_optimization_level", 1)
        self.q_resilience_level = getattr(configs, "q_resilience_level", 1)

        if self.use_lpv:
            self.temporalQuery = torch.nn.Parameter(torch.zeros(self.cycle_len, self.enc_in), requires_grad=True)

        if self.channel_aggre:
            # 使用DRQC替代MultiheadAttention
            # self.channelAggregator = DRQC(input_dim=self.seq_len, output_dim=self.seq_len, hidden_dim=self.num_qubits, num_layers=self.num_layers)
             # 使用 DRQC 量子层；传入远端配置
            self.channelAggregator = DRQC(
                input_dim=self.seq_len,
                output_dim=self.seq_len,
                hidden_dim=self.num_qubits,
                num_layers=self.num_layers,
                use_remote=self.qiskit_remote,
                ibm_channel=self.ibm_channel,
                ibm_instance=self.ibm_instance,
                ibm_backend_name=self.ibm_backend_name,
                ibm_min_qubits=self.ibm_min_qubits,
                shots=self.q_shots,
                optimization_level=self.q_optimization_level,
                resilience_level=self.q_resilience_level,
            )
        self.input_proj = nn.Linear(self.seq_len, self.d_model)

        self.model = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
        )

        self.output_proj = nn.Sequential(
            nn.Dropout(self.dropout),
            nn.Linear(self.d_model, self.pred_len)
        )
        # self.fc = nn.Linear(self.seq_len, self.pred_len)

    def forward(self, x, cycle_index):

        # instance norm
        if self.use_revin:
            seq_mean = torch.mean(x, dim=1, keepdim=True)
            seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
            x = (x - seq_mean) / torch.sqrt(seq_var)

        # b,s,c -> b,c,s
        x_input = x.permute(0, 2, 1)

        if self.use_lpv:
            gather_index = (cycle_index.view(-1, 1) + torch.arange(self.seq_len, device=cycle_index.device).view(1, -1)) % self.cycle_len
            query_input = self.temporalQuery[gather_index].permute(0, 2, 1)  # (b, c, s)
            if self.channel_aggre:
                # 将query_input和x_input拼接后送入QAN
                # concatenated_input = torch.cat([query_input, x_input], dim=1)  # 在通道维度拼接
                # channel_information = self.channelAggregator(concatenated_input)
                # 不拼接
                channel_information = self.channelAggregator(query_input)
            else:
                channel_information = query_input
        else:
            if self.channel_aggre:
                # 当不使用temporal query时
                channel_information = self.channelAggregator(x_input)
            else:
                channel_information = 0

        input = self.input_proj(x_input+channel_information)

        hidden = self.model(input)

        output = self.output_proj(hidden+input).permute(0, 2, 1)
        

        # instance denorm
        if self.use_revin:
            output = output * torch.sqrt(seq_var) + seq_mean

        return output


# import torch
# import torch.nn as nn

# class Model(nn.Module):
#     def __init__(self, configs):
#         super(Model, self).__init__()

#         self.seq_len = configs.seq_len
#         self.pred_len = configs.pred_len
#         self.enc_in = configs.enc_in
#         self.cycle_len = configs.cycle
#         self.model_type = configs.model_type
#         self.d_model = configs.d_model
#         self.dropout = configs.dropout
#         self.use_revin = configs.use_revin

#         self.use_tq = True  # ablation parameter, default: True
#         self.channel_aggre = True   # ablation parameter, default: True

#         if self.use_tq:
#             self.temporalQuery = torch.nn.Parameter(torch.zeros(self.cycle_len, self.enc_in), requires_grad=True)

#         if self.channel_aggre:
#             self.channelAggregator = nn.MultiheadAttention(embed_dim=self.seq_len, num_heads=4, batch_first=True, dropout=0.5)

#         self.input_proj = nn.Linear(self.seq_len, self.d_model)

#         self.model = nn.Sequential(
#             nn.Linear(self.d_model, self.d_model),
#             nn.GELU(),
#             nn.Linear(self.d_model, self.d_model),
#             nn.GELU(),
#         )

#         self.output_proj = nn.Sequential(
#             nn.Dropout(self.dropout),
#             nn.Linear(self.d_model, self.pred_len)
#         )


#     def forward(self, x, cycle_index):

#         # instance norm
#         if self.use_revin:
#             seq_mean = torch.mean(x, dim=1, keepdim=True)
#             seq_var = torch.var(x, dim=1, keepdim=True) + 1e-5
#             x = (x - seq_mean) / torch.sqrt(seq_var)

#         # b,s,c -> b,c,s
#         x_input = x.permute(0, 2, 1)

#         if self.use_tq:
#             gather_index = (cycle_index.view(-1, 1) + torch.arange(self.seq_len, device=cycle_index.device).view(1, -1)) % self.cycle_len
#             query_input = self.temporalQuery[gather_index].permute(0, 2, 1)  # (b, c, s)
#             if self.channel_aggre:
#                 channel_information = self.channelAggregator(query=query_input, key=x_input, value=x_input)[0]
#             else:
#                 channel_information = query_input
#         else:
#             if self.channel_aggre:
#                 channel_information = self.channelAggregator(query=x_input, key=x_input, value=x_input)[0]
#             else:
#                 channel_information = 0

#         input = self.input_proj(x_input+channel_information)

#         hidden = self.model(input)

#         output = self.output_proj(hidden+input).permute(0, 2, 1)

#         # instance denorm
#         if self.use_revin:
#             output = output * torch.sqrt(seq_var) + seq_mean

#         return output

