import math
import torch
from model.models.classifier import Classifier


def get_data(args, dt_dl):
    '''
    Get random data for stability analysis
    options: mnist, random, ordered to complex, complex
    '''
    option = args.dataset
    if option in ['mnist', 'fmnist', 'svhn', 'cifar10', 'cifar100']:
        return dt_dl
    if option == 'random':
        x = torch.randn(args.bsz, args.ds[0])
        y = torch.randn(args.bsz, args.ds[-1])
        # compute variance for each data point
        # for i in range(x.shape[0]):
        #     print(f"Input variance of data {i}: {torch.var(x[i])}")
        #     print(f"Output variance of data {i}: {torch.var(y[i])}")
    if option == 'o2c':
        x = torch.randn(args.bsz, args.ds[0])\
            .to(args.device, non_blocking=True)
        if args.o2c_idx == 1:
            y = x.detach()
        else:
            y, _ = ff(x, args, L=args.o2c_idx)
        # standardize y
        # y = torch.div((y - torch.mean(y, dim=-1, keepdim=True)), torch.std(y, dim=-1, keepdim=True))
        # currently bsz << latent_dim
        # zca = zca_whitening(y)
        # y = torch.matmul(y, zca)
        # TODO: check with forbenius norm whether the cov of y is identity
    return args.bsz, [(x, y, [i for  i in range(args.bsz)])]


def ff(x, args, L=0):
    model = Classifier(args, L)
    y_hat, hs = model.forward_layer(x)
    return y_hat, hs
