-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparse_generalizations.py
More file actions
100 lines (95 loc) · 3.69 KB
/
Copy pathparse_generalizations.py
File metadata and controls
100 lines (95 loc) · 3.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import re
import pandas as pd
import argparse
parser = argparse.ArgumentParser(
prog="parse_generalisations",
description="Parses and processes results from a generalisation run",
)
parser.add_argument("filename")
parser.add_argument("-s", "--string", default="1 surround DAX after DAX thrice")
parser.add_argument("--full_table", action="store_true")
args = parser.parse_args()
rule_string = args.string.split(" ")
with open(args.filename) as f:
evaluations = []
evaluation = {}
episode_number = None
state = None
colour_count = 0
retrieval_line_count = 0
for line in f:
line = line.strip()
if m := re.match(r"Evaluation episode (\d+)", line):
episode_number = int(m.group(1))
state = None
if len(evaluation) > 0:
evaluations.append(evaluation)
evaluation = {}
colour_count = 0
retrieval_line_count = 0
elif episode_number is None:
# Ignore header
continue
elif line == "support items;":
state = "support"
elif line.startswith("retrieval items;"):
state = "retrieval"
elif line.startswith("generalization items;"):
state = "generalize"
elif state is not None:
if state == "support":
continue
elif state == "retrieval":
if retrieval_line_count >= 14:
if m := re.match(r"(\w+) -> ([A-Z]+)$", line):
evaluation[
f"colour_word_{colour_count + 1 if colour_count < 3 else 'h'}"
] = m.group(1)
evaluation[
f"colour_{colour_count+ 1 if colour_count < 3 else 'h'}"
] = m.group(2)
colour_count += 1
retrieval_line_count += 1
elif state == "generalize":
if m := re.match(
r" ".join([r"(\w+)" for w in rule_string]) + r" ->", line
):
for x in filter(
lambda x: x in rule_string,
["DAX", "surround", "after", "thrice"],
):
evaluation[x] = m.group(rule_string.index(x) + 1)
if "target" in line:
evaluation["correct"] = False
string = line.split("->")[1].split("(")[0].strip()
for i in range(4):
string = string.replace(
evaluation[f"colour_{i + 1 if i < 3 else 'h'}"],
str(i + 1) if i != 3 else "h",
)
evaluation["generalization"] = string
else:
evaluation["correct"] = True
string = line.split("->")[1].strip()
for i in range(4):
string = string.replace(
evaluation[f"colour_{i + 1 if i < 3 else 'h'}"],
str(i + 1) if i != 3 else "h",
)
evaluation["generalization"] = string
episode_number = None
state = None
evaluations.append(evaluation)
df = pd.DataFrame(evaluations)
print(df["generalization"].value_counts())
df["correct"] = df["correct"] * 100
print(
df.groupby(["colour_word_h", "colour_h"] if args.full_table else ["colour_h"])[
"correct"
].mean()
)
print(
df.groupby(["colour_word_h", "colour_h"] if args.full_table else ["colour_h"])[
"correct"
].sem()
)