import os
import torch
import numpy as np
from torch import nn
from collections import OrderedDict
from args import get_args 
from data import get_dataset
from sparse_gradient_reconstruction import *
from PIL import Image
from utils import report_metrics

args = get_args()
torch.manual_seed(args.random_seed)
np.random.seed(args.random_seed)

train_loader, test_loader, img_size, num_classes, inv_transform = get_dataset(args)
if args.ds_type == 'train':
    loader = train_loader
else:
    loader = test_loader

class Net(nn.Module):
    def __init__(self, num_layers, size, img_size, num_classes, bias=True):
        super(Net, self).__init__()
        img_size = np.prod(img_size)
        layers = [ ('fc1', nn.Linear(img_size,size)), ('relu1', nn.ReLU()) ]
        for i in range(1,num_layers-1):
            layers.append(  ( f'fc{i+1}', nn.Linear(size,size, bias=bias) ) )
            layers.append( (f'relu{i+1}', nn.ReLU()) )
        layers.append( (f'fc{num_layers}', nn.Linear(size, num_classes, bias=bias)) )
        self.model = nn.Sequential( OrderedDict( layers ) )

    def forward(self, x):
        b = x.shape[0] 
        x = x.reshape(b,-1)
        x = self.model(x)
        return x

for batch_idx, (example_data, example_targets) in enumerate(loader):
    if batch_idx < args.st:
        continue
    if batch_idx >= args.en:
        break
    print(f'\n\n\nImage {batch_idx}\n\n\n')
    if args.neptune:
        args.neptune['step'].log(batch_idx)
    random_seed = batch_idx
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)

    net = Net(args.L, args.W, img_size, num_classes)

    l = torch.nn.functional.cross_entropy(net(example_data), example_targets)
    l.backward()

    B_true = example_targets.shape[0]
    if args.true_B:
        B_est, params, grad_params, LR, LR_inv = get_layer_decomp(args.neptune, net, 'model.fc1', B=args.B, device='cuda')
    else:
        B_est, params, grad_params, LR, LR_inv = get_layer_decomp(args.neptune, net, 'model.fc1', B=None, device='cuda')
    print( f'B_est: {B_est} vs B_true: {B_true}' )

    Q_opt = torch.linalg.lstsq( example_data.reshape(B_true,-1).cuda().T, LR[1].T, driver='gels').solution.T.detach()
    Q_opt_error = ( Q_opt @ example_data.reshape(B_true,-1).cuda() - LR[1] ).abs().max().item()
    dZ_opt = LR[0] @ Q_opt
    dZ_true_sparsity = ( dZ_opt.abs() < 1e-6 ).sum(0)
    treshold = get_tau( args.pFN, LR[0].shape[0] )
    min_sparsity = min(dZ_true_sparsity).item()
    min_allowed_sparsity = LR[0].shape[0] * treshold
    print( f'Q_opt num error: {Q_opt_error}, Min sparsity: {min_sparsity}/{min_allowed_sparsity}' )
    if args.neptune:
        args.neptune['result/Q_opt_error'].log( Q_opt_error )
        args.neptune['result/min_sparsity'].log( min_sparsity )
        args.neptune['result/min_allowed_sparsity'].log( min_allowed_sparsity )
        args.neptune['parameters/treshold'].log( treshold )
    
    try:
        Q_rec, Q_rec_inv = getQ( args.neptune, params, grad_params, LR, LR_inv, Q_opt, device='cuda', N=args.N, par_SVD=args.par_SVD, 
            treshold=treshold, cond=args.cond, sigma_tol=args.sigma_tol, sigma_treshold=args.sigma_treshold, sparsity_tol=args.sparsity_tol, count_hack=args.count_hack )
    except:
        Q_rec, Q_rec_inv = None,None
        if args.neptune:
            args.neptune['result/failed'].log( batch_idx )
    if Q_rec_inv is None or Q_rec_inv.sum().isnan().item():
        Q_rec, Q_rec_inv = None,None

    if Q_rec is None:
        X_rec = torch.zeros( *example_data.shape )
        X_rec = X_rec.reshape(args.B, -1)
    else:
        X_rec = (Q_rec_inv[0] @ LR[1]).cpu()
    
    vision_metrics = report_metrics( net, X_rec, example_data, example_targets, 
                f'./result/{batch_idx}', f'result/rec/batch {batch_idx}', f'result/gt/batch {batch_idx}', 
                img_size, inv_transform, args.neptune )

    for p in net.parameters():
        p.grad = None
