import torch
import torch.nn as nn
import numpy as np
import wandb
import os
import matplotlib.pyplot as plt


class PI_controller(nn.Module):
  
  def __init__(self,
               model:nn.Module,
               k_p = 0.01,
               k_i = 0.001,
               lamba_min = 1e-1,
               lambda_ini =0,
               lamba_max = 1,
               target = 1e-4,
               moe_ploicy = "Stragey_1"
               ):
        """
        初始化PI控制器
        :param model: 要控制的模型
        :param kp: 比例增益
        :param ki: 积分增益
        :param target_loss: 目标损失值
        moe_ploicy   
        按照论文里面是暂时两个策略
        Stragey_1:分治,top_k 
        或者Stragey_2:按照权重送给每一个模型
        """
        super(PI_controller, self).__init__()
        self.base_model = "PID-M2M"
        self.model = model
        self.kp = k_p
      
        self.ki = k_i
        print("kp",self.kp)
        self.moe_policy = moe_ploicy
        self.target_loss = target
        self.integral = 0  # 初始化积分项
        self.error = 0  # 上一次的误差
      
        
        self.lamba = lambda_ini 
        self.lamba_min = lamba_min
        self.lamba_max = lamba_max

        self.model = model
        
        #保存起来
        self.save_dict = {"target": self.target_loss,
                          "lamba_t": [], # along t
                          "K_p": self.kp,
                          "K_i": self.ki,
                          "error_t": [] # along t
                          }
        # 初始的lambad
        self.save_dict["lamba_t"].append( self.lamba )

      
    
  def update(self, current_loss):
        """
        更新控制器参数
        :param current_loss: 当前模型的损失
        :return: 调整后的参数值
        """
        # 计算误差
        self.error = self.target_loss - current_loss
        
        # 计算比例项
        p = self.kp / (1 + np.exp(self.error))
        
        # 检查 lambda 是否在允许的范围内，以决定是否更新积分项
        if self.lamba_min < self.lamba < self.lamba_max:
            self.integral += self.ki * self.error #注意+法
        else:
            # 反积分饱和处理
            self.integral = self.integral  # 保持不变
        
        # 更新 lambda
        self.lamba = p + self.integral + self.lamba_min
        
        # 限制 lambda 在指定的范围内
        self.lamba = max(self.lamba_min, min(self.lamba_max, self.lamba))
        
        #保存更新
        self.save_dict["lamba_t"] = np.append(self.save_dict["lamba_t"], self.lamba)

        self.save_dict["error_t"] = np.append(self.save_dict["error_t"], self.error)
        return  self.lamba
      
  def forward(self,input_tensors,static_tensors = None,
                           boundary_tensors = None,
                           pred_steps=[1],
                           multi_sclale=1,**kwargs):
    '''
      按照策略送入模型，然后得到输出,   到时候用模型的forward
    '''
    out = self.model(input_tensors)
    
    return out
  
  @property
  def Moe_weight_info(self):
    #下游模型的property
    return self.model.Moe_weight_info

  @property
  def load_balancing_loss(self):

    return self.model.load_balancing_loss
  

  def count_paras(self):
    return sum(p.numel() for p in self.parameters())
  
  @property
  def MOE_name(self):
    return self.model.experts_name
  
  def save_pid(self,save_path:str,epoch:int):

    # 把 lamba_t 和 error_t 转换为 numpy 数组
    self.save_dict["lamba_t"] = np.array(self.save_dict["lamba_t"])
    self.save_dict["error_t"] = np.array(self.save_dict["error_t"])
    # 保存到 .npz 文件
    np.savez(f'{save_path}/PID_dict_file.npz', **self.save_dict)
    
    # 创建图像和左侧 y 轴
    fig, ax1 = plt.subplots()

    # 在左侧 y 轴上绘制 error_t
    ax1.plot(self.save_dict["error_t"], label="error(t)", color='blue')
    ax1.set_xlabel("Epoch(t)")
    ax1.set_ylabel("Error", color='blue')
    ax1.tick_params(axis='y', labelcolor='blue')

    # 假设 self.save_dict["target"] 是你希望绘制的 y 轴上的固定值
    target_value = self.save_dict["target"]

    # 使用 axhline 在 y=target_value 处绘制红色的水平线
    ax1.axhline(y=target_value, color='red', linestyle='--', label='target')

    # 创建右侧 y 轴
    ax2 = ax1.twinx()

    # 在右侧 y 轴上绘制 lamba_t
    ax2.plot(self.save_dict["lamba_t"], label=r"$\lambda(t)$", color='green')
    ax2.set_ylabel(r"$\lambda(t)$", color='green')
    ax2.tick_params(axis='y', labelcolor='green')

    # 添加图例
    fig.tight_layout()  # 防止布局重叠
   # 设置图例，自动选择最佳位置
  # 设置图例位置，避免重叠
    ax1.legend(loc='upper right')  # 手动指定右上角
    ax2.legend(loc='lower left')   # 手动指定左下角

    # 添加标题
    plt.title(f"PID: K_p:{self.kp}, K_i:{self.ki}")

    plt.tight_layout()
    img_path = os.path.join(save_path, "pid_show.png")  # 使用 os.path.join 拼接路径
    plt.savefig(img_path,dpi=300)
    #使用 wandb.log 上传图像
    # 正确的代码
    wandb.log({"PID-image": wandb.Image(img_path)}, step=epoch)
    plt.close()
    
    
