import torch
import numpy as np
import pickle
import os
from Optimization_Method import projection_simplex_sort as pj
from model_SSAGDA import Model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CUDA not available")

def SSAGDA(train_set, data_name, p, tau_1, tau_2, beta, b, sim_time, max_epoch, epoch_number, 
           is_show_result=False, is_save_data=False, is_save_grad_data=False, device = device):
    directory = f'./result_data/{data_name}'
    os.makedirs(directory, exist_ok=True)
    # i = 1
    final_acc_SSAGDA, record_SSAGDA = [], []

    # Pre-load data to GPU if possible
    train_data_gpu = train_set.data.to(device)
    train_targets_gpu = train_set.targets.to(device)

    all_x_grad, all_y_grad, half_x_grad, half_y_grad = [], [], [], []

    for s in range(sim_time):
        print(s)
        model = Model(data_size=(len(train_set.data[0]), len(train_set.targets))).to(device)
        record_SSAGDA_sub, acc, x_iterates_grad, y_iterates_grad = [], [], [], []
        epoch_SSAGDA = [0]
        epoch, sample_complexity = 0, 0

        # Avoid re-initializing these every loop iteration
        full_batch = torch.arange(len(train_set.targets)).to(device)
        test1 = model
        output = [torch.zeros_like(param) for param in test1.parameters()]
        Z0 = test1.w.clone().detach().requires_grad_(False) #: Z0 is the proximal center

        while epoch < max_epoch + 10:
            data_loader = torch.randperm(len(train_targets_gpu), device=device)
            batch_start = 0

            while batch_start < len(data_loader):
                batch_index = data_loader[batch_start:batch_start+b]
                if len(batch_index) < b:
                    break
                batch_start += b
                sample_complexity += b

                data = torch.index_select(train_data_gpu, 0, batch_index)
                target = torch.index_select(train_targets_gpu, 0, batch_index)

                test1.zero_grad()
                loss = test1.loss(test1(data), batch_index, target)
                loss.backward()

                for (name, param), output_param in zip(test1.named_parameters(), output):
                    if name != 'variable_y':
                        grad = param.grad
                        x_iterates_grad.append(grad)
                        param.data -= tau_1 * (grad + p * (param.data - model.w.data))
                        model.w.data += beta * (param.data - model.w.data)
                        output_param.data = param.data
                    else:
                        grad = param.grad
                        temp_grad = grad.clone()
                        y_iterates_grad.append(pj(temp_grad.cpu()))
                        # y_iterates_grad.append(grad)
                        param.data += tau_2 * grad
                        param.data = pj(param.data.cpu()).to(device)

            if sample_complexity // epoch_number > epoch:
                epoch = sample_complexity // epoch_number
                # if epoch == (max_epoch / 2) // 1:
                #     random_index = np.random.randint(0, len(x_iterates_grad))
                #     half_x_grad.append(x_iterates_grad[random_index])
                #     half_y_grad.append(y_iterates_grad[random_index])
                loss_val = test1.loss(test1(train_data_gpu), full_batch, train_targets_gpu)
                acc_val = (test1.predict(train_data_gpu) == train_targets_gpu).float().mean().item()
                record_SSAGDA_sub.append(loss_val.item())
                acc.append(acc_val)
                if is_show_result:
                    print(f'Sample complexity: {sample_complexity}, Epoch: {epoch}, Accuracy: {acc_val}, Loss: {loss_val.item()}')
            
        # Randomly sample one iterate from x and y iterates
        random_index = np.random.randint(0, len(x_iterates_grad))
        all_x_grad.append(x_iterates_grad[random_index])
        all_y_grad.append(y_iterates_grad[random_index])

        final_acc_SSAGDA.append(acc[-1])
        record_SSAGDA.append(record_SSAGDA_sub)
        if is_save_data:
            file_path = os.path.join(directory, f'SSAGDA_maxepoch={max_epoch}_epochnum={epoch_number}_tau1={tau_1}_tau2={tau_2}_beta={beta}_p={p}_b={b}.pkl')
            with open(file_path, 'wb') as fp:
                pickle.dump([record_SSAGDA, epoch_SSAGDA, final_acc_SSAGDA], fp)

    if is_save_grad_data:
        # half_grad_samples_file = os.path.join(directory, f'{i}_SSAGDA_HALF_grad_samples_simtime={sim_time}_maxepoch={max_epoch}_epochnum={epoch_number}_tau1={tau_1}_tau2={tau_2}_beta={beta}_p={p}_b={b}.pkl')
        # with open(half_grad_samples_file, 'wb') as fp:
        #     pickle.dump([half_x_grad, half_y_grad], fp)
        all_grad_samples_file = os.path.join(directory, f'TWENTY_SSAGDA_ALL_grad_sample_simtime={sim_time}_maxepoch={max_epoch}_epochnum={epoch_number}_tau1={tau_1}_tau2={tau_2}_beta={beta}_p={p}_b={b}.pkl')
        with open(all_grad_samples_file, 'wb') as fp:
            pickle.dump([all_x_grad, all_y_grad], fp)