220 lines
6.0 KiB
Python
220 lines
6.0 KiB
Python
|
from math_verify import parse, verify
|
||
|
def compute_score(solution_str, ground_truth) -> float:
|
||
|
retval = 0.
|
||
|
|
||
|
if solution_str == ground_truth:
|
||
|
return 1.0
|
||
|
|
||
|
if float(verify(parse(solution_str), parse(ground_truth))) > 0:
|
||
|
return 1.0
|
||
|
|
||
|
try:
|
||
|
answer = solution_str
|
||
|
string_in_last_boxed = last_boxed_only_string(solution_str)
|
||
|
if string_in_last_boxed is not None:
|
||
|
answer = remove_boxed(string_in_last_boxed)
|
||
|
|
||
|
if is_equiv(answer, ground_truth):
|
||
|
return 1.0
|
||
|
except Exception as e:
|
||
|
print(e)
|
||
|
|
||
|
return retval
|
||
|
|
||
|
|
||
|
def remove_boxed(s):
|
||
|
if "\\boxed " in s:
|
||
|
left = "\\boxed "
|
||
|
assert s[:len(left)] == left
|
||
|
return s[len(left):]
|
||
|
|
||
|
left = "\\boxed{"
|
||
|
|
||
|
assert s[:len(left)] == left
|
||
|
assert s[-1] == "}"
|
||
|
|
||
|
return s[len(left):-1]
|
||
|
|
||
|
def last_boxed_only_string(string):
|
||
|
idx = string.rfind("\\boxed")
|
||
|
if "\\boxed " in string:
|
||
|
return "\\boxed " + string.split("\\boxed ")[-1].split("$")[0]
|
||
|
if idx < 0:
|
||
|
idx = string.rfind("\\fbox")
|
||
|
if idx < 0:
|
||
|
return None
|
||
|
|
||
|
i = idx
|
||
|
right_brace_idx = None
|
||
|
num_left_braces_open = 0
|
||
|
while i < len(string):
|
||
|
if string[i] == "{":
|
||
|
num_left_braces_open += 1
|
||
|
if string[i] == "}":
|
||
|
num_left_braces_open -= 1
|
||
|
if num_left_braces_open == 0:
|
||
|
right_brace_idx = i
|
||
|
break
|
||
|
i += 1
|
||
|
|
||
|
if right_brace_idx is None:
|
||
|
retval = None
|
||
|
else:
|
||
|
retval = string[idx:right_brace_idx + 1]
|
||
|
|
||
|
return retval
|
||
|
|
||
|
# string normalization from https://github.com/EleutherAI/lm-evaluation-harness/blob/master/lm_eval/tasks/hendrycks_math.py
|
||
|
def is_equiv(str1, str2, verbose=False):
|
||
|
if str1 is None and str2 is None:
|
||
|
print("WARNING: Both None")
|
||
|
return True
|
||
|
if str1 is None or str2 is None:
|
||
|
return False
|
||
|
|
||
|
try:
|
||
|
ss1 = strip_string(str1)
|
||
|
ss2 = strip_string(str2)
|
||
|
if verbose:
|
||
|
print(ss1, ss2)
|
||
|
return ss1 == ss2
|
||
|
except Exception:
|
||
|
return str1 == str2
|
||
|
|
||
|
|
||
|
|
||
|
def fix_fracs(string):
|
||
|
substrs = string.split("\\frac")
|
||
|
new_str = substrs[0]
|
||
|
if len(substrs) > 1:
|
||
|
substrs = substrs[1:]
|
||
|
for substr in substrs:
|
||
|
new_str += "\\frac"
|
||
|
if substr[0] == "{":
|
||
|
new_str += substr
|
||
|
else:
|
||
|
try:
|
||
|
assert len(substr) >= 2
|
||
|
except AssertionError:
|
||
|
return string
|
||
|
a = substr[0]
|
||
|
b = substr[1]
|
||
|
if b != "{":
|
||
|
if len(substr) > 2:
|
||
|
post_substr = substr[2:]
|
||
|
new_str += "{" + a + "}{" + b + "}" + post_substr
|
||
|
else:
|
||
|
new_str += "{" + a + "}{" + b + "}"
|
||
|
else:
|
||
|
if len(substr) > 2:
|
||
|
post_substr = substr[2:]
|
||
|
new_str += "{" + a + "}" + b + post_substr
|
||
|
else:
|
||
|
new_str += "{" + a + "}" + b
|
||
|
string = new_str
|
||
|
return string
|
||
|
|
||
|
|
||
|
def fix_a_slash_b(string):
|
||
|
if len(string.split("/")) != 2:
|
||
|
return string
|
||
|
a = string.split("/")[0]
|
||
|
b = string.split("/")[1]
|
||
|
try:
|
||
|
a = int(a)
|
||
|
b = int(b)
|
||
|
assert string == "{}/{}".format(a, b)
|
||
|
new_string = "\\frac{" + str(a) + "}{" + str(b) + "}"
|
||
|
return new_string
|
||
|
except AssertionError:
|
||
|
return string
|
||
|
|
||
|
|
||
|
def remove_right_units(string):
|
||
|
# "\\text{ " only ever occurs (at least in the val set) when describing units
|
||
|
if "\\text{ " in string:
|
||
|
splits = string.split("\\text{ ")
|
||
|
assert len(splits) == 2
|
||
|
return splits[0]
|
||
|
else:
|
||
|
return string
|
||
|
|
||
|
|
||
|
def fix_sqrt(string):
|
||
|
if "\\sqrt" not in string:
|
||
|
return string
|
||
|
splits = string.split("\\sqrt")
|
||
|
new_string = splits[0]
|
||
|
for split in splits[1:]:
|
||
|
if split[0] != "{":
|
||
|
a = split[0]
|
||
|
new_substr = "\\sqrt{" + a + "}" + split[1:]
|
||
|
else:
|
||
|
new_substr = "\\sqrt" + split
|
||
|
new_string += new_substr
|
||
|
return new_string
|
||
|
|
||
|
|
||
|
def strip_string(string):
|
||
|
# linebreaks
|
||
|
string = string.replace("\n", "")
|
||
|
|
||
|
# remove inverse spaces
|
||
|
string = string.replace("\\!", "")
|
||
|
|
||
|
# replace \\ with \
|
||
|
string = string.replace("\\\\", "\\")
|
||
|
|
||
|
# replace tfrac and dfrac with frac
|
||
|
string = string.replace("tfrac", "frac")
|
||
|
string = string.replace("dfrac", "frac")
|
||
|
|
||
|
# remove \left and \right
|
||
|
string = string.replace("\\left", "")
|
||
|
string = string.replace("\\right", "")
|
||
|
|
||
|
# Remove circ (degrees)
|
||
|
string = string.replace("^{\\circ}", "")
|
||
|
string = string.replace("^\\circ", "")
|
||
|
|
||
|
# remove dollar signs
|
||
|
string = string.replace("\\$", "")
|
||
|
|
||
|
# remove units (on the right)
|
||
|
string = remove_right_units(string)
|
||
|
|
||
|
# remove percentage
|
||
|
string = string.replace("\\%", "")
|
||
|
string = string.replace("\%", "") # noqa: W605
|
||
|
|
||
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
||
|
string = string.replace(" .", " 0.")
|
||
|
string = string.replace("{.", "{0.")
|
||
|
# if empty, return empty string
|
||
|
if len(string) == 0:
|
||
|
return string
|
||
|
if string[0] == ".":
|
||
|
string = "0" + string
|
||
|
|
||
|
# to consider: get rid of e.g. "k = " or "q = " at beginning
|
||
|
if len(string.split("=")) == 2:
|
||
|
if len(string.split("=")[0]) <= 2:
|
||
|
string = string.split("=")[1]
|
||
|
|
||
|
# fix sqrt3 --> sqrt{3}
|
||
|
string = fix_sqrt(string)
|
||
|
|
||
|
# remove spaces
|
||
|
string = string.replace(" ", "")
|
||
|
|
||
|
# \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
|
||
|
string = fix_fracs(string)
|
||
|
|
||
|
# manually change 0.5 --> \frac{1}{2}
|
||
|
if string == "0.5":
|
||
|
string = "\\frac{1}{2}"
|
||
|
|
||
|
# NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y
|
||
|
string = fix_a_slash_b(string)
|
||
|
|
||
|
return string
|