import torch
from torch.utils.data import Subset, DataLoader, Dataset, random_split
from torchvision import datasets, transforms
import numpy as np
from torch import nn


def load_MNIST_Cornercrop(split='train', target_class=1, data_dir='dataset'):
    class TransformDataset(Dataset):
        def __init__(self, dataset):
            self.dataset = dataset

        def __len__(self):
            return len(self.dataset)

        def __getitem__(self, index):
            img, _ = self.dataset.__getitem__(index)
            x = img.detach().clone()
            x[:, 14:, :14] = 0
            y = img[:, 14:, :14].detach().clone()
            return x, y

    data_root = data_dir  # '/home/kychong/Documents/dataset'
    if split != 'Test':
        full_dataset = datasets.MNIST(
            root=data_root,
            train=True,
            transform=transforms.ToTensor(),
            download=True
        )
        idx = full_dataset.targets == target_class
        full_dataset.targets = full_dataset.targets[idx]
        full_dataset.data = full_dataset.data[idx]

        train_size = int(0.8 * len(full_dataset))
        val_size = len(full_dataset) - train_size

        # Perform the random split
        train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])
        train_dataset = TransformDataset(train_dataset)  # 48000
        valid_dataset = TransformDataset(val_dataset)  # 12000

        if split == 'Train':
            return train_dataset
        else:
            return valid_dataset
    else:
        test_dataset = datasets.MNIST(
            root=data_root,
            train=False,
            transform=transforms.ToTensor(),
            download=True
        )
        idx = test_dataset.targets == target_class
        test_dataset.targets = test_dataset.targets[idx]
        test_dataset.data = test_dataset.data[idx]
        test_dataset = TransformDataset(test_dataset)
        return test_dataset


def evaluate_loss(X_test, Y_test, W1, W2, B0):
    mse = nn.MSELoss()
    with torch.no_grad():
        Y_pred = (W2 @ torch.nn.functional.relu(W1 @ X_test.t())).t() + B0
        loss = mse(Y_pred, Y_test)
    return loss

def estimate_dW2(X, Y, W1, B0, W2, ridge_lambda, k ):
    with torch.no_grad():
        T = Y - B0
        T_bar = torch.mean(T, dim=0)
        T_minus_T_bar = T - T_bar
        a = torch.nn.functional.relu(W1 @ X.t()).t()
        a_bar = torch.mean(a, dim=0)
        a_minus_a_bar = a - a_bar
        target = T_minus_T_bar - a_minus_a_bar @ W2.t()
        predictor = a_minus_a_bar
        Z2 = torch.linalg.inv(predictor.t() @ predictor + ridge_lambda) @ predictor.t() @ target
        Y_hat = predictor @ Z2  # (n,l3)
        M = Y_hat.t() @ Y_hat
        _, _, VT_hat = torch.linalg.svd(M)
        VT_hat = VT_hat[:k, :]
        P = VT_hat.t() @ VT_hat
        Z2_k = Z2 @ P
        updated_W2 = W2 + Z2_k.t()
        Z0 = (torch.mean(Y - B0 - a @ updated_W2.t(), dim=0)).view(-1)
        return updated_W2, Z0, Z2_k.t()

def estimate_dW1(X, Y, W1,  k, device):
    n, l1 = X.size()
    n, l3 = Y.size()
    with torch.no_grad():
        # Standardize
        x_bar = torch.mean(X, dim=0)
        X_dm = X - x_bar

        cov = (X_dm.t() @ X_dm) / n
        prec = torch.linalg.inv(cov)
        B = prec @ X_dm.t()

        # Stack-SVD
        stack_list = list()
        for i in range(l3):
            Yi = Y[:, i].view(-1, 1)
            EYSx = (B @ (Yi * B.t()) - torch.sum(Yi) * prec) / n
            stack_list.append(EYSx)
        EYSx_stacked = torch.concat(stack_list, dim=1)
        V, S, _ = torch.linalg.svd(EYSx_stacked)
        Var = (S ** 2).cpu()
        Var_per = Var / torch.sum(Var)
        cum_per = torch.cumsum(Var_per, dim=0)
        best_rank = torch.argmax((cum_per >= 0.98).float()).item() + 1

        print('Best rank is {}'.format(best_rank))
        if best_rank == l1:
            W1_copy = W1.detach()
            updated_W1 = W1_copy
            Z1_k = updated_W1 - W1_copy
        else:
            Vk = V.detach()[:, :best_rank]
            W1_copy = W1.detach()
            ImP = torch.eye(l1, device=device) - Vk @ Vk.t()
            Z1 = -W1_copy @ ImP
            A, B, CT = torch.linalg.svd(Z1)
            Z1_k = A[:, :k] @ torch.diag(B[:k]) @ CT[:k, :]
            updated_W1 = W1_copy + Z1_k
    return updated_W1, Z1_k


def update_SGD(A,B,C,D, Z0, W1, W2, B0, X, Y, lr, batch_size, epochs):
    n = X.size(0)
    optimizer = torch.optim.SGD([A,B,C,D, Z0], lr=lr)
    mse = nn.MSELoss()
    for e in range(epochs):
        indices = torch.randperm(n)
        X_perm = X[indices]
        Y_perm = Y[indices]


        for i in range(0, n, batch_size):
            optimizer.zero_grad()

            X_batch = X_perm[i:i + batch_size]
            Y_batch = Y_perm[i:i + batch_size]

            Y_batch_pred = ((W2 + C @ D) @ torch.nn.functional.relu( (W1 + A @ B) @ X_batch.t())).t() + (B0 + Z0)
            loss =mse(Y_batch_pred, Y_batch)
            loss.backward()
            optimizer.step()

    A.requires_grad = False
    B.requires_grad = False
    C.requires_grad = False
    D.requires_grad = False
    Z0.requires_grad = False
    updated_W1 = W1 + A @ B
    updated_W2 = W2 + C @ D
    return updated_W1, updated_W2, Z0
