from ctypes.wintypes import MAX_PATH
from typing import final
import torch
import numpy as np
import copy
import pickle
import argparse
import os
from model_SSAGDA import Model
from dataclass import Creatdata

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.get_device_name(0)
torch.cuda.empty_cache()

# Main function of Stochastic Smoothed AGDA (SSAGDA)
def SSAGDA(train_set, data_name, p, tau_1, tau_2, beta, T, sigma,
           b, sim_time, max_epoch, epoch_number,
           is_show_result=False, is_save_data=False, is_save_grad_data=False, device='cuda'):
    # Initialize the result
    final_acc_SSAGDA = []
    record_SSAGDA = []
    acc = []
    data_size = (len(train_set.data[0]), len(train_set.targets))

    # Keeping track of random samples
    all_x_grad_samples = []
    all_y_grad_samples = []

    for s in range(sim_time):
        print(f'Starting simulation {s}')
        start_model = Model(data_size=data_size).to(device)
        # Initialize for this simulation
        epoch_SSAGDA = []
        record_SSAGDA_sub = []
        epoch, sample_complexity = 0, 0

        # Load the start model
        test1 = Model(data_size=data_size).to(device)
        test1.load_state_dict(copy.deepcopy(start_model.state_dict()))
        epoch_SSAGDA.append(0)
        full_batch = torch.arange(len(train_set.targets)).to(device)

        cpu_test = Model(data_size=data_size)
        cpu_test.load_state_dict(copy.deepcopy(test1.state_dict()))
        record_SSAGDA_sub.append(cpu_test.testloss(cpu_test.forward(train_set.data.to('cpu')), full_batch.to('cpu'), train_set.targets.to('cpu')).detach().numpy())

        # Initialize the output
        output = [param.clone().detach() for param in test1.parameters()]
        Z0 = test1.w.clone().detach().requires_grad_(False)  # Z0 is the proximal center

        x_iterates_grad = []
        y_iterates_grad = []

        # Outer loop
        while True:
            iter = 0  # The iteration number for the inner loop
            batch_start = 0  # The start index of the batch
            data_loader_dumb = torch.randperm(len(train_set)).to(device)

            while True:
                # Reinitialize the output as zero vectors
                output = []
                Z0_updated = []

                for param in test1.parameters():
                    output.append(torch.zeros_like(param))

                # Generate the batch select data by batch index
                if batch_start + b < len(data_loader_dumb):
                    batch_index = data_loader_dumb[batch_start:batch_start + b]
                    batch_start += b
                    sample_complexity += b
                else:
                    # Drop the incomplete data if they cannot form a full batch
                    data_loader_dumb = torch.randperm(len(train_set)).to(device)
                    batch_start = 0
                    continue

                data = torch.index_select(train_set.data, 0, index=batch_index)
                target = torch.index_select(train_set.targets, 0, index=batch_index)

                # Compute the primal gradient of x_k, y_k
                test1.zero_grad()
                test1.loss(test1.forward(data), batch_index, target).backward()
                oracles_primal = [param.grad for name, param in test1.named_parameters()]

                # Update primal variables to x_t+1 based on x_t, y_t, z_t
                for (name, param), grad, z0 in zip(test1.named_parameters(), oracles_primal, Z0):
                    if name != 'variable_y':
                        new_val = param - tau_1 * (grad + p * (param - z0))
                        z0_updated = z0 + beta * (new_val - z0)
                        Z0_updated.append(z0_updated)
                        param.data = new_val.data
                        x_iterates_grad.append(grad)

                Z0 = Z0_updated

                # Compute the dual gradient of x_k, y_k+1 only for dual variable
                test1.zero_grad()
                test1.loss(test1.forward(data), batch_index, target).backward()
                oracles_dual = [param.grad for name, param in test1.named_parameters()]

                # Update dual variables to y_k+1
                for (name, param), grad in zip(test1.named_parameters(), oracles_dual):
                    if name == 'variable_y':
                        new_val = param + tau_2 * grad
                        param.data = new_val.data
                        y_iterates_grad.append(grad)

                # Record the output of each inner loop and take the average
                for (name, param1), param2 in zip(test1.named_parameters(), output):
                    param2.data = param1.data

                iter += 1
                if iter == T:
                    break

            if sample_complexity // epoch_number > epoch:
                epoch = sample_complexity // epoch_number
                epoch_SSAGDA.append(epoch)
                cpu_test = Model(data_size=data_size)
                cpu_test.load_state_dict(copy.deepcopy(test1.state_dict()))

                record_SSAGDA_sub.append(cpu_test.testloss(cpu_test.forward(train_set.data.to('cpu')), full_batch.to('cpu'), train_set.targets.to('cpu')).detach().numpy())
                acc.append(torch.sum(cpu_test.predict(train_set.data.to('cpu')) == train_set.targets.to('cpu')) / len(train_set.data))
                if is_show_result:
                    print('sample complexity is', sample_complexity, ', epoch is', epoch, ', acc is', acc[-1], ', loss is', record_SSAGDA_sub[-1])

                if epoch >= max_epoch + 10:
                    break

        # Randomly sample one iterate from x and y iterates
        random_index = np.random.randint(0, T)
        all_x_grad_samples.append(x_iterates_grad[random_index])
        all_y_grad_samples.append(y_iterates_grad[random_index])

        # Save this simulation result
        print('')
        print('Simulation time ', s + 1, ' is done.....')
        print('sample complexity is', sample_complexity, ', epoch is', epoch, ', acc is', acc[-1], ', loss is', record_SSAGDA_sub[-1])
        print('')
        record_SSAGDA.append(record_SSAGDA_sub)
        final_acc_SSAGDA.append(acc[-1])

        if is_save_data:
            # Construct the directory path
            directory = os.path.join('/home/yas33/StochSmoothedAGDA/DRO/result_data', data_name)
            # Ensure the directory exists
            os.makedirs(directory, exist_ok=True)

            # Now construct the full file path
            file_name = os.path.join(directory, f'SSAGDA_sim_time={sim_time}_T={T}_tau_1={tau_1}_tau_2={tau_2}_beta={beta}_p={p}_b={b}')

            # Safely open the file and write the data
            with open(file_name, "wb") as fp:
                pickle.dump([record_SSAGDA, epoch_SSAGDA, final_acc_SSAGDA], fp)

    # Save random iterates data to a file
    if is_save_grad_data:
        directory = os.path.join('/home/yas33/StochSmoothedAGDA/DRO/result_data', data_name)
        os.makedirs(directory, exist_ok=True)
        # Now construct the full file path
        file_name = os.path.join(directory, f'SSAGDA_rand_samples_sim_time={sim_time}_T={T}_tau_1={tau_1}_tau_2={tau_2}_beta={beta}_p={p}_b={b}')
        # Safely open the file and write the data
        with open(file_name, "wb") as fp:
            pickle.dump([all_x_grad_samples, all_y_grad_samples], fp)