import torch
from torch.utils.data import Subset, DataLoader, Dataset, random_split
import os
import argparse
import numpy as np
import random
from cuda_selector import auto_cuda
from tqdm import tqdm


parser = argparse.ArgumentParser(description="CelebA WAE", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--data_dir", type=str, default="dataset", help="directory of CelebA")
parser.add_argument("--save_dir", type=str, default="models/", help="directory to save models")
parser.add_argument("--seed", type=int, default=12345678, help="random seed")
parser.add_argument("--ridge_lambda", type=float, default=0.1, help="k")
parser.add_argument("--z_dim", type=int, default=20)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--target_class", type=int, default=0)
parser.add_argument("--SGD_batch_size", type=int, default=32)
parser.add_argument("--SGD_epochs", type=int, default=100)
parser.add_argument("--SGD_lr", type=float, default=0.01)

args = parser.parse_args()

# -------------------- Seed ------------------------------
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)  # cpu
torch.cuda.manual_seed_all(args.seed)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
# torch.backends.cudnn.enabled = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['PYTHONHASHSEED'] = str(args.seed)

cuda_available = torch.cuda.is_available()


save_dir = args.save_dir
data_root = args.data_dir #'/home/kychong/Documents/dataset'
# ------------------- Select Device -------------------------------
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

if cuda_available:
    device_id = int(auto_cuda()[-1])
    torch.cuda.set_device(device_id)
    device = torch.device('cuda')
    print('Using cuda:{}'.format(torch.cuda.current_device()))

else:
    device = 'cpu'
    print('Using CPU')

# ------------------- Load Parameters -------------------------------
z_dim = args.z_dim
batch_size=args.batch_size


# ------------------- Load Models -------------------------------
WAE_dir = 'MNIST_WAE_zdim{}_cornercrop_all.pth'.format( z_dim)
checkpoint = torch.load(WAE_dir)
from models import MNISTEncoder
EX = MNISTEncoder(z_dim=args.z_dim).to(device)
EX.load_state_dict(checkpoint['encoder'])
for p in EX.parameters():
    p.requires_grad = False

from models import MNIST_NN_CornerCrop
nn_model = MNIST_NN_CornerCrop(z_dim=z_dim).to(device)
checkpoint = torch.load('MNIST_NN_zdim{}_cornercrop_all_{}_WAE.pth'.format(z_dim, args.target_class))
nn_model.load_state_dict(checkpoint['nn_model'])
nn_model.eval()

W1 = nn_model.main[0].weight.detach()
W2 = nn_model.main[2].weight.detach()
B0 =nn_model.main[2].bias.detach()
l2, l1 = W1.size()
l3, l2 = W2.size()

# ------------------- Load and preprocess data -------------------------------

target_class_list = list()
target_class_loss1 = list()
target_class_loss2 = list()
target_class_loss3 = list()
target_class_loss4 = list()
target_class_loss5 = list()
from functions import load_MNIST_Cornercrop, evaluate_loss, estimate_dW1, estimate_dW2, update_SGD
import pandas as pd

attr_train_dataset = load_MNIST_Cornercrop(split='Train', target_class=args.target_class, data_dir = data_root)
attr_valid_dataset =load_MNIST_Cornercrop(split='Valid', target_class=args.target_class, data_dir = data_root)
attr_test_dataset = load_MNIST_Cornercrop(split='Test', target_class=args.target_class, data_dir = data_root)

attr_train_loader = DataLoader(attr_train_dataset, batch_size=args.batch_size, shuffle=False)
attr_val_loader = DataLoader(attr_valid_dataset, batch_size=args.batch_size, shuffle=False)
attr_test_loader = DataLoader(attr_test_dataset, batch_size=args.batch_size, shuffle=False)

X_list = []
Y_list = []
X_test_list = []
Y_test_list = []
X_valid_list = []
Y_valid_list = []

with torch.no_grad():
    for img, y in tqdm(attr_train_loader):
        img, y = img.to(device), y.to(device)
        zx = EX(img)
        X_list.append(zx)
        Y_list.append(y.view(-1, 14 * 14))

    for img, y in tqdm(attr_test_loader):
        img, y = img.to(device), y.to(device)
        zx = EX(img)
        X_test_list.append(zx)
        Y_test_list.append(y.view(-1, 14 * 14))

    for img, y in tqdm(attr_val_loader):
        img, y = img.to(device), y.to(device)
        zx = EX(img)
        X_valid_list.append(zx)
        Y_valid_list.append(y.view(-1, 14 * 14))

    X = torch.concat(X_list, dim=0)
    Y = torch.concat(Y_list, dim=0)
    X_test = torch.concat(X_test_list, dim=0)
    Y_test = torch.concat(Y_test_list, dim=0)
    X_valid = torch.concat(X_valid_list, dim=0)
    Y_valid = torch.concat(Y_valid_list, dim=0)
    Y = Y.float()
    Y_test = Y_test.float()
    Y_valid = Y_valid.float()
    n, l3 = Y.size()
    _, p = X.size()

for k in [1,2,3]:

    loss1 = evaluate_loss(X_test.detach(), Y_test.detach(), W1.detach(), W2.detach(), B0.detach())



    updated_W2, Z0, Z2_k = estimate_dW2(X.detach(), Y.detach(), W1.detach(), B0.detach(), W2.detach(), args.ridge_lambda, k)
    loss2 = evaluate_loss(X_test.detach(), Y_test.detach(), W1.detach(), updated_W2.detach(), (B0+Z0).detach())


    updated_W1, Z1_k = estimate_dW1(X.detach(), Y.detach(), W1.detach(), k, args.ridge_lambda,  device)
    updated_W2, Z0, Z2_k = estimate_dW2(X.detach(), Y.detach(), updated_W1.detach(),
                                        B0.detach(), W2.detach(), args.ridge_lambda, k )
    loss3 = evaluate_loss(X_test.detach(), Y_test.detach(),  updated_W1.detach(),
                          updated_W2.detach(), (B0+Z0).detach())



    # SGD -- Stein Init
    A, S1, B = torch.linalg.svd(Z1_k, full_matrices=False)
    A = A * 0.001
    B = torch.diag(S1) @ B * 0.1

    C, S2, D = torch.linalg.svd(Z2_k, full_matrices=False)
    C = C * 0.001
    D = torch.diag(S2) @ D * 0.1

    Z0 = Z0 * 0.001

    A.requires_grad_(True)
    B.requires_grad_(True)
    C.requires_grad_(True)
    D.requires_grad_(True)
    Z0.requires_grad_(True)

    updated_W1, updated_W2, Z0 = update_SGD(A, B, C, D, Z0,  W1.detach(), W2.detach(), B0.detach(), X.detach(), Y.detach(),
                                        args.SGD_lr, args.SGD_batch_size, epochs=args.SGD_epochs)


    loss4 = evaluate_loss(X_test.detach(), Y_test.detach(), updated_W1.detach(),
                          updated_W2.detach(), (B0 + Z0).detach())


    # SGD -- Zero init
    A = torch.randn((l2, k), device=device) * 0.01
    B = torch.zeros((k, l1), device=device)

    C = torch.randn((l3, k), device=device) * 0.01
    D = torch.zeros((k, l2), device=device)

    Z0 = torch.randn_like(B0) * 0

    A.requires_grad_(True)
    B.requires_grad_(True)
    C.requires_grad_(True)
    D.requires_grad_(True)
    Z0.requires_grad_(True)

    updated_W1, updated_W2, Z0 = update_SGD(A, B, C, D, Z0,  W1.detach(), W2.detach(), B0.detach(), X.detach(), Y.detach(),
                                        args.SGD_lr, args.SGD_batch_size, epochs=args.SGD_epochs)


    loss5 = evaluate_loss(X_test.detach(), Y_test.detach(), updated_W1.detach(),
                          updated_W2.detach(), (B0 + Z0).detach())


    target_class_list.append(args.target_class)
    target_class_loss1.append(float(loss1.cpu()))
    target_class_loss2.append(float(loss2.cpu()))
    target_class_loss3.append(float(loss3.cpu()))
    target_class_loss4.append(float(loss4.cpu()))
    target_class_loss5.append(float(loss5.cpu()))

df = pd.DataFrame({'k': [1,2,3], 'target_class': target_class_list, 'loss1': target_class_loss1,
                   'loss2': target_class_loss2, 'loss3': target_class_loss3, 'loss4': target_class_loss4,
                   'loss5': target_class_loss5,})

df.to_csv('MNIST_result_WAE_class{}.csv'.format(args.target_class))
print(df)

