
import torch
from experiments.Networks import CA_CNN_Convolutional
from experiments.Automaton import CoarseWrapper, LifeLikeAutomaton2D, CellularAutomaton2D, create_init_only_center
from experiments.Trainer import Trainer, DataLogger, SGDTrainer

"""
Experiments for Localized Initalization and 32x32 grid
"""
if __name__=="__main__":
    
    for i in range(512):
        import random
        num = random.randint(0, 2**18)
        for iters in [4,5,6]:
            for spatial_factor in [1]:
                grid_size=32
                init_size = grid_size - 2*iters
                
                CA = LifeLikeAutomaton2D(num,grid_size)
                spatial_factor=spatial_factor
                coarse = CoarseWrapper(CA, iters, spatial_factor, only_output_coarse=True, init_function=create_init_only_center(init_size), init_size=init_size)
                
                for residual in [False]:
                    for extra_width in [1]:
                        for train_length in [3]:

                            net_iters = iters + spatial_factor//2

                            net = CA_CNN_Convolutional(64*(2**extra_width), net_iters, num_classes=2, extra_depth=2, residual=residual,use_bn=True, device='cuda')


                            data_logger = DataLogger(use_wandb=True, use_local=True, localPath='results/', 
                                                    localName='Rule'+str(num), console=False, 
                                                    wandbProject="CA_test_timeCG_centerStrong32_noNat",wandbRunName="RuleLong"+str(num))
                            trainer = Trainer(net, coarse, data_logger)
                            try:
                                trainer.train(512*(2**train_length), 32, 1e-4, early_stopping=0.98, nat_its=0)
                            except:
                                print("Error in training, skipping rule", num)
                                
                            del net
                            del data_logger
                            del trainer

        print("Finished rule", num)
        print("Iteration ", i)
    # Training should go long, sometimes it takes a while for acc to increase
    # For some reason res=False trains faster
    # lr rather high, but it seems to work, prob because high bs
