import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from model_SSAGDA import Model
from Optimization_Method import projection_simplex_sort as pj
import pickle
import matplotlib.pyplot as plt
import random
from dataclass import Creatdata
#example of pickle
# l = [1,2,3,4]
# with open("test", "wb") as fp:   #Pickling
#     pickle.dump(l, fp)

# with open("test", "rb") as fp:   # Unpickling
#     b = pickle.load(fp)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.get_device_name(0)

data_name = 'gisette'

is_create_data = True

if is_create_data:
    data_path = './data/'+ data_name +'.py'
    exec(open(data_path).read())
else:
    file_name = './data/' + data_name + '/' + data_name
    with open(file_name, "rb") as fp:   # Unpickling
        train_set = pickle.load(fp)

train_set.data = train_set.data.to(device)
train_set.targets = train_set.targets.to(device)

print('Example of the data')
print(train_set.data[0])
print(len(train_set))
print(train_set.targets)

print(len(train_set.data[0]), len(train_set.targets))

is_show_result = False
is_save_result = False
is_save_grad_data = True

p, tau_1, tau_2, beta, b = 160, 0.001, 0.0002, 0.00001, 256
max_epoch, epoch_number = 20, 6000
sim_time = 200

from alg_SSAGDA_optimized import SSAGDA

SSAGDA(train_set = train_set, data_name = data_name, p = p, tau_1 = tau_1, tau_2 = tau_2, beta = beta,
        b = b, sim_time = sim_time, max_epoch = max_epoch, epoch_number = epoch_number, 
        is_show_result = is_show_result, is_save_data = is_save_result, is_save_grad_data = is_save_grad_data, device = device)