# from ctypes.wintypes import MAX_PATH
# from typing import final
# import torch
# import numpy as np
# import copy
# import pickle
# from Optimization_Method import projection_simplex_sort as pj
# from model_SSAGDA import Model
# from dataclass import Creatdata
# import os

# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# torch.cuda.get_device_name(0)

# #main of Stochastic smoothed AGDA (SSAGDA)  
# def SSAGDA(train_set, data_name, p, tau_1, tau_2, beta, T, sigma,
#         b, sim_time, max_epoch, epoch_number, 
#         is_show_result = False, is_save_data = False, is_save_grad_data = False, device = 'cuda'): 
#     #initialize key parameters
#     # tau_1 = min(np.sqrt(Delta)/(2*sigma*np.sqrt(T*L)), 1/(3*L))
#     # tau_2 = min(np.sqrt(Delta)/(96*sigma*np.sqrt(T*L)), 1/(144*L))
#     # tau_1 = 0.01
#     # tau_2 = tau_1/48
#     # beta = 0.00001

#     # Initialize directories for saving later
#     directory = f'./result_data/{data_name}'
#     os.makedirs(directory, exist_ok=True)

#     #initialize the result
#     final_acc_SSAGDA = []
#     record_SSAGDA = []
#     acc = []
#     data_size = (len(train_set.data[0]),len(train_set.targets))

#     # keeping track of random samples
#     all_x_grad_samples = []
#     all_y_grad_samples = []

#     for s in range(sim_time):
#         #initialize for this simulation
#         epoch_SSAGDA = []
#         record_SSAGDA_sub = []
#         epoch, sample_complexity = 0, 0

#         #load the start model
#         test1 = Model(data_size=data_size).to(device)
#         # test1.load_state_dict(copy.deepcopy(test1.state_dict()))
#         epoch_SSAGDA.append(0)
#         full_batch = torch.arange(len(train_set.targets)).to(device)

#         cpu_test =  Model(data_size=data_size)
#         # cpu_test.load_state_dict(copy.deepcopy(test1.state_dict()))
#         record_SSAGDA_sub.append(cpu_test.testloss(cpu_test.forward(train_set.data.to('cpu')),full_batch.to('cpu'),train_set.targets.to('cpu')).detach().numpy())

#         #initilize the output
#         output = [param.clone().detach() for param in test1.parameters()]
#         Z0 = test1.w.clone().detach().requires_grad_(False) #: Z0 is the proximal center

#         x_iterates_grad = []
#         y_iterates_grad = []

#         # outer loop
#         while 1:
             
#             iter = 0 #:the iteration number for the inner loop
#             batch_start = 0 #:the start index of batch
#             data_loader_dumb = torch.randperm(len(train_set.targets), device=device)

#             while 1:   
#                 #re-initial the output as 0 vectors
#                 output = []
#                 Z0_updated = []

#                 for param in test1.parameters():
#                     output.append(torch.zeros_like(param))

#                 #generate the batch select data by batch index
#                 if batch_start+b < len(data_loader_dumb):
#                     batch_index = data_loader_dumb[batch_start:batch_start+b]
#                     batch_start += b
#                     sample_complexity += b
#                 else:
#                     #drop the incomplete data if they can not form a full batch
#                     data_loader_dumb = torch.randperm(len(train_set)).to(device)
#                     batch_start = 0
#                     continue

#                 data = torch.index_select(train_set.data,0,index=batch_index) 
#                 target = torch.index_select(train_set.targets,0,index=batch_index)

#                 #compute the primal gradient of x_k,y_k
#                 test1.zero_grad()
#                 test1.loss(test1.forward(data),batch_index, target).backward()
#                 oracles_primal = []
#                 for name,param in test1.named_parameters():
#                     oracles_primal.append(param.grad) #:although we don't need primal here, but to use zip properly later, we need to have the same shape

#                 #update primal variables to x_t+1 based on x_t,y_t,z_t
#                 for (name,param), grad, z0 in zip(test1.named_parameters(), oracles_primal, Z0):
#                     if name != 'variable_y':
#                         new_val = param - tau_1*(grad + p*(param - z0))
#                         z0_updated = z0 + beta*(new_val - z0)
#                         Z0_updated.append(z0_updated)
#                         param.data = new_val.data
#                         x_iterates_grad.append(grad)

#                 Z0 = Z0_updated
                
#                 #compute the dual gradient of x_k,y_k+1 only for dual variable
#                 test1.zero_grad()
#                 test1.loss(test1.forward(data),batch_index, target).backward()
#                 oracles_dual = [] #:oracles_dual is vt
#                 for name,param in test1.named_parameters():
#                     oracles_dual.append(param.grad)

#                 #update dual variables to y_k+1
#                 for (name,param),grad in zip(test1.named_parameters(),oracles_dual):
#                     if name == 'variable_y':
#                         new_val = param + tau_2*grad
#                         param.data = torch.tensor(pj(new_val.cpu().detach().numpy()),dtype=torch.float32).to(device)
#                         y_iterates_grad.append(grad)

#                 #record the output of each inner loop and take the average
#                 for (name,param1), param2 in zip(test1.named_parameters(),output):
#                     param2.data =  param1.data

#                 iter += 1
#                 if iter == T:
#                     break
                
#             if sample_complexity//epoch_number>epoch:
#                 epoch = sample_complexity//epoch_number
#                 epoch_SSAGDA.append(epoch)
#                 cpu_test =  Model(data_size=data_size)
#                 cpu_test.load_state_dict(copy.deepcopy(test1.state_dict()))
 
#                 record_SSAGDA_sub.append(cpu_test.testloss(cpu_test.forward(train_set.data.to('cpu')),full_batch.to('cpu'),train_set.targets.to('cpu')).detach().numpy())
#                 acc.append(torch.sum(cpu_test.predict(train_set.data.to('cpu'))==train_set.targets.to('cpu'))/len(train_set.data))
#                 if is_show_result:
#                     print('sample complexity is', sample_complexity, ', epoch is', epoch, ', acc is', acc[-1], ', loss is', record_SSAGDA_sub[-1])
                
#                 if epoch >= max_epoch+10:
#                     break

#         # Randomly sample one iterate from x and y iterates
#         random_index = np.random.randint(0, T)
#         all_x_grad_samples.append(x_iterates_grad[random_index])
#         all_y_grad_samples.append(y_iterates_grad[random_index])

#         #save this simulation result
#         print('')
#         print('Simulation time ', s+1, ' is done.....')
#         print('sample complexity is', sample_complexity, ', epoch is', epoch, ', acc is', acc[-1], ', loss is', record_SSAGDA_sub[-1])
#         print('')
#         record_SSAGDA.append(record_SSAGDA_sub)
#         final_acc_SSAGDA.append(acc[-1])

#         if is_save_data:
#             # Now construct the full file path
#             file_name = directory + '/SSAGDA_' + 'T =' + str(T) + 'tau_1 =' + str(tau_1) + 'tau_2 = ' + str(tau_2) + 'beta =' + str(beta) + 'p =' + str(p) + 'b =' + str(b)

#             # Safely open the file and write the data
#             with open(file_name, "wb") as fp:
#                 pickle.dump([record_SSAGDA, epoch_SSAGDA, final_acc_SSAGDA], fp)

#     # Save random iterates data to a file
#     if is_save_grad_data:
#         # Now construct the full file path
#         file_name = directory + '/SSAGDA_' + 'rand_samples' + 'T =' + str(T) + 'tau_1 =' + str(tau_1) + 'tau_2 = ' + str(tau_2) + 'beta =' + str(beta) + 'p =' + str(p) + 'b =' + str(b)
#         # Safely open the file and write the data
#         with open(file_name, "wb") as fp:
#             pickle.dump([all_x_grad_samples, all_y_grad_samples], fp)

import torch
import numpy as np
import pickle
import os
from Optimization_Method import projection_simplex_sort as pj
from model_SSAGDA import Model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CUDA not available")

def SSAGDA(train_set, data_name, p, tau_1, tau_2, beta, b, sim_time, max_epoch, epoch_number, 
           is_show_result=False, is_save_data=False, is_save_grad_data=False, device = device):
    directory = f'./result_data/{data_name}'
    os.makedirs(directory, exist_ok=True)

    final_acc_SSAGDA, record_SSAGDA = [], []

    # Pre-load data to GPU if possible
    train_data_gpu = train_set.data.to(device)
    train_targets_gpu = train_set.targets.to(device)

    all_x_grad, all_y_grad = [], []

    for s in range(sim_time):
        print(s)
        model = Model(data_size=(len(train_set.data[0]), len(train_set.targets))).to(device)
        record_SSAGDA_sub, acc, x_iterates_grad, y_iterates_grad = [], [], [], []
        epoch_SSAGDA = [0]
        epoch, sample_complexity = 0, 0

        # Avoid re-initializing these every loop iteration
        full_batch = torch.arange(len(train_set.targets)).to(device)
        test1 = model
        output = [torch.zeros_like(param) for param in test1.parameters()]

        while epoch < max_epoch + 10:
            data_loader = torch.randperm(len(train_targets_gpu), device=device)
            batch_start = 0

            while batch_start < len(data_loader):
                batch_index = data_loader[batch_start:batch_start+b]
                if len(batch_index) < b:
                    break
                batch_start += b
                sample_complexity += b

                data = torch.index_select(train_data_gpu, 0, batch_index)
                target = torch.index_select(train_targets_gpu, 0, batch_index)

                test1.zero_grad()
                loss = test1.loss(test1(data), batch_index, target)
                loss.backward()

                for (name, param), output_param in zip(test1.named_parameters(), output):
                    if name != 'variable_y':
                        grad = param.grad
                        x_iterates_grad.append(grad)
                        param.data -= tau_1 * (grad + p * (param.data - model.w.data))
                        model.w.data += beta * (param.data - model.w.data)
                        output_param.data = param.data
                    else:
                        grad = param.grad
                        y_iterates_grad.append(grad)
                        param.data += tau_2 * grad
                        param.data = pj(param.data.cpu()).to(device)

            if sample_complexity // epoch_number > epoch:
                epoch = sample_complexity // epoch_number
                loss_val = test1.loss(test1(train_data_gpu), full_batch, train_targets_gpu)
                acc_val = (test1.predict(train_data_gpu) == train_targets_gpu).float().mean().item()
                record_SSAGDA_sub.append(loss_val.item())
                acc.append(acc_val)
                if is_show_result:
                    print(f'Sample complexity: {sample_complexity}, Epoch: {epoch}, Accuracy: {acc_val}, Loss: {loss_val.item()}')
            
        # Randomly sample one iterate from x and y iterates
        random_index = np.random.randint(0, len(x_iterates_grad))
        all_x_grad.append(x_iterates_grad[random_index])
        all_y_grad.append(y_iterates_grad[random_index])

        final_acc_SSAGDA.append(acc[-1])
        record_SSAGDA.append(record_SSAGDA_sub)
        if is_save_data:
            file_path = os.path.join(directory, f'SSAGDA_maxepoch={max_epoch}_epochnum={epoch_number}_tau1={tau_1}_tau2={tau_2}_beta={beta}_p={p}_b={b}.pkl')
            with open(file_path, 'wb') as fp:
                pickle.dump([record_SSAGDA, epoch_SSAGDA, final_acc_SSAGDA], fp)

    if is_save_grad_data:
        grad_samples_file = os.path.join(directory, f'SSAGDA_rand_samples_maxepoch={max_epoch}_epochnum={epoch_number}_tau1={tau_1}_tau2={tau_2}_beta={beta}_p={p}_b={b}.pkl')
        with open(grad_samples_file, 'wb') as fp:
            pickle.dump([all_x_grad, all_y_grad], fp)






