import torch
import torch.nn as nn
from torch.optim.optimizer import Optimizer
from torch.optim.lr_scheduler import StepLR
import numpy as np
import csv
from sklearn.model_selection import train_test_split
import random
from algs import *
import copy

# Set the random seed
seed = 2023
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
 

data = np.genfromtxt('process_heart1.csv', delimiter=',')
X = data[1:, :6]
y = data[1:, 6]

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=10)

X_train = X_train[0:200]
y_train = y_train[0:200]
X_test = X_test[0:50]
y_test = y_test[0:50]



class Net(nn.Module):
    def __init__(self, input_size):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return x

input_size = X_train.shape[1]
model = Net(input_size)
model_gd = copy.deepcopy(model)
model_sgd = copy.deepcopy(model)
model_svrg = copy.deepcopy(model)

criterion = nn.BCELoss()


def build_model(X_train,y_train,model,optimizer,scheduler,criterion,num_epochs):
    train_losses = []
    test_losses = []
    gen_errors = []

    for epoch in range(num_epochs):
        inputs = torch.from_numpy(X_train).float()
        labels = torch.from_numpy(y_train).float()

        outputs = model(inputs).squeeze(1)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        # loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        train_losses.append(loss.item())
        if (epoch+1)%10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {loss.item()}')

        with torch.no_grad():
            inputs = torch.from_numpy(X_test).float()
            labels = torch.from_numpy(y_test).float()
            outputs = model(inputs).squeeze()

            loss = criterion(outputs, labels)
            test_losses.append(loss.item())

            gen_error = abs(test_losses[-1] - train_losses[-1])
            gen_errors.append(gen_error)

            if (epoch+1)%10 == 0:
                print(f'Test Loss: {loss.item()}')
                print(f'Generalization Error: {gen_error}')
    return train_losses, test_losses, gen_errors





# ZO-GD
# optimizer_gd = ZO_GD(model_gd, model_gd.parameters(), X_train, y_train, criterion, lr=1e-5, use_true_grad=False)
# scheduler_gd = None
# train_losses_gd, test_losses_gd, gen_errors_gd = build_model(X_train,y_train,model_gd,optimizer_gd,scheduler_gd,criterion,num_epochs = 4000)

# # Save the loss values after the loop
# with open('train_loss_gd.csv', 'w', newline='') as f1:
#     writer = csv.writer(f1)
#     writer.writerows([[val] for val in train_losses_gd])

# with open('test_loss_gd.csv', 'w', newline='') as f2:
#     writer = csv.writer(f2)
#     writer.writerows([[val] for val in test_losses_gd])

# with open('test_train_loss_gd.csv', 'w', newline='') as f3:
#     writer = csv.writer(f3)
#     writer.writerows([[val] for val in gen_errors_gd])


#ZO-SGD

# optimizer_sgd = ZO_MiniBatch_SGD(model_sgd, model_sgd.parameters(), X_train, y_train, criterion, lr=1e-5, batch_size=1, use_true_grad=False)
# scheduler_sgd = StepLR(optimizer_sgd, step_size=100, gamma=0.7)
# train_losses_sgd, test_losses_sgd, gen_errors_sgd = build_model(X_train,y_train,model_sgd,optimizer_sgd,scheduler_sgd,criterion,num_epochs=4000)
# # Save the loss values after the loop
# with open('train_loss_sgd.csv', 'w', newline='') as f1:
#     writer = csv.writer(f1)
#     writer.writerows([[val] for val in train_losses_sgd])

# with open('test_loss_sgd.csv', 'w', newline='') as f2:
#     writer = csv.writer(f2)
#     writer.writerows([[val] for val in test_losses_sgd])

# with open('test_train_loss_sgd.csv', 'w', newline='') as f3:
#     writer = csv.writer(f3)
#     writer.writerows([[val] for val in gen_errors_sgd])


#ZO-SVRG
optimizer_svrg = ZO_SVRG(model_svrg, model_svrg.parameters(), X_train, y_train, criterion, lr=1e-6, batch_size=1, m=100)
scheduler_svrg = StepLR(optimizer_svrg,step_size=200, gamma=0.9)
# scheduler_svrg = None
train_losses_svrg, test_losses_svrg, gen_errors_svrg = build_model(X_train,y_train,model_svrg,optimizer_svrg,scheduler_svrg,criterion,num_epochs = 4000)
# Save the loss values after the loop
with open('train_loss_svrg.csv', 'w', newline='') as f1:
    writer = csv.writer(f1)
    writer.writerows([[val] for val in train_losses_svrg])

with open('test_loss_svrg.csv', 'w', newline='') as f2:
    writer = csv.writer(f2)
    writer.writerows([[val] for val in test_losses_svrg])

with open('test_train_loss_svrg.csv', 'w', newline='') as f3:
    writer = csv.writer(f3)
    writer.writerows([[val] for val in gen_errors_svrg])

