# File: updated-G.py
# Computes just G from an input x
# Compute G + G' for an input x
# Task is to see if these are as good as MLP.
import torch 
from models.sl2models import Irreps2SymGramMatrix
from cg import transvectant
from generate_data import pickle_to_dataloader
from compare_to_baselines import load_model_from_directory
from utils import normalized_mse_loss
from tests.bernstein_comparison import ispsd
if torch.cuda.is_available():
    device = torch.device('cuda')
    accelerator = "gpu"
else:
    device = torch.device('cpu')
    accelerator = "cpu"


def computeG(x, precomputed_data_dir = None):
    batch, degp1 = x.shape
    xdeg = degp1 - 1
    T = Irreps2SymGramMatrix(xdeg, device, precomputed_data_dir)
    x_dict = {i: torch.zeros(batch,1,i+1) for i in range(0,xdeg,2)}
    x_dict[xdeg] = x.reshape(batch,1,-1)
    return T.apply(x_dict)

def computeUpdatedG(x, precomputed_data_dir = None):
    batch, degp1 = x.shape
    xdeg = degp1 - 1
    T = Irreps2SymGramMatrix(xdeg, device, precomputed_data_dir)
    outtensor = torch.zeros((batch, xdeg//2 + 1, xdeg//2+1))
    for pi in range(x.shape[0]):
        p = x[pi,:]
        x_dict = {}
        x_dict[xdeg] = p.reshape(1,1,xdeg+1)
        invar = transvectant(p, p, xdeg)
        rtinvar = torch.sqrt(abs(invar))
        for tni in range(xdeg//2, xdeg+1,1): # works for degrees multiple of 4 right now
            degtni = 2*xdeg - 2*tni
            x_dict[degtni] = (transvectant(p,p,tni) / rtinvar).reshape(1,1,-1)
        x_dict[xdeg] = p.reshape(1,1,xdeg+1)
        # Magic scaling: 
        # These constants are copied from computations with Julia
        x_dict[0] *= -0.6802879458389796
        x_dict[4] *= -0.9524031434519564
        outtensor[pi,:,:] = T.apply(x_dict)  
    return outtensor  
    
    
    

def main():
    precomputed_data_dir = "/home/user/TransvectantNets_shared/precomputations"
    training_data_dir = "/home/user/TransvectantNets_shared/data/equivariant/deg_6_train_5000_val_100_test_100"
    model_dir = "trained_models_and_logs/maxlogdet500train500val500testdeg8/March27/epochs100maxirrep16layers3channels10invar2020/my_model/version_3/"
    # python train.py --batch_size 32 --max_epochs 100 --data_dir /home/user/TransvectantNets_shared/data/equivariant/deg_6_train_5000_val_100_test_100 --save_dir trained_models_and_logs/maxlogdet500train500val500testdeg8/April13/epochs100maxirrep16layers3channels10invar2020 --precompute_dir ../../TransvectantNets_shared/precomputations --mode max_det --max_irrep 16 --num_layers 3 --num_internal_channels 10 --invar_arch 20 20 --no_batch_norm

    model = load_model_from_directory(model_dir, model_type='SL2Net', device=device)
    dsets, dloaders = pickle_to_dataloader(training_data_dir,
                                        8, return_datasets=True, transvectant_scaling = False)

    x = dsets['train'].tensors[0]
    y = dsets['train'].tensors[1]
    Gbasic = computeG(x, precomputed_data_dir = precomputed_data_dir)
    Gpredict = computeUpdatedG(x, precomputed_data_dir = precomputed_data_dir)
    modelpredict = model(x)
    print(normalized_mse_loss(Gbasic, y))
    print(normalized_mse_loss(Gpredict, y))
    print(normalized_mse_loss(modelpredict, y))

if __name__ == "__main__":
    main()