amirali1985 commited on
Commit
c1fc0b6
·
verified ·
1 Parent(s): 3131b3d

add 10K table; match on num_epochs once catalog has it

Browse files
Files changed (1) hide show
  1. gen_summary.py +19 -0
gen_summary.py CHANGED
@@ -17,6 +17,22 @@ MODEL_REPO = "thoughtworks/arithmetic-sorl"
17
  CACHE_DIR = "/tmp/hf_dash_cache"
18
 
19
  SUMMARY_CONFIGS = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  {
21
  "label": "Standard model (2L/3H/510d) at 25K data",
22
  "baseline_match": {"mode": "baseline", "ops": "add_sub", "dataset_size": 25000, "n_layer": 2, "n_head": 3, "n_embd": 510},
@@ -78,6 +94,9 @@ def _entry_matches(entry, match_cfg):
78
  pass
79
 
80
  for k, v in match_cfg.items():
 
 
 
81
  val = arch_map.get(k, entry.get(k))
82
  if val != v:
83
  return False
 
17
  CACHE_DIR = "/tmp/hf_dash_cache"
18
 
19
  SUMMARY_CONFIGS = [
20
+ {
21
+ "label": "Standard model (2L/3H/510d) at 10K data",
22
+ "baseline_match": {"mode": "baseline", "ops": "add_sub", "dataset_size": 10000, "n_layer": 2, "n_head": 3, "n_embd": 510, "num_epochs": 20},
23
+ "sorl_match": {"mode": "sorl", "ops": "add_sub", "dataset_size": 10000, "n_layer": 2, "n_head": 3, "n_embd": 510, "abs_vocab": 30, "K": 1},
24
+ "sorl_label": "SoRL K=1 abs30",
25
+ "splits": ["add_S0", "add_C1", "add_C3", "add_C5", "add_C6", "sub_M0", "sub_M4"],
26
+ "split_labels": {
27
+ "add_S0": "S0 (no carry, easy)",
28
+ "add_C1": "C1 (1 carry)",
29
+ "add_C3": "C3 (3 hot carries)",
30
+ "add_C5": "C5 (5 hot carries)",
31
+ "add_C6": "C6 (6 hot carries)",
32
+ "sub_M0": "sub_M0 (no borrow, easy)",
33
+ "sub_M4": "sub_M4 (4 borrows)",
34
+ },
35
+ },
36
  {
37
  "label": "Standard model (2L/3H/510d) at 25K data",
38
  "baseline_match": {"mode": "baseline", "ops": "add_sub", "dataset_size": 25000, "n_layer": 2, "n_head": 3, "n_embd": 510},
 
94
  pass
95
 
96
  for k, v in match_cfg.items():
97
+ # num_epochs may not exist in older catalog entries — skip if missing
98
+ if k == "num_epochs" and entry.get(k) is None:
99
+ continue
100
  val = arch_map.get(k, entry.get(k))
101
  if val != v:
102
  return False