
import torch
from experiments.Networks import CA_CNN_Convolutional
from experiments.Automaton import CoarseWrapper, LifeLikeAutomaton2D, CellularAutomaton2D, create_init_only_center, init_512_rule
from experiments.Trainer import Trainer, DataLogger, SGDTrainer
from experiments.AlternativeNetworks import PixelTransformerClassifier
from tqdm import tqdm



if __name__=="__main__":
    import time


    for i in tqdm(range(512)):
        num = init_512_rule(uniform_lambda=True)
        for iters in [2,3,4,5]:

            for spatial_factor in [1]:
                grid_size=16
                CA = CellularAutomaton2D(num,grid_size)
                spatial_factor=spatial_factor
                coarse = CoarseWrapper(CA, iters, spatial_factor, only_output_coarse=True, init_function=None)
                
                for residual in [False]:
                    for extra_width in [1]:
                        for train_length in [3]:

                            
                            net = PixelTransformerClassifier(k=iters,num_layers=3, embed_dim=64, num_heads=4, num_classes=2).cuda()
                            num_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
                            data_logger = DataLogger(use_wandb=True, use_local=True, localPath='results/', 
                                                    localName='Rule'+str(i), console=False, 
                                                    wandbProject="CA_Transformer_AllAuto",wandbRunName="Rule"+str(i)+"_"+str(iters))
                            trainer = Trainer(net, coarse, data_logger)
                            trainer.train(512*(2**train_length), 1, 3e-4, early_stopping=None, nat_its=0)
                                
                            del net
                            del data_logger
                            del trainer

        print("Finished rule", num)
        print("Iteration ", i)
    # Training should go long, sometimes it takes a while for acc to increase
    # For some reason res=False trains faster
    # lr rather high, but it seems to work, prob because high bs
