kindwolf.org Git repositories nginxwaf / master nginxwaf
master

Tree @master (Download .tar.gz)

nginxwaf @masterraw · history · blame

#!/usr/bin/env python3
import re, sys, yaml

# Python regular expression used to spot {pattern_ids} in PCRE patterns:
pattern_id_re = r'''(?x) # extended/verbose/free-spacing mode
(?<!                     # Use a negative lookbehind to ensure pattern ids are NOT preceded with:
	\\[xpPgk]              # - \x (character hex code),
	                       # - \p, \P (character property)
	                       # - \g, \k (Perl / .NET reference by name)
)                        
\{                       # opening brace
(?P<pattern_id>
	[a-zA-Z]             # pattern ids must start with a letter (to distinguish them from PCRE quantifiers)
	[a-zA-Z0-9_+-]*      # pattern ids may feature alphanumeric characters, underscores, dashes and plus signs
)
\}                       # closing brace
'''

EXPAND_PATTERN_RECURSION = 0
EXPAND_PATTERN_RECURSION_LIMIT = 100

def is_single_group(pattern):
	"""
	Ensure the given pattern is a single group, i.e. "(foo)", "((foo)(bar))" but not "(foo)(bar)".
	"""
	# Groups must start and end with a parenthesis:
	if not pattern.startswith('(') or not pattern.endswith(')'):
		return False
	level = 0
	skip = False
	for pos, char in enumerate(pattern):
		if skip:
			skip = False
			continue
		if char == '(':
			level += 1
		elif char == '\\':
			skip = True
		elif char == ')':
			level -= 1
		if level == 0:
			# We reached level 0 again; this should only happen after the last character:
			return pos == len(pattern) - 1
	# Still there? Something went wrong, e.g. escaped closing parenthesis
	return False

def group_pattern(pattern, force=False):
	if force or ('|' in pattern):
		# Prevent unnecessary double-grouping, e.g. (?:(?:foo)):
		if not is_single_group(pattern):
			pattern = '(?:%s)' % pattern
	return pattern

def anchor_pattern(pattern):
	if not pattern.startswith('^'):
		pattern = '^' + pattern
	if not pattern.endswith('$'):
		pattern += '$'
	return pattern

def add_options_to_pattern(pattern):
	if re.search(r'\\[pP]\{\w+\}', pattern):
		pattern = '(*UTF)' + pattern
	if '\n' in pattern:
		pattern = '(?x)' + pattern
	return pattern

def make_regex(values, group=False, anchor=False):
	pattern = '|'.join(map(escape_regex, values))
	if group:
		pattern = group_pattern(pattern)
	if anchor:
		pattern = anchor_pattern(pattern)
	return pattern

def get_pattern_by_id(pattern_id):
	pattern = doc.get('common', {}).get('pattern').get(pattern_id, '')
	if type(pattern) is list:
		pattern = make_regex(pattern)
	return expand_pattern(pattern)

def get_pattern_by_match(rem):
	return get_pattern_by_id(rem.group('pattern_id'))

def get_pattern_to_inject(rem):
	pattern = get_pattern_by_match(rem)
	# Systematically enclose the pattern within a non-capturing group to
	# prevent these problematic cases:
	# Case #1:
	#   foo: '[a-z]+[0-9]{2}'
	#   bar: '{foo}{5}' => '[a-z]+[0-9]{2}{5}' vs '(?:[a-z]+[0-9]{2}){5}'
	# Case #2:
	#   foo: '(foo)(bar)'
	#   bar: '{foo}{5}' => '(foo)(bar){5}' vs '(?:(foo)(bar)){5}'
	pattern = group_pattern(pattern, True)
	return pattern

def expand_pattern(pattern, group=False, anchor=False):
	global EXPAND_PATTERN_RECURSION
	EXPAND_PATTERN_RECURSION += 1
	if EXPAND_PATTERN_RECURSION >= EXPAND_PATTERN_RECURSION_LIMIT:
		raise Exception('Recursion limit reached for pattern expansion. There likely is a loop between patterns.')

	pattern = re.sub(pattern_id_re, get_pattern_to_inject, pattern)
	if group:
		pattern = group_pattern(pattern)
	if anchor:
		pattern = anchor_pattern(pattern)

	EXPAND_PATTERN_RECURSION -= 1
	return pattern

def looks_like_pcre(string):
	# The '.' metacharacter was removed because:
	# 1 - it is ubiquitous in non-regex strings
	# 2 - it rarely is the only metacharacter of a given regex
	for metacharacter in '\\^$*+?()[]{}|':
		if metacharacter in string:
			return True
	return False

def escape(value, metacharacters, prefix='\\'):
	for metacharacter in metacharacters:
		value = value.replace(metacharacter, prefix + metacharacter)
	return value
	
def escape_regex(value):
	return escape(value, '\\^$*+?()[]{}|.')
	
def make_comparison(variable_name, value, negate=False):
	left_part, operator, right_part = '$' + variable_name, '~', value
	if type(value) is list:
		if len(value) == 1:
			operator, right_part = '=', value[0]
		else:
			right_part = make_regex(value, True, True)
	else:
		right_part = add_options_to_pattern(expand_pattern(value, True, True))
	return '%s %s%s %s' % (left_part, '!' if negate else '', operator, nginx_quote(right_part))

def make_comparison2(var_type, var_name, value, negate=False):
	variable_name = var_type + '_' + nginx_normalise(var_name)
	return make_comparison(variable_name, value, negate)

def nginx_normalise(name):
	return name.lower().replace('-', '_')

def nginx_quote(value):
	return '"' + escape(value, '\\"') + '"'

def get_item(obj, key, item_type=None):
	try:
		value = obj[key]
	except (KeyError, IndexError):
		return None
	if type(value) is str: # the value happens to be a reference:
		return doc.get('common', {}).get(item_type or key, {}).get(value)
	return value

def check_non_mandatory_item(presence_condition, mismatch_condition, return_status):
	print('\t\tset $current_item_present "0";')
	print('\t\tif (%s) { set $current_item_present "1"; }' % presence_condition)
	print('\t\tset $current_item_mismatch "0";')
	print('\t\tif (%s) { set $current_item_mismatch "1"; }' % mismatch_condition)
	print('\t\tset $buf "${current_item_present}${current_item_mismatch}";')
	print('\t\tif ($buf = "11") { return %d; }' % return_status)

doc = yaml.load(sys.stdin, Loader=yaml.SafeLoader)

uninitialized_variable_warn = doc.get('uninitialized_variable_warn', True)
variable = doc.get('variable', 'waf')
waf_prefix = doc.get('prefix', '/waf')
uri_prefix = doc.get('uri_prefix', '')
status = doc.get('status', 405)
debug = doc.get('debug', False)

if uninitialized_variable_warn:
	print('uninitialized_variable_warn off;')
if debug:
	print('')
	print('# Make debug data available through an extra HTTP header:')
	print('add_header X-WAF-Debug "${%s_debug}";' % variable)
	print('')
print('if ($%s = "") {' % variable)
print('\trewrite ^ "%s${uri}" last;' % waf_prefix)
print('}')
print('')
print('location %s {' % waf_prefix)
print('\tinternal;')
print('\t# Executed if none of the nested locations matched:')
print('\treturn %d;' % status)
print('')
for uri in doc.get('uri', []):
	full_uri_pattern = uri_prefix + uri['pattern']
	pattern = expand_pattern(full_uri_pattern)
	if looks_like_pcre(pattern):
		operator = '~'
		# Prefix and anchor the pattern, then prepend relevant PCRE options:
		pattern = add_options_to_pattern(anchor_pattern(waf_prefix + pattern))
	else:
		operator = '='
		pattern = waf_prefix + pattern
	print(re.sub(r'(?m)^', '\t# ', full_uri_pattern))
	print('\tlocation %s %s {' % (operator, nginx_quote(pattern)))
	print('\t\tinternal;')
	print('\t\tset $%s "1";' % variable)
	if debug:
		print('\t\tset $%s_debug "%s";' % (variable, full_uri_pattern.replace('\n', '')))
	policy = get_item(uri, 'policy')

	methods = get_item(policy, 'method')
	if methods:
		print('')
		print('\t\t# Filter HTTP methods:')
		print('\t\tif (%s) { return 405; }' % make_comparison('request_method', methods, True))

	args = get_item(policy, 'arg', 'argset')
	if args is not None: # None/null means "no policy", [] means "no args expected"
		print('')
		print('\t\t# Filter query string arguments:')
		print('\t\tset $new_args "";')
		for arg_index, _ in enumerate(args):
			arg = get_item(args, arg_index, 'arg')
			comparison = make_comparison2('arg', arg['name'], arg['pattern'], True)
			return_status = arg.get('status', status)
			print('')
			if arg.get('mandatory', False):
				print('\t\tif ($args !~ "(?:^|&)%s=") { return %d; } # %s must be present' % (arg['name'], return_status, arg['name']))
				print('\t\tif (%s) { return %d; } # %s must match the expected format' % (comparison, return_status, arg['name']))
				print('\t\tset $new_args "${new_args}&%s=${arg_%s}";' % (arg['name'], arg['name']))
			else:
				check_non_mandatory_item('$args ~ "(?:^|&)%s="' % arg['name'], comparison, return_status)
				print('\t\tif ($current_item_present) { set $new_args "${new_args}&%s=${arg_%s}"; }' % (arg['name'], arg['name']))
		print('')
		print('\t\t# Get rid of the leading ampersand, if any:')
		print('\t\tif ($new_args ~ "^&(?<buf>.+)") { set $new_args "${buf}"; }')
		print('\t\t# Normalise the query string:')
		print('\t\tset $args "${new_args}";')

	headers = get_item(policy, 'header', 'headerset')
	if headers:
		print('')
		print('\t\t# Check headers:')
		for header_index, _ in enumerate(headers):
			header = get_item(headers, header_index, 'header')
			norm_name = nginx_normalise(header['name'])
			comparison = make_comparison2('http', header['name'], header['pattern'], True)
			return_status = header.get('status', status)
			if header.get('mandatory', False):
				print('\t\tif (%s) { return %d; } # %s must match the expected format' % (comparison, return_status, header['name']))
			else:
				check_non_mandatory_item('$http_%s != ""' % norm_name, comparison, return_status)
			print('')

	cookies = get_item(policy, 'cookie', 'cookieset')
	if cookies:
		print('')
		print('\t\t# Check cookies:')
		for cookie_index, _ in enumerate(cookies):
			cookie = get_item(cookies, cookie_index, 'cookie')
			norm_name = nginx_normalise(header['name'])
			comparison = make_comparison2('cookie', cookie['name'], cookie['pattern'], True)
			return_status = cookie.get('status', status)
			if header.get('mandatory', False):
				print('\t\tif (%s) { return %d; } # %s must match the expected format' % (comparison, return_status, cookie['name']))
			else:
				check_non_mandatory_item('$cookie_%s != ""' % norm_name, comparison, return_status)
			print('')
			
		# TODO rewrite cookie header?

	print('')
	print('\t\t# Exit from the WAF part and return to regular nginx locations:')
	print('\t\trewrite "^%s(?<tail>.*)" "${tail}" last;' % waf_prefix)
	print('\t}')
print('}')