#!/usr/bin/env python3
import re, sys, yaml
# Python regular expression used to spot {pattern_ids} in PCRE patterns:
pattern_id_re = r'''(?x) # extended/verbose/free-spacing mode
(?<! # Use a negative lookbehind to ensure pattern ids are NOT preceded with:
\\[xpPgk] # - \x (character hex code),
# - \p, \P (character property)
# - \g, \k (Perl / .NET reference by name)
)
\{ # opening brace
(?P<pattern_id>
[a-zA-Z] # pattern ids must start with a letter (to distinguish them from PCRE quantifiers)
[a-zA-Z0-9_+-]* # pattern ids may feature alphanumeric characters, underscores, dashes and plus signs
)
\} # closing brace
'''
EXPAND_PATTERN_RECURSION = 0
EXPAND_PATTERN_RECURSION_LIMIT = 100
def is_single_group(pattern):
"""
Ensure the given pattern is a single group, i.e. "(foo)", "((foo)(bar))" but not "(foo)(bar)".
"""
# Groups must start and end with a parenthesis:
if not pattern.startswith('(') or not pattern.endswith(')'):
return False
level = 0
skip = False
for pos, char in enumerate(pattern):
if skip:
skip = False
continue
if char == '(':
level += 1
elif char == '\\':
skip = True
elif char == ')':
level -= 1
if level == 0:
# We reached level 0 again; this should only happen after the last character:
return pos == len(pattern) - 1
# Still there? Something went wrong, e.g. escaped closing parenthesis
return False
def group_pattern(pattern, force=False):
if force or ('|' in pattern):
# Prevent unnecessary double-grouping, e.g. (?:(?:foo)):
if not is_single_group(pattern):
pattern = '(?:%s)' % pattern
return pattern
def anchor_pattern(pattern):
if not pattern.startswith('^'):
pattern = '^' + pattern
if not pattern.endswith('$'):
pattern += '$'
return pattern
def add_options_to_pattern(pattern):
if re.search(r'\\[pP]\{\w+\}', pattern):
pattern = '(*UTF)' + pattern
if '\n' in pattern:
pattern = '(?x)' + pattern
return pattern
def make_regex(values, group=False, anchor=False):
pattern = '|'.join(map(escape_regex, values))
if group:
pattern = group_pattern(pattern)
if anchor:
pattern = anchor_pattern(pattern)
return pattern
def get_pattern_by_id(pattern_id):
pattern = doc.get('common', {}).get('pattern').get(pattern_id, '')
if type(pattern) is list:
pattern = make_regex(pattern)
return expand_pattern(pattern)
def get_pattern_by_match(rem):
return get_pattern_by_id(rem.group('pattern_id'))
def get_pattern_to_inject(rem):
pattern = get_pattern_by_match(rem)
# Systematically enclose the pattern within a non-capturing group to
# prevent these problematic cases:
# Case #1:
# foo: '[a-z]+[0-9]{2}'
# bar: '{foo}{5}' => '[a-z]+[0-9]{2}{5}' vs '(?:[a-z]+[0-9]{2}){5}'
# Case #2:
# foo: '(foo)(bar)'
# bar: '{foo}{5}' => '(foo)(bar){5}' vs '(?:(foo)(bar)){5}'
pattern = group_pattern(pattern, True)
return pattern
def expand_pattern(pattern, group=False, anchor=False):
global EXPAND_PATTERN_RECURSION
EXPAND_PATTERN_RECURSION += 1
if EXPAND_PATTERN_RECURSION >= EXPAND_PATTERN_RECURSION_LIMIT:
raise Exception('Recursion limit reached for pattern expansion. There likely is a loop between patterns.')
pattern = re.sub(pattern_id_re, get_pattern_to_inject, pattern)
if group:
pattern = group_pattern(pattern)
if anchor:
pattern = anchor_pattern(pattern)
EXPAND_PATTERN_RECURSION -= 1
return pattern
def looks_like_pcre(string):
# The '.' metacharacter was removed because:
# 1 - it is ubiquitous in non-regex strings
# 2 - it rarely is the only metacharacter of a given regex
for metacharacter in '\\^$*+?()[]{}|':
if metacharacter in string:
return True
return False
def escape(value, metacharacters, prefix='\\'):
for metacharacter in metacharacters:
value = value.replace(metacharacter, prefix + metacharacter)
return value
def escape_regex(value):
return escape(value, '\\^$*+?()[]{}|.')
def make_comparison(variable_name, value, negate=False):
left_part, operator, right_part = '$' + variable_name, '~', value
if type(value) is list:
if len(value) == 1:
operator, right_part = '=', value[0]
else:
right_part = make_regex(value, True, True)
else:
right_part = add_options_to_pattern(expand_pattern(value, True, True))
return '%s %s%s %s' % (left_part, '!' if negate else '', operator, nginx_quote(right_part))
def make_comparison2(var_type, var_name, value, negate=False):
variable_name = var_type + '_' + nginx_normalise(var_name)
return make_comparison(variable_name, value, negate)
def nginx_normalise(name):
return name.lower().replace('-', '_')
def nginx_quote(value):
return '"' + escape(value, '\\"') + '"'
def get_item(obj, key, item_type=None):
try:
value = obj[key]
except (KeyError, IndexError):
return None
if type(value) is str: # the value happens to be a reference:
return doc.get('common', {}).get(item_type or key, {}).get(value)
return value
def check_non_mandatory_item(presence_condition, mismatch_condition, return_status):
print('\t\tset $current_item_present "0";')
print('\t\tif (%s) { set $current_item_present "1"; }' % presence_condition)
print('\t\tset $current_item_mismatch "0";')
print('\t\tif (%s) { set $current_item_mismatch "1"; }' % mismatch_condition)
print('\t\tset $buf "${current_item_present}${current_item_mismatch}";')
print('\t\tif ($buf = "11") { return %d; }' % return_status)
doc = yaml.load(sys.stdin, Loader=yaml.SafeLoader)
uninitialized_variable_warn = doc.get('uninitialized_variable_warn', True)
variable = doc.get('variable', 'waf')
waf_prefix = doc.get('prefix', '/waf')
uri_prefix = doc.get('uri_prefix', '')
status = doc.get('status', 405)
debug = doc.get('debug', False)
if uninitialized_variable_warn:
print('uninitialized_variable_warn off;')
if debug:
print('')
print('# Make debug data available through an extra HTTP header:')
print('add_header X-WAF-Debug "${%s_debug}";' % variable)
print('')
print('if ($%s = "") {' % variable)
print('\trewrite ^ "%s${uri}" last;' % waf_prefix)
print('}')
print('')
print('location %s {' % waf_prefix)
print('\tinternal;')
print('\t# Executed if none of the nested locations matched:')
print('\treturn %d;' % status)
print('')
for uri in doc.get('uri', []):
full_uri_pattern = uri_prefix + uri['pattern']
pattern = expand_pattern(full_uri_pattern)
if looks_like_pcre(pattern):
operator = '~'
# Prefix and anchor the pattern, then prepend relevant PCRE options:
pattern = add_options_to_pattern(anchor_pattern(waf_prefix + pattern))
else:
operator = '='
pattern = waf_prefix + pattern
print(re.sub(r'(?m)^', '\t# ', full_uri_pattern))
print('\tlocation %s %s {' % (operator, nginx_quote(pattern)))
print('\t\tinternal;')
print('\t\tset $%s "1";' % variable)
if debug:
print('\t\tset $%s_debug "%s";' % (variable, full_uri_pattern.replace('\n', '')))
policy = get_item(uri, 'policy')
methods = get_item(policy, 'method')
if methods:
print('')
print('\t\t# Filter HTTP methods:')
print('\t\tif (%s) { return 405; }' % make_comparison('request_method', methods, True))
args = get_item(policy, 'arg', 'argset')
if args is not None: # None/null means "no policy", [] means "no args expected"
print('')
print('\t\t# Filter query string arguments:')
print('\t\tset $new_args "";')
for arg_index, _ in enumerate(args):
arg = get_item(args, arg_index, 'arg')
comparison = make_comparison2('arg', arg['name'], arg['pattern'], True)
return_status = arg.get('status', status)
print('')
if arg.get('mandatory', False):
print('\t\tif ($args !~ "(?:^|&)%s=") { return %d; } # %s must be present' % (arg['name'], return_status, arg['name']))
print('\t\tif (%s) { return %d; } # %s must match the expected format' % (comparison, return_status, arg['name']))
print('\t\tset $new_args "${new_args}&%s=${arg_%s}";' % (arg['name'], arg['name']))
else:
check_non_mandatory_item('$args ~ "(?:^|&)%s="' % arg['name'], comparison, return_status)
print('\t\tif ($current_item_present) { set $new_args "${new_args}&%s=${arg_%s}"; }' % (arg['name'], arg['name']))
print('')
print('\t\t# Get rid of the leading ampersand, if any:')
print('\t\tif ($new_args ~ "^&(?<buf>.+)") { set $new_args "${buf}"; }')
print('\t\t# Normalise the query string:')
print('\t\tset $args "${new_args}";')
headers = get_item(policy, 'header', 'headerset')
if headers:
print('')
print('\t\t# Check headers:')
for header_index, _ in enumerate(headers):
header = get_item(headers, header_index, 'header')
norm_name = nginx_normalise(header['name'])
comparison = make_comparison2('http', header['name'], header['pattern'], True)
return_status = header.get('status', status)
if header.get('mandatory', False):
print('\t\tif (%s) { return %d; } # %s must match the expected format' % (comparison, return_status, header['name']))
else:
check_non_mandatory_item('$http_%s != ""' % norm_name, comparison, return_status)
print('')
cookies = get_item(policy, 'cookie', 'cookieset')
if cookies:
print('')
print('\t\t# Check cookies:')
for cookie_index, _ in enumerate(cookies):
cookie = get_item(cookies, cookie_index, 'cookie')
norm_name = nginx_normalise(header['name'])
comparison = make_comparison2('cookie', cookie['name'], cookie['pattern'], True)
return_status = cookie.get('status', status)
if header.get('mandatory', False):
print('\t\tif (%s) { return %d; } # %s must match the expected format' % (comparison, return_status, cookie['name']))
else:
check_non_mandatory_item('$cookie_%s != ""' % norm_name, comparison, return_status)
print('')
# TODO rewrite cookie header?
print('')
print('\t\t# Exit from the WAF part and return to regular nginx locations:')
print('\t\trewrite "^%s(?<tail>.*)" "${tail}" last;' % waf_prefix)
print('\t}')
print('}')