"""
Observing scaling laws associated with model size and training iterations
"""

# <codecell>
import jax
import pandas as pd
from tqdm import tqdm

import sys
sys.path.append('../../../')
sys.path.append('../../../../')
from common import *
from model.mlp import MlpConfig
from model.transformer import TransformerConfig
from task.function import ClassificationTask 

run_id = new_seed()
print('RUN ID', run_id)

batch_size = 128

train_iters_mlp = 64_000
depths_mlp = [1, 2, 4]
widths_mlp = [4, 16, 64, 256]

train_iters_trans = 128_000
depths_trans = [1, 2, 4]
widths_trans = [8, 32]

# n_dims = [2, 4, 8, 16, 32, 64]
n_dims = [64]
n_classes = [2, 16, 64]
# n_classes = [16]

### START TEST CONFIGS
# train_iters_mlp = 64_0
# depths_mlp = [1]
# widths_mlp = [4]

# train_iters_trans = 128_0
# depths_trans = [1]
# widths_trans = [8]

# n_dims = [2]
# n_classes = [2]
### END TEST CONFIGS

all_cases = []

for n_d in n_dims:
    for depth in depths_mlp:
        for width in widths_mlp:
            for c in n_classes:
                common_args = {'n_dims': n_d, 'seed': new_seed(), 'n_classes': c}

                all_cases.append(
                    Case('MLP', MlpConfig(n_out=c, n_layers=depth, n_hidden=width),
                        train_args={'train_iters': train_iters_mlp, 'test_iters': 1, 'test_every': 1000, 'loss': 'ce'},
                        train_task = ClassificationTask(batch_size=batch_size, **common_args),
                        test_task=ClassificationTask(batch_size=1024, **common_args))
                )


    for depth in depths_trans:
        for width in widths_trans:
            for c in n_classes:
                common_args = {'n_dims': n_d, 'tokenize': 1, 'seed': new_seed(), 'n_classes': c}

                all_cases.append(
                    Case('Transformer', TransformerConfig(n_out=c, n_layers=depth, n_hidden=width, pos_emb=True, n_mlp_layers=2),
                        train_args={'train_iters': train_iters_trans, 'test_iters': 1, 'test_every': 1000, 'loss': 'ce'},
                        train_task=ClassificationTask(batch_size=batch_size, **common_args),
                        test_task=ClassificationTask(batch_size=1024, **common_args))
                )


for case in tqdm(all_cases):
    print('RUNNING', case.name)
    case.run()

test_tasks = [c.test_task for c in all_cases]
eval_cases(all_cases, eval_task=test_tasks, key_name='acc')

for case in all_cases:
    case.info['size'] = sum(x.size for x in jax.tree_util.tree_leaves(case.state.params))
    case.info['flops'] = case.get_flops()
    case.state = None

df = pd.DataFrame(all_cases)
df.to_pickle(f'res.{run_id}.pkl')

print('done!')

# %%
