import matplotlib.pyplot as plt
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import pandas as pd
from iams.iams_opt import IAMS
from models.poisson import PoissonRegressionModel, poisson_loss
from iams.train import  train, train_iams, train_full_batch
from iams.poisson.data import get_bike_share_data, get_SEER_cancer_data, get_diabetes_data
import copy
import numpy as np


X, y, hours, dataname = get_bike_share_data()
num_epochs = 7
# X, y, dataname = get_diabetes_data()
# num_epochs = 15

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize the features
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Convert data to PyTorch tensors
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

# Create DataLoader for mini-batch training
batch_size = 64
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Initialize the model
input_dim = X_train_tensor.shape[1]
model = PoissonRegressionModel(input_dim)
model_copy = copy.deepcopy(model)

best_loss = float('inf')
# Start the training process using L-BFGS
optimizer = optim.LBFGS(model.parameters(), lr=0.01)
model, train_losses, test_losses = train_full_batch(num_epochs, model, optimizer, poisson_loss, X_train_tensor, y_train_tensor, X_test_tensor, y_test_tensor)
# For plot train and test loss
LBFGS_train_losses = train_losses.copy()
LBFGS_test_losses = test_losses.copy()
model_LBFGS = copy.deepcopy(model)
if LBFGS_train_losses[-1] < best_loss:
    best_loss = LBFGS_train_losses[-1]
    model_star = copy.deepcopy(model_LBFGS)

# Set up Adam 
model =copy.deepcopy(model_copy)
optimizer = optim.Adam(model.parameters(), lr=0.01)
model, train_losses, test_losses =train(num_epochs, model, optimizer, poisson_loss, train_loader, test_loader)
model_Adam = copy.deepcopy(model)
if train_losses[-1] < best_loss:
    best_loss = train_losses[-1]
    model_star = copy.deepcopy(model_Adam)
# Set up the SGD optimizer
c_values = [0.1, 0.5, 1.0, 2.0, 5.0]  # Hyperparameter search for constant c
best_c = None
best_sgd_loss = float('inf')
for c in c_values:
    model = copy.deepcopy(model_copy)
    optimizer = optim.SGD(model.parameters(), lr=0.001*c)
    model, train_losses, test_losses =train(num_epochs, model, optimizer,poisson_loss, train_loader, test_loader)
    if train_losses[-1] < best_sgd_loss:
        best_loss = train_losses[-1]
        best_c = c
        if train_losses[-1] < best_loss:
            best_loss = train_losses[-1]
            model_star = copy.deepcopy(model)
print(f'Best constant c found: {best_c}')
# best_c =1
model = copy.deepcopy(model_copy)
optimizer = optim.SGD(model.parameters(), lr=0.001*best_c)
model_SGD, SGD_train_losses, SGD_test_losses =train(num_epochs, model, optimizer,poisson_loss, train_loader, test_loader)
## Runing again for IAMS

model = copy.deepcopy(model_copy)
IAMS_optimizer = IAMS(model.parameters(), lmbda=None)
model, train_losses, test_losses = train_iams(num_epochs, model, model_star, IAMS_optimizer, poisson_loss, train_loader, test_loader)
# Plotting the training and test loss curves
plt.rcParams.update({
    'font.size': 18,        # Increase the global font size
    'axes.titlesize': 24,   # Title font size
    'axes.labelsize': 15,   # X and Y label font size
    'xtick.labelsize': 12,  # X-axis tick label font size
    'ytick.labelsize': 12,  # Y-axis tick label font size
    'legend.fontsize': 18,  # Legend font size
    'lines.linewidth': 3    # Default line width
})
plt.figure(figsize=(10, 5))
plt.plot(SGD_train_losses, label='SGD train', marker='o', markevery =num_epochs//4,  color="blue",  linewidth=3)
plt.plot(SGD_test_losses, label='SGD test', marker='s', markevery =num_epochs//4, color="blue", linestyle='--', linewidth=3)
plt.plot(train_losses, label='IAM train', marker='x', markevery =num_epochs//4, color="red", linewidth=3)
plt.plot(test_losses, label='IAM test', marker='^', markevery =num_epochs//4, color="red", linestyle='--', linewidth=3)
plt.plot(LBFGS_train_losses, label='LBFGS train',  marker='D', markevery =num_epochs//4, color="black", linewidth=3)
plt.plot(LBFGS_test_losses, label='LBFGS test',  marker='>',  markevery =num_epochs//4, color="black", linestyle='--',linewidth=3)

plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('figures/' + dataname + '.pdf', format='pdf', bbox_inches='tight')
plt.show()



