import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from scipy.optimize import minimize
from sklearn.preprocessing import StandardScaler
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def dataframe_to_dict(df,X):
    """
    将数据框转换为指定结构的字典
    """
    result_dict = {
        "X": df[X].values,
        "A": df["A"].values,
        "Y": df["Y"].values,
        "R": df["R"].values,
        "S": df["S"].values
    }

    return result_dict

class MLP1(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.to(device)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)

        return torch.squeeze(x)

    def fit(self, X_train, y_train, batch_size = 64, num_epoch=10000, lr=0.005, lamb=1e-3, tol=1e-5, verbose=True):
        X_train = torch.Tensor(X_train)
        y_train = torch.Tensor(y_train)

        train_data = TensorDataset(X_train, y_train)
        train_loader = DataLoader(train_data, batch_size=batch_size)

        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=lamb)

        last_loss = 1e9
        for epoch in range(num_epoch):
            epoch_loss = 0
            for X, y in train_loader:
                X = X.to(device)
                y = y.to(device)
                pred = self.forward(X)
                loss = criterion(pred, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.detach().cpu().numpy()

            epoch_loss /= len(train_loader)

            if epoch_loss < last_loss - tol:
                last_loss = epoch_loss
                early_stop = 0
            else:
                early_stop += 1

            if early_stop > 20:
                print("[MLP_model] epoch:{}, MSE:{}".format(epoch, epoch_loss))
                break

            if epoch % 50 == 0:
                print("[MLP_model] epoch:{}, MSE:{}".format(epoch, epoch_loss))


    def predict(self, x):
        x = torch.Tensor(x)
        x = x.to(device)
        x = self.forward(x)
        return x.detach().cpu().numpy()


class MLP2(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.relu = torch.nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = x.to(device)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return torch.squeeze(x)

    def fit(self, X_train, y_train, batch_size = 64, num_epoch=10000, lr=0.005, lamb=1e-3, tol=1e-3, verbose=True):
        X_train = torch.Tensor(X_train)
        y_train = torch.Tensor(y_train)

        train_data = TensorDataset(X_train, y_train)
        train_loader = DataLoader(train_data, batch_size=batch_size)

        criterion = nn.MSELoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=lr, weight_decay=lamb)

        last_loss = 1e9
        for epoch in range(num_epoch):
            epoch_loss = 0
            for X, y in train_loader:
                X = X.to(device)
                y = y.to(device)
                pred = self.forward(X)
                loss = criterion(pred, y)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                epoch_loss += loss.detach().cpu().numpy()

            if epoch_loss < last_loss - tol:
                last_loss = epoch_loss
                early_stop = 0
            else:
                early_stop += 1

            if early_stop > 20:
                print("[MLP_model] epoch:{}, MSE:{}".format(epoch, epoch_loss))
                break

            if epoch % 50 == 0:
                print("[MLP_model] epoch:{}, MSE:{}".format(epoch, epoch_loss))


    def predict(self, x):
        x = torch.Tensor(x)
        x = x.to(device)
        x = self.forward(x)
        return x.detach().cpu().numpy()


class Exponential_regression:
    def __init__(self, normalize=None, random_state=None):
        """
        初始化模型
        :param normalize: 是否标准化特征
        :param random_state: 随机种子
        """
        self.params = None  # 模型参数 [a, b1, b2, ..., bn]
        self.normalize = normalize
        self.scaler_X = StandardScaler() if normalize else None
        self.scaler_y = StandardScaler() if normalize else None
        self.random_state = random_state
        self.loss_history = []

    def _model(self, params, X):
        """模型计算函数"""
        a = params[0]
        b = params[1:]
        # 限制指数运算结果，避免溢出
        exponent = np.dot(X, b)
        exponent = np.clip(exponent, -100, 100)  # 限制指数范围
        return a * np.exp(exponent)

    def _objective(self, params, X, y):
        """目标函数（残差平方和）"""
        y_pred = self._model(params, X)
        # 使用 clip 限制差值范围，避免溢出
        diff = np.clip(y - y_pred, -1e6, 1e6)
        loss = 0.5 * np.sum(diff ** 2)
        self.loss_history.append(loss)
        return loss

    def _gradient(self, params, X, y):
        """计算梯度"""
        a = params[0]
        b = params[1:]
        exponent = np.dot(X, b)
        exponent = np.clip(exponent, -100, 100)  # 限制指数范围
        exp_term = np.exp(exponent)
        error = y - a * exp_term

        grad_a = -np.sum(error * exp_term)
        grad_b = -np.dot(X.T, error * a * exp_term)

        return np.concatenate(([grad_a], grad_b))

    def fit(self, X, y, initial_params=None, method='BFGS', maxiter=1000, tol=1e-4):
        """
        训练模型
        :param X: 特征矩阵 (n_samples, n_features)
        :param y: 目标值 (n_samples,)
        :param initial_params: 初始参数猜测
        :param method: 优化方法 ('BFGS', 'L-BFGS-B', 'Nelder-Mead'等)
        :param maxiter: 最大迭代次数
        :param tol: 容忍度
        :return: 训练好的模型
        """
        if self.random_state is not None:
            np.random.seed(self.random_state)

        # 数据预处理
        if self.normalize:
            X = self.scaler_X.fit_transform(X)
            y = self.scaler_y.fit_transform(y.reshape(-1, 1)).ravel()

        # 设置初始参数
        n_features = X.shape[1]
        if initial_params is None:
            initial_params = np.concatenate(([np.mean(y)], np.zeros(n_features)))

        # 优化模型参数
        self.loss_history = []
        result = minimize(self._objective, initial_params, args=(X, y),
                          method=method, jac=self._gradient,
                          options={'maxiter': maxiter, 'disp': False},
                          tol=tol)

        self.params = result.x
        return self

    def predict(self, X):
        """
        预测
        :param X: 特征矩阵 (n_samples, n_features)
        :return: 预测值 (n_samples,)
        """
        if self.params is None:
            raise ValueError("Model not fitted yet. Call fit() first.")

        if self.normalize:
            X = self.scaler_X.transform(X)

        y_pred = self._model(self.params, X)
        if self.normalize:
            y_pred = self.scaler_y.inverse_transform(y_pred.reshape(-1, 1)).ravel()
        return y_pred
