import torch
import os
import argparse
import numpy as np
import random
from cuda_selector import auto_cuda
from tqdm import tqdm

parser = argparse.ArgumentParser(description="Simulation", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--k", type=int, default=3, help="k")
parser.add_argument("--ridge_lambda", type=float, default=1, help="k")
parser.add_argument("--rank_P", type=int, default=75)
parser.add_argument("--l1", type=int, default=200)
parser.add_argument("--l2", type=int, default=100)
parser.add_argument("--l3", type=int, default=50)
parser.add_argument("--M", type=int, default=50)
parser.add_argument("--SGD_batch_size", type=int, default=256)
parser.add_argument("--SGD_epochs", type=int, default=30)
parser.add_argument("--SGD_lr", type=float, default=0.001)
args = parser.parse_args()

# -------------------- Seed ------------------------------
np.random.seed(args.seed)
random.seed(args.seed)
torch.manual_seed(args.seed)  # cpu
torch.cuda.manual_seed_all(args.seed)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)
# torch.backends.cudnn.enabled = False
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ['PYTHONHASHSEED'] = str(args.seed)

cuda_available = torch.cuda.is_available()


# ------------------- Select Device -------------------------------

if cuda_available:
    device_id = int(auto_cuda()[-1])
    torch.cuda.set_device(device_id)
    device = torch.device('cuda')
    print('Using cuda:{}'.format(torch.cuda.current_device()))

else:
    device = 'cpu'
    print('Using CPU')

# ------------------- Load Parameters -------------------------------
l1, l2, l3  = args.l1, args.l2, args.l3
rank_P = args.rank_P
M = args.M
assert rank_P <= l1, 'rank_P must be >= l3'
assert rank_P > 0, 'rank_P must be > 0'

from functions import estimate_dW2, estimate_dW1
import pandas as pd
# ------------------- Load Models -------------------------------
n_list = list()
loss1_list = list()
loss2_list = list()
loss3_list = list()
SGD_loss_list = list()
mse = torch.nn.MSELoss()

def simulate(n, device):
    with torch.no_grad():

        W1 = torch.randn((l2, l1), device=device) * 0.25
        W1.requires_grad_(False)

        W2 = torch.randn((l3, l2), device=device) * 0.25
        W2.requires_grad_(False)

        tmp = torch.randn((l2, l1), device=device)
        _, _, VT = torch.linalg.svd(tmp)
        V = VT.t()
        P = V[:, :rank_P] @ V[:, :rank_P].t()
        Z1 = W1 @ P - W1

        Z2 = torch.randn((l3, l2), device=device)
        Z2 = Z2 * 0.25

        Z1.requires_grad_(False)
        Z2.requires_grad_(False)

        X = torch.randn((n, l1), device=device)
        Y = ((W2 + Z2) @ torch.nn.functional.relu((W1 + Z1) @ X.t())).t()
        Y = Y + torch.randn_like(Y)
        X.requires_grad_(False)
        Y.requires_grad_(False)

        X_test = torch.randn((1000, l1), device=device)
        Y_test = ((W2 + Z2) @ torch.nn.functional.relu((W1 + Z1) @ X_test.t())).t()
        Y_test = Y_test + torch.randn_like(Y_test)
        X_test.requires_grad_(False)
        Y_test.requires_grad_(False)

        Y1 = (W2 @ torch.nn.functional.relu(W1 @ X_test.t())).t()
        loss1 = mse(Y_test, Y1)

        updated_W2, Z0 = estimate_dW2(X, Y, W1, 0, W2, args.ridge_lambda, args.k)
        Y_pred = (updated_W2 @ torch.nn.functional.relu(W1 @ X_test.t())).t()
        loss2 = mse(Y_test, Y_pred)
        updated_W1 = estimate_dW1(X, Y, W1,  args.k,  device)
        updated_W2, Z0 = estimate_dW2(X, Y, updated_W1, 0, W2, args.ridge_lambda, args.k)
        Y_pred = (updated_W2 @ torch.nn.functional.relu(updated_W1 @ X_test.t())).t()
        loss3 = mse(Y_test, Y_pred)

    # SGD
    A = torch.randn((l2, args.k), device=device, requires_grad=True)
    B = torch.zeros((args.k, l1), device=device, requires_grad=True)
    C = torch.randn((l3, args.k), device=device, requires_grad=True)
    D = torch.zeros((args.k, l2), device=device, requires_grad=True)

    optimizer = torch.optim.SGD([A,B,C,D], lr=args.SGD_lr)
    for e in range(args.SGD_epochs):
        indices = torch.randperm(n)
        X_perm = X[indices]
        Y_perm = Y[indices]


        for i in range(0, n, args.SGD_batch_size):
            optimizer.zero_grad()

            X_batch = X_perm[i:i + args.SGD_batch_size]
            Y_batch = Y_perm[i:i + args.SGD_batch_size]

            Y_batch_pred = ((W2 + C @ D) @ torch.nn.functional.relu( (W1 + A @ B) @ X_batch.t())).t()

            loss =mse(Y_batch_pred, Y_batch)
            loss.backward()
            optimizer.step()

    A.requires_grad = False
    B.requires_grad = False
    C.requires_grad = False
    D.requires_grad = False
    Y_pred = ((W2 + C @ D) @ torch.nn.functional.relu((W1 + A @ B) @ X_test.t())).t()
    SGD_loss = mse(Y_test, Y_pred)
    return loss1, loss2, loss3, SGD_loss


for n in [5000, 10000, 15000, 25000, 30000,  40000, 50000, 100000]:
    loss1_avg = 0
    loss2_avg = 0
    loss3_avg = 0
    SGD_loss_avg = 0
    for m in tqdm(range(M)):
        loss1, loss2, loss3, SGD_loss = simulate(n, device)
        loss1_avg += loss1 / M
        loss2_avg += loss2 / M
        loss3_avg += loss3 / M
        SGD_loss_avg += SGD_loss / M

    n_list.append(n)
    loss1_list.append(float(loss1_avg.cpu()))
    loss2_list.append(float(loss2_avg.cpu()))
    loss3_list.append(float(loss3_avg.cpu()))
    SGD_loss_list.append(float(SGD_loss_avg.cpu()))

    df = pd.DataFrame({'n': n_list, 'loss1': loss1_list,
                       'loss2': loss2_list, 'loss3': loss3_list, 'SGD_loss': SGD_loss_list})

    print(df)
    df.to_csv('Simulation_result_k{}_lambda{}.csv'.format(args.k, args.ridge_lambda))


