import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from model import Model
from Optimization_Method import projection_simplex_sort as pj
import pickle
import matplotlib.pyplot as plt
import random
from dataclass import Creatdata

import os


# print("Current working directory:", os.getcwd())

# # Specify the new directory path
# new_directory = 'C:/Users/sysa1/Documents/Research/Optimization/Research code/Stochastic smoothed AGDA/DRO'

# # Change the current working directory
# os.chdir(new_directory)

#example of pickle
# l = [1,2,3,4]
# with open("test", "wb") as fp:   #Pickling
#     pickle.dump(l, fp)

# with open("test", "rb") as fp:   # Unpickling
#     b = pickle.load(fp)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.get_device_name(0)

data_name = 'sido0'

is_create_data = True

if is_create_data:
    data_path = './data/'+ data_name +'.py'
    print(data_path)
    exec(open(data_path).read())
else:
    file_name = './data/' + data_name + '/' + data_name
    with open(file_name, "rb") as fp:   # Unpickling
        train_set = pickle.load(fp)

train_set.data = train_set.data.to(device)
train_set.targets = train_set.targets.to(device)

# print('Example of the data')
# print(train_set.data[0])
# print(len(train_set.data))

is_show_result = False
is_save_data= False
is_save_grad_data = True

p, tau_1, tau_2, beta, b = 160, 0.001, 0.0002, 0.00001, 1028
max_epoch, epoch_number = 20, 12678
sim_time =  200

from alg_SSAGDA_optimized import SSAGDA

SSAGDA(train_set = train_set, data_name = data_name, p = p, tau_1 = tau_1, tau_2 = tau_2, beta = beta,
        b = b, sim_time = sim_time, max_epoch = max_epoch, epoch_number = epoch_number, 
        is_show_result = is_show_result, is_save_data = is_save_data, is_save_grad_data = is_save_grad_data, device = device)