import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import copy
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
from iams.iams_opt import IAMS
from models.poisson import PoissonRegressionModel, poisson_loss
from iams.train import  train, train_iams, train_full_batch
from iams.poisson.data import get_bike_share_data, get_SEER_cancer_data, get_diabetes_data# Load the diabetes dataset

# X, y, hours, dataname = get_bike_share_data()
# num_epochs = 7
X, y, dataname = get_diabetes_data()
num_epochs = 15

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize the features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# Create DataLoader for mini-batch training
batch_size = 64
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Hyperparameters and results storage
learning_rates = [0.01, 0.1, 0.5, 1.0, 2.0, 5.0, 10, 20, 50]
results = {'SGD': {'train_errors': [], 'test_errors': []},
           'LBFGS': {'train_errors': [], 'test_errors': []}}
input_dim = X_train_tensor.shape[1]
# Testing convergence for each learning rate using both optimizers
best_loss = float('inf')
for lr in learning_rates:
    # SGD Optimizer
    model_sgd = PoissonRegressionModel(input_dim=input_dim)
    optimizer_sgd = optim.SGD(model_sgd.parameters(), lr=lr*0.001)
    model_sgd, train_error_sgd, test_error_sgd =train(num_epochs, model_sgd, optimizer_sgd,poisson_loss, train_loader, test_loader)
    # Calculate training and test errors for SGD
    if train_error_sgd[-1] < best_loss:
        best_loss = train_error_sgd[-1]
        model_star = copy.deepcopy(model_sgd)
    results['SGD']['train_errors'].append(train_error_sgd[-1])
    results['SGD']['test_errors'].append(test_error_sgd[-1])

    # L-BFGS Optimizer
    model_lbfgs = PoissonRegressionModel(input_dim=input_dim)
    optimizer_lbfgs = optim.LBFGS(model_lbfgs.parameters(), lr=lr*0.01)
    model, LBFGS_train_error, LBFGS_test_error = train_full_batch(num_epochs, model_lbfgs, optimizer_lbfgs, poisson_loss, X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor)
    # For plot train and test loss
    model_LBFGS = copy.deepcopy(model)
    if LBFGS_train_error[-1] < best_loss:
        best_loss = LBFGS_train_error[-1] 
        model_star = copy.deepcopy(model_LBFGS)
    results['LBFGS']['train_errors'].append(LBFGS_train_error[-1] )
    results['LBFGS']['test_errors'].append(LBFGS_test_error[-1] )

model = PoissonRegressionModel(input_dim=input_dim)
IAMS_optimizer = IAMS(model.parameters(), lmbda=None)
model, train_losses, test_losses = train_iams(num_epochs, model, model_star, IAMS_optimizer, poisson_loss, train_loader, test_loader)

plt.rcParams.update({
    'font.size': 18,        # Increase the global font size
    'axes.titlesize': 24,   # Title font size
    'axes.labelsize': 15,   # X and Y label font size
    'xtick.labelsize': 12,  # X-axis tick label font size
    'ytick.labelsize': 12,  # Y-axis tick label font size
    'legend.fontsize': 18,  # Legend font size
    'lines.linewidth': 3    # Default line width
})

# Plotting the sensitivity results
plt.figure(figsize=(12, 6))

# Plot for SGD
plt.plot(learning_rates, results['SGD']['train_errors'], label='SGD train', marker='o', linestyle='-', color='blue')
plt.plot(learning_rates, results['SGD']['test_errors'], label='SGD test', marker='o', linestyle='--', color='blue')

# Plot for L-BFGS
plt.plot(learning_rates, results['LBFGS']['train_errors'], label='L-BFGS train', marker='s', linestyle='-', color='black')
plt.plot(learning_rates, results['LBFGS']['test_errors'], label='L-BFGS test', marker='s', linestyle='--', color='black')

# iams
plt.axhline(y=train_losses[-1], color='red', linestyle='-', linewidth=2, label='IAM train')
plt.axhline(y=test_losses[-1], color='red', linestyle='--', linewidth=2, label='IAM test')
plt.xlabel('Learning Rate')
plt.ylabel('Loss')
plt.xscale('log')  # Log scale for learning rates
plt.legend()
plt.grid(True)

# Save the plot
os.makedirs('figures', exist_ok=True)
plt.savefig('figures/sensitivity-' + dataname + '-plot.pdf', format='pdf', bbox_inches='tight')
plt.show()
