import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import copy
from model_SSAGDA import Model
import pickle
import matplotlib.pyplot as plt
import random
from dataclass import Creatdata
from alg_SSAGDA_optimized import SSAGDA

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.get_device_name(0)

data_name = 'a9a'

is_create_data = True

if is_create_data:
    data_path = './data/'+ data_name +'.py'
    exec(open(data_path).read())
else:
    file_name = './data/' + data_name + '/' + data_name
    with open(file_name, "rb") as fp:   # Unpickling
        train_set = pickle.load(fp)

train_set.data = train_set.data.to(device)
train_set.targets = train_set.targets.to(device)

# Take a random sample of 1000 data points to reduce computational cost for now
# Set the seed for reproducibility
# np.random.seed(13)
# # Extract data and targets
# data = train_set.data
# targets = train_set.targets

# # Create indices to sample from
# indices = np.arange(len(targets))

# size_of_sample = 10000

# # Select random indices
# random_indices = np.random.choice(indices, size_of_sample, replace=False)

# # Sample the data and targets using the random indices
# sampled_data = data[random_indices]
# sampled_targets = targets[random_indices]

# # Create a new Creatdata object with the sampled data and targets
# sampled_train_set = Creatdata(data=sampled_data, targets=sampled_targets)
# # sampled_train_set = train_set

# # Move data and targets to device if needed
# sampled_train_set.data = sampled_train_set.data.to(device)
# sampled_train_set.targets = sampled_train_set.targets.to(device)

# train_set = sampled_train_set

is_show_result = False
is_save_data = False
is_save_grad_data = True

p, tau_1, tau_2, beta, b = 160, 0.1, 0.002, 0.0001, 1028
max_epoch, epoch_number = 20, 30000
sim_time = 200

SSAGDA(train_set = train_set, data_name = data_name, p = p, tau_1 = tau_1, tau_2 = tau_2, beta = beta,
        b = b, sim_time = sim_time, max_epoch = max_epoch, epoch_number = epoch_number, 
        is_show_result = is_show_result, is_save_data = is_save_data, is_save_grad_data = is_save_grad_data, device = device)