add 10K table; match on num_epochs once catalog has it
Browse files- gen_summary.py +19 -0
gen_summary.py
CHANGED
|
@@ -17,6 +17,22 @@ MODEL_REPO = "thoughtworks/arithmetic-sorl"
|
|
| 17 |
CACHE_DIR = "/tmp/hf_dash_cache"
|
| 18 |
|
| 19 |
SUMMARY_CONFIGS = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
{
|
| 21 |
"label": "Standard model (2L/3H/510d) at 25K data",
|
| 22 |
"baseline_match": {"mode": "baseline", "ops": "add_sub", "dataset_size": 25000, "n_layer": 2, "n_head": 3, "n_embd": 510},
|
|
@@ -78,6 +94,9 @@ def _entry_matches(entry, match_cfg):
|
|
| 78 |
pass
|
| 79 |
|
| 80 |
for k, v in match_cfg.items():
|
|
|
|
|
|
|
|
|
|
| 81 |
val = arch_map.get(k, entry.get(k))
|
| 82 |
if val != v:
|
| 83 |
return False
|
|
|
|
| 17 |
CACHE_DIR = "/tmp/hf_dash_cache"
|
| 18 |
|
| 19 |
SUMMARY_CONFIGS = [
|
| 20 |
+
{
|
| 21 |
+
"label": "Standard model (2L/3H/510d) at 10K data",
|
| 22 |
+
"baseline_match": {"mode": "baseline", "ops": "add_sub", "dataset_size": 10000, "n_layer": 2, "n_head": 3, "n_embd": 510, "num_epochs": 20},
|
| 23 |
+
"sorl_match": {"mode": "sorl", "ops": "add_sub", "dataset_size": 10000, "n_layer": 2, "n_head": 3, "n_embd": 510, "abs_vocab": 30, "K": 1},
|
| 24 |
+
"sorl_label": "SoRL K=1 abs30",
|
| 25 |
+
"splits": ["add_S0", "add_C1", "add_C3", "add_C5", "add_C6", "sub_M0", "sub_M4"],
|
| 26 |
+
"split_labels": {
|
| 27 |
+
"add_S0": "S0 (no carry, easy)",
|
| 28 |
+
"add_C1": "C1 (1 carry)",
|
| 29 |
+
"add_C3": "C3 (3 hot carries)",
|
| 30 |
+
"add_C5": "C5 (5 hot carries)",
|
| 31 |
+
"add_C6": "C6 (6 hot carries)",
|
| 32 |
+
"sub_M0": "sub_M0 (no borrow, easy)",
|
| 33 |
+
"sub_M4": "sub_M4 (4 borrows)",
|
| 34 |
+
},
|
| 35 |
+
},
|
| 36 |
{
|
| 37 |
"label": "Standard model (2L/3H/510d) at 25K data",
|
| 38 |
"baseline_match": {"mode": "baseline", "ops": "add_sub", "dataset_size": 25000, "n_layer": 2, "n_head": 3, "n_embd": 510},
|
|
|
|
| 94 |
pass
|
| 95 |
|
| 96 |
for k, v in match_cfg.items():
|
| 97 |
+
# num_epochs may not exist in older catalog entries — skip if missing
|
| 98 |
+
if k == "num_epochs" and entry.get(k) is None:
|
| 99 |
+
continue
|
| 100 |
val = arch_map.get(k, entry.get(k))
|
| 101 |
if val != v:
|
| 102 |
return False
|