from gurobipy import GRB
import torch
import torch.nn as nn
import torch.optim as optim
from openpto.method.Models.abcOptModel import optModel
from openpto.method.utils_method import do_reduction, to_tensor,to_array
import numpy as np
import torch
from torch.distributions import Categorical
from typing import List, Tuple, Dict, Callable, Optional

class PolicyNetwork(nn.Module):
    """策略网络，输出物品选择概率分布"""
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, action_dim)  # 直接输出每个物品的选择概率
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))  # 使用sigmoid使输出在[0,1]之间
        return x


class RewardNetwork(nn.Module):
    """奖励网络，评估状态-动作对的价值"""
    def __init__(self, state_dim, action_dim, hidden_dim=64):
        super(RewardNetwork, self).__init__()
        self.fc1 = nn.Linear(state_dim + action_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, 1)
        
    def forward(self, state, action):
        x = torch.cat([state, action], dim=-1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)


class OnlineIRL:
    """在线逆向强化学习，使用向量形式的专家动作"""
    def __init__(self, state_dim, action_dim, lr=1e-3):
        self.policy = PolicyNetwork(state_dim, action_dim)
        self.reward = RewardNetwork(state_dim, action_dim)
        
        self.policy_optimizer = optim.Adam(self.policy.parameters(), lr=lr)
        self.reward_optimizer = optim.Adam(self.reward.parameters(), lr=lr)
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        # 评估缓冲区
        self.eval_buffer_states = []
        self.eval_buffer_actions = []
    
    def get_action(self, state):
        """根据当前策略生成物品选择向量"""
        state_tensor = torch.FloatTensor(state)
        with torch.no_grad():
            probs = self.policy(state_tensor)
        
        # 按概率随机选择物品
        action = (torch.rand(self.action_dim) < probs).float().numpy()
        return action
    
    def get_deterministic_action(self, state):
        """获取确定性的物品选择向量（用于评估）"""
        state_tensor = torch.FloatTensor(state)
        with torch.no_grad():
            probs = self.policy(state_tensor)
        
        # 阈值为0.5，大于则选择
        action = (probs >= 0.5).float().numpy()
        return action
    
    def train_with_expert_demonstration(self, state, expert_action):
        """使用单个专家示范训练模型"""
        state_tensor = torch.FloatTensor(state)
        expert_action_tensor = torch.FloatTensor(expert_action)
        
        # 1. 更新奖励网络
        self._update_reward_network(state_tensor, expert_action_tensor)
        
        # 2. 更新策略网络
        self._update_policy_network(state_tensor, expert_action_tensor)
        
        # 3. 存储评估数据
        self.eval_buffer_states.append(state)
        self.eval_buffer_actions.append(expert_action)
        
        # 限制缓冲区大小
        if len(self.eval_buffer_states) > 100:
            self.eval_buffer_states.pop(0)
            self.eval_buffer_actions.pop(0)
    
    def _update_reward_network(self, state_tensor, expert_action_tensor):
        """更新奖励网络，区分专家动作和策略动作"""
        # 生成当前策略的动作
        policy_action = self.get_action(state_tensor.numpy())
        policy_action_tensor = torch.FloatTensor(policy_action)
        
        # 计算奖励
        expert_reward = self.reward(state_tensor, expert_action_tensor)
        policy_reward = self.reward(state_tensor, policy_action_tensor)
        
        # 奖励网络损失
        reward_loss = -torch.log(torch.sigmoid(expert_reward)) - torch.log(1 - torch.sigmoid(policy_reward))
        
        # 优化
        self.reward_optimizer.zero_grad()
        reward_loss.backward()
        self.reward_optimizer.step()
    
    def _update_policy_network(self, state_tensor, expert_action_tensor):
        """更新策略网络以最大化奖励网络预测的奖励"""
        # 获取当前策略的动作概率
        probs = self.policy(state_tensor)
        
        # 计算对数概率（使用伯努利分布）
        log_probs = expert_action_tensor * torch.log(probs + 1e-8) + \
                   (1 - expert_action_tensor) * torch.log(1 - probs + 1e-8)
        
        # 计算奖励网络对专家动作的奖励
        reward = self.reward(state_tensor, expert_action_tensor)
        
        # 策略梯度损失（最大化奖励）
        policy_loss = -torch.mean(log_probs) * reward
        
        # 优化
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()