import os
import time

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from torchvision import datasets
from torchvision import transforms

import matplotlib.pyplot as plt
from PIL import Image

from proposed import RSDM
from baselines import RGD_QR

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

import pickle
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"


class DNN_cifar(nn.Module):
    def __init__(self):
        super(DNN_cifar, self).__init__()

        # Manually define custom weights and biases
        self.linear1_weight = nn.Parameter(torch.randn(1200, 1024))
        self.linear1_bias = nn.Parameter(torch.randn(1024))

        self.linear2_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear2_bias = nn.Parameter(torch.randn(1024))

        self.linear3_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear3_bias = nn.Parameter(torch.randn(1024))

        self.linear4_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear4_bias = nn.Parameter(torch.randn(1024))

        self.linear5_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear5_bias = nn.Parameter(torch.randn(1024))

        self.linear6_weight = nn.Parameter(torch.randn(1024, 10))
        self.linear6_bias = nn.Parameter(torch.randn(10))

        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.orthogonal_(self.linear1_weight)
        nn.init.orthogonal_(self.linear2_weight)
        nn.init.orthogonal_(self.linear3_weight)
        nn.init.orthogonal_(self.linear4_weight)
        nn.init.orthogonal_(self.linear5_weight)
        nn.init.xavier_uniform_(self.linear6_weight)

        nn.init.constant_(self.linear1_bias, 0)
        nn.init.constant_(self.linear2_bias, 0)
        nn.init.constant_(self.linear3_bias, 0)
        nn.init.constant_(self.linear4_bias, 0)
        nn.init.constant_(self.linear5_bias, 0)
        nn.init.constant_(self.linear6_bias, 0)

    def forward(self, x):
        x = F.relu(F.linear(x, self.linear1_weight.T, self.linear1_bias))
        x = F.relu(F.linear(x, self.linear2_weight.T, self.linear2_bias))
        x = F.relu(F.linear(x, self.linear3_weight.T, self.linear3_bias))
        x = F.relu(F.linear(x, self.linear4_weight.T, self.linear4_bias))
        x = F.relu(F.linear(x, self.linear5_weight.T, self.linear5_bias))
        x = F.linear(x, self.linear6_weight.T, self.linear6_bias)
        return x

class DNN(nn.Module):
    def __init__(self):
        super(DNN, self).__init__()

        # Manually define custom weights and biases
        self.linear1_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear1_bias = nn.Parameter(torch.randn(1024))

        self.linear2_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear2_bias = nn.Parameter(torch.randn(1024))

        self.linear3_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear3_bias = nn.Parameter(torch.randn(1024))

        self.linear4_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear4_bias = nn.Parameter(torch.randn(1024))

        self.linear5_weight = nn.Parameter(torch.randn(1024, 1024))
        self.linear5_bias = nn.Parameter(torch.randn(1024))

        self.linear6_weight = nn.Parameter(torch.randn(1024, 10))
        self.linear6_bias = nn.Parameter(torch.randn(10))

        self._initialize_weights()

    def _initialize_weights(self):
        nn.init.orthogonal_(self.linear1_weight)
        nn.init.orthogonal_(self.linear2_weight)
        nn.init.orthogonal_(self.linear3_weight)
        nn.init.orthogonal_(self.linear4_weight)
        nn.init.orthogonal_(self.linear5_weight)
        nn.init.xavier_uniform_(self.linear6_weight)

        nn.init.constant_(self.linear1_bias, 0)
        nn.init.constant_(self.linear2_bias, 0)
        nn.init.constant_(self.linear3_bias, 0)
        nn.init.constant_(self.linear4_bias, 0)
        nn.init.constant_(self.linear5_bias, 0)
        nn.init.constant_(self.linear6_bias, 0)

    def forward(self, x):
        x = F.relu(F.linear(x, self.linear1_weight.T, self.linear1_bias))
        x = F.relu(F.linear(x, self.linear2_weight.T, self.linear2_bias))
        x = F.relu(F.linear(x, self.linear3_weight.T, self.linear3_bias))
        x = F.relu(F.linear(x, self.linear4_weight.T, self.linear4_bias))
        x = F.relu(F.linear(x, self.linear5_weight.T, self.linear5_bias))
        x = F.linear(x, self.linear6_weight.T, self.linear6_bias)
        return x



def evaluate():
    model.eval()
    num_correct = 0
    num_samples = 0

    with torch.no_grad():
        for data, targets in test_loader:
            data = data.to(device)
            targets = targets.to(device)

            scores = model(data)
            _, predictions = torch.max(scores, 1)
            num_correct += (predictions == targets).sum()
            num_samples += predictions.size(0)

        acc = float(num_correct) / float(num_samples) * 100

    return acc



if __name__ == '__main__':

    DATASET = 'mnist'
    RANDOM_SEED = 42
    LEARNING_RATE = 0.1
    BATCH_SIZE = 16
    NUM_EPOCHS = 2
    # #
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    # if DATASET == 'mnist':
    #     transform = transforms.Compose([
    #         transforms.Resize((32, 32)),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5,), (0.5,)),
    #         transforms.Lambda(lambda x: x.view(-1))
    #     ])
    #
    #     train_dataset = datasets.MNIST(root='data',
    #                                    train=True,
    #                                    transform=transform,
    #                                    download=True)
    #
    #     test_dataset = datasets.MNIST(root='data',
    #                                   train=False,
    #                                   transform=transform)
    #
    #     train_loader = DataLoader(dataset=train_dataset,
    #                               batch_size=BATCH_SIZE,
    #                               shuffle=True)
    #
    #     test_loader = DataLoader(dataset=test_dataset,
    #                              batch_size=BATCH_SIZE,
    #                              shuffle=False)
    # elif DATASET == 'cifar':
    #     transform = transforms.Compose([
    #         transforms.Resize((20, 20)),
    #         transforms.ToTensor(),
    #         transforms.Normalize((0.5,), (0.5,)),
    #         transforms.Lambda(lambda x: x.view(-1))
    #     ])
    #
    #     train_dataset = datasets.CIFAR10(root='data',
    #                                    train=True,
    #                                    transform=transform,
    #                                    download=True)
    #
    #     test_dataset = datasets.CIFAR10(root='data',
    #                                   train=False,
    #                                   transform=transform)
    #
    #     train_loader = DataLoader(dataset=train_dataset,
    #                               batch_size=BATCH_SIZE,
    #                               shuffle=True)
    #
    #     test_loader = DataLoader(dataset=test_dataset,
    #                              batch_size=BATCH_SIZE,
    #                              shuffle=False)
    #
    #
    # method = 'RSDM-P'
    #
    # num_rep = 5
    #
    # loss_all = []
    # acc_all = []
    # time_all = []
    #
    # for rep in range(num_rep):
    #
    #     loss_rep = []
    #     acc_rep = []
    #     time_rep = []
    #
    #     # set up model and optimizer
    #     if DATASET == 'mnist':
    #         model = DNN()
    #         model = model.to(device)
    #     elif DATASET == 'cifar':
    #         model = DNN_cifar()
    #         model = model.to(device)
    #
    #     # Loss and optimizer
    #     criterion = nn.CrossEntropyLoss()
    #
    #     stiefel_params = []
    #     other_params = []
    #
    #     for name, param in model.named_parameters():
    #         if 'weight' in name and 'linear6' not in name:
    #             stiefel_params.append(param)
    #         else:
    #             other_params.append(param)
    #
    #     print(method)
    #
    #     if method == 'RSDM-P':
    #         if DATASET == 'mnist':
    #             optimizer_stiefel = RSDM(stiefel_params, lr=0.01, r=250, use_permutation=True)
    #         elif DATASET == 'cifar':
    #             optimizer_stiefel = RSDM(stiefel_params, lr=0.01, r=250, use_permutation=True)
    #     elif method == 'RSDM-O':
    #         optimizer_stiefel = RSDM(stiefel_params, lr=0.01, r=300, use_permutation=False)
    #     elif method == 'RGD':
    #         optimizer_stiefel = RGD_QR(stiefel_params, lr=0.01)
    #     optimizer_other = torch.optim.SGD(other_params, lr=LEARNING_RATE)
    #
    #     start_time = time.time()
    #     for epoch in range(NUM_EPOCHS):
    #         model.train()
    #         for batch_idx, (data, targets) in enumerate(train_loader):
    #             data = data.to(device)
    #             targets = targets.to(device)
    #
    #             scores = model(data)
    #             loss = criterion(scores, targets)
    #
    #             optimizer_stiefel.zero_grad()
    #             optimizer_other.zero_grad()
    #             loss.backward()
    #
    #             optimizer_stiefel.step()
    #             optimizer_other.step()
    #
    #             if batch_idx % 100 == 0:
    #                 elapsed_time = time.time() - start_time
    #                 acc = evaluate()
    #                 print(f'Epoch [{epoch + 1}/{NUM_EPOCHS}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}, Test ACC: {acc:.2f}%, '
    #                       f'Time: {elapsed_time:.4f}')
    #
    #                 loss_rep.append(loss.item())
    #                 time_rep.append(elapsed_time)
    #                 acc_rep.append(acc)
    #
    #                 start_time = time.time()
    #
    #     loss_all.append(loss_rep)
    #     time_all.append(time_rep)
    #     acc_all.append(acc_rep)
    #
    # loss_all = np.array(loss_all)
    # time_all = np.array(time_all)
    # acc_all = np.array(acc_all)
    #
    # arrays = {'loss': loss_all, 'time': time_all, 'acc': acc_all }
    #
    # # Save the arrays to a pickle file
    # with open(f'results/{method}_dnn_{DATASET}.pkl', 'wb') as f:
    #     pickle.dump(arrays, f)


    ##

    with open(f'results/RSDM-P_dnn_{DATASET}.pkl', 'rb') as f:
        RSDM_results = pickle.load(f)
        time_rsdm = RSDM_results['time']
        loss_rsdm = RSDM_results['loss']
        acc_rsdm = RSDM_results['acc']

        loss_rsdm_mean = loss_rsdm.mean(axis=0)
        loss_rsdm_std = loss_rsdm.std(axis=0)
        time_rsdm_mean = time_rsdm.mean(axis=0)
        acc_rsdm_mean = acc_rsdm.mean(axis=0)
        acc_rsdm_std = acc_rsdm.std(axis=0)

    with open(f'results/RGD_dnn_{DATASET}.pkl', 'rb') as f:
        RGD_results = pickle.load(f)
        time_rgd = RGD_results['time']
        loss_rgd = RGD_results['loss']
        acc_rgd = RGD_results['acc']

        loss_rgd_mean = loss_rgd.mean(axis=0)
        loss_rgd_std = loss_rgd.std(axis=0)
        time_rgd_mean = time_rgd.mean(axis=0)
        acc_rgd_mean = acc_rgd.mean(axis=0)
        acc_rgd_std = acc_rgd.std(axis=0)


    # # loss
    # plt.figure(figsize=(5.5, 4.5))
    # plt.plot(time_rgd_mean.cumsum(), loss_rgd_mean, label='RGD', color='tab:blue')
    # plt.fill_between(time_rgd_mean.cumsum(), loss_rgd_mean - loss_rgd_std,
    #                  loss_rgd_mean + loss_rgd_std,
    #                  color='tab:blue', alpha=0.3)
    # plt.plot(time_rsdm_mean.cumsum(), loss_rsdm_mean, label='RSDM-P', color='tab:purple')
    # plt.fill_between(time_rsdm_mean.cumsum(), loss_rsdm_mean - loss_rsdm_std,
    #                  loss_rsdm_mean + loss_rsdm_std,
    #                  color='tab:purple', alpha=0.3)
    # plt.legend(loc=7, prop={'size': 16})
    # plt.xticks(fontsize=13)
    # plt.yticks(fontsize=13)
    # plt.xlabel("Time", fontsize=20)
    # plt.ylabel("Loss", fontsize=20)
    # plt.tight_layout()
    # plt.savefig(f'dnn_loss_{DATASET}.pdf', bbox_inches='tight')
    # plt.close()
    #
    #
    # # time
    # plt.figure(figsize=(5.5, 4.5))
    # plt.plot(time_rgd_mean.cumsum(), acc_rgd_mean, label='RGD', color='tab:blue')
    # plt.fill_between(time_rgd_mean.cumsum(), acc_rgd_mean - acc_rgd_std,
    #                  acc_rgd_mean + acc_rgd_std,
    #                  color='tab:blue', alpha=0.3)
    # plt.plot(time_rsdm_mean.cumsum(), acc_rsdm_mean, label='RSDM-P', color='tab:purple')
    # plt.fill_between(time_rsdm_mean.cumsum(), acc_rsdm_mean - acc_rsdm_std,
    #                  acc_rsdm_mean + acc_rsdm_std,
    #                  color='tab:purple', alpha=0.3)
    # plt.legend(loc=7, prop={'size': 16}, bbox_to_anchor=(1, 0.7))
    # # plt.legend(loc='best')
    # plt.xticks(fontsize=13)
    # if DATASET == 'mnist':
    #     plt.xlim([0, 100])
    #     plt.yticks([10, 30, 50, 70, 90], fontsize=13)
    # elif DATASET == 'cifar':
    #     plt.xlim([0, 80])
    # plt.xlabel("Time", fontsize=20)
    # plt.ylabel("ACC (%)", fontsize=20)
    #
    #
    # # plt.yscale('log')
    # # plot zoomed in
    # if DATASET == 'mnist':
    #     # plt.arrow(10, 80, 12, -18, head_width=1, width=0.7, head_length=2, fc='tab:red', ec='tab:red')
    #     inset_axes = plt.axes((0.3, 0.2, 0.3, 0.35))
    #     inset_axes.plot(time_rgd_mean.cumsum(), acc_rgd_mean, color='tab:blue')
    #     inset_axes.fill_between(time_rgd_mean.cumsum(), acc_rgd_mean - acc_rgd_std,
    #                      acc_rgd_mean + acc_rgd_std,
    #                      color='tab:blue', alpha=0.3)
    #     inset_axes.plot(time_rsdm_mean.cumsum(), acc_rsdm_mean, color='tab:purple')
    #     inset_axes.fill_between(time_rsdm_mean.cumsum(), acc_rsdm_mean - acc_rsdm_std,
    #                      acc_rsdm_mean + acc_rsdm_std,
    #                      color='tab:purple', alpha=0.3)
    #     # Set the axis limits for the zoomed-in region
    #     inset_axes.set_xlim(0, 25)  # Zoom in on the x-axis from 2 to 6
    #     # # inset_axes.set_ylim(-0.5, 0.5)  # Zoom in on the y-axis from -0.5 to 0.5
    #
    # elif DATASET == 'cifar':
    #     inset_axes = plt.axes((0.3, 0.2, 0.3, 0.35))
    #     inset_axes.plot(time_rgd_mean.cumsum(), acc_rgd_mean, color='tab:blue')
    #     inset_axes.fill_between(time_rgd_mean.cumsum(), acc_rgd_mean - acc_rgd_std,
    #                             acc_rgd_mean + acc_rgd_std,
    #                             color='tab:blue', alpha=0.3)
    #     inset_axes.plot(time_rsdm_mean.cumsum(), acc_rsdm_mean, color='tab:purple')
    #     inset_axes.fill_between(time_rsdm_mean.cumsum(), acc_rsdm_mean - acc_rsdm_std,
    #                             acc_rsdm_mean + acc_rsdm_std,
    #                             color='tab:purple', alpha=0.3)
    #     # Set the axis limits for the zoomed-in region
    #     inset_axes.set_xlim(0, 25)  # Zoom in on the x-axis from 2 to 6
    #
    #
    #
    #
    # # plt.tight_layout()
    # plt.savefig(f'dnn_acc_{DATASET}.pdf', bbox_inches='tight')
    # plt.close()



    # acc in iteration
    plt.figure(figsize=(5.5, 4.5))
    plt.plot(acc_rgd_mean, label='RGD', color='tab:blue')
    plt.fill_between((np.arange(len(acc_rgd_mean))), acc_rgd_mean - acc_rgd_std,
                     acc_rgd_mean + acc_rgd_std,
                     color='tab:blue', alpha=0.3)
    plt.plot(acc_rsdm_mean, label='RSDM-P', color='tab:purple')
    plt.fill_between(np.arange(len(acc_rsdm_mean)), acc_rsdm_mean - acc_rsdm_std,
                     acc_rsdm_mean + acc_rsdm_std,
                     color='tab:purple', alpha=0.3)
    plt.legend(loc=7, prop={'size': 16})
    plt.xticks(fontsize=13)
    plt.yticks(fontsize=13)
    plt.xlabel("Iteration", fontsize=20)
    plt.ylabel("ACC (%)", fontsize=20)
    plt.tight_layout()
    plt.savefig(f'dnn_acc_{DATASET}_iter.pdf', bbox_inches='tight')
    plt.close()





