-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathapply-firewall.py
More file actions
executable file
·214 lines (165 loc) · 6.61 KB
/
apply-firewall.py
File metadata and controls
executable file
·214 lines (165 loc) · 6.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
#!/usr/bin/env python
"""
Apply a "Security Group" to the members of an etcd cluster.
Usage: apply-firewall.py
"""
import os
import re
import string
import argparse
from threading import Thread
import uuid
import colorama
from colorama import Fore, Style
import paramiko
import requests
import sys
import yaml
def get_nodes_from_args(args):
if args.discovery_url is not None:
return get_nodes_from_discovery_url(args.discovery_url)
return get_nodes_from_discovery_url(get_discovery_url_from_user_data())
def get_nodes_from_discovery_url(discovery_url):
try:
nodes = []
json = requests.get(discovery_url).json()
discovery_nodes = json['node']['nodes']
for node in discovery_nodes:
value = node['value']
ip = re.search('([0-9]{1,3}\.){3}[0-9]{1,3}', value).group(0)
nodes.append(ip)
return nodes
except:
raise IOError('Could not load nodes from discovery url ' + discovery_url)
def get_discovery_url_from_user_data():
name = 'linode-user-data.yaml'
log_info('Loading discovery url from ' + name)
try:
current_dir = os.path.dirname(__file__)
user_data_file = file(os.path.abspath(os.path.join(current_dir, name)), 'r')
user_data = yaml.safe_load(user_data_file)
return user_data['coreos']['etcd']['discovery']
except:
raise IOError('Could not load discovery url from ' + name)
def validate_ip_address(ip):
return True if re.match('([0-9]{1,3}\.){3}[0-9]{1,3}', ip) else False
def get_firewall_contents(node_ips, private=False):
rules_template_text = """*filter
:INPUT DROP [0:0]
:FORWARD DROP [0:0]
:OUTPUT ACCEPT [0:0]
:DOCKER - [0:0]
:Firewall-INPUT - [0:0]
-A INPUT -j Firewall-INPUT
-A FORWARD -j Firewall-INPUT
-A Firewall-INPUT -i lo -j ACCEPT
-A Firewall-INPUT -p icmp --icmp-type echo-reply -j ACCEPT
-A Firewall-INPUT -p icmp --icmp-type destination-unreachable -j ACCEPT
-A Firewall-INPUT -p icmp --icmp-type time-exceeded -j ACCEPT
# Ping
-A Firewall-INPUT -p icmp --icmp-type echo-request -j ACCEPT
# Accept any established connections
-A Firewall-INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT
# Enable the traffic between the nodes of the cluster
-A Firewall-INPUT -s $node_ips -j ACCEPT
# Allow connections from docker container
-A Firewall-INPUT -i docker0 -j ACCEPT
# Accept ssh, http, https and git
-A Firewall-INPUT -m conntrack --ctstate NEW -m multiport$multiport_private -p tcp --dports 22,2222,80,443 -j ACCEPT
# Log and drop everything else
-A Firewall-INPUT -j REJECT
COMMIT
"""
multiport_private = ' -s 192.168.0.0/16' if private else ''
rules_template = string.Template(rules_template_text)
return rules_template.substitute(node_ips=string.join(node_ips, ','), multiport_private=multiport_private)
def apply_rules_to_all(host_ips, rules, private_key):
pkey = detect_and_create_private_key(private_key)
threads = []
for ip in host_ips:
t = Thread(target=apply_rules, args=(ip, rules, pkey))
t.setDaemon(False)
t.start()
threads.append(t)
for thread in threads:
thread.join()
def detect_and_create_private_key(private_key):
private_key_text = private_key.read()
private_key.seek(0)
if '-----BEGIN RSA PRIVATE KEY-----' in private_key_text:
return paramiko.RSAKey.from_private_key(private_key)
elif '-----BEGIN DSA PRIVATE KEY-----' in private_key_text:
return paramiko.DSSKey.from_private_key(private_key)
else:
raise ValueError('Invalid private key file ' + private_key.name)
def apply_rules(host_ip, rules, private_key):
# connect to the server via ssh
ssh = paramiko.SSHClient()
ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
ssh.connect(host_ip, username='core', allow_agent=False, look_for_keys=False, pkey=private_key)
# copy the rules to the temp directory
temp_file = '/tmp/' + str(uuid.uuid4())
ssh.open_sftp()
sftp = ssh.open_sftp()
sftp.open(temp_file, 'w').write(rules)
# move the rules in to place and enable and run the iptables-restore.service
commands = [
'sudo mv ' + temp_file + ' /var/lib/iptables/rules-save',
'sudo chown root:root /var/lib/iptables/rules-save',
'sudo systemctl enable iptables-restore.service',
'sudo systemctl start iptables-restore.service'
]
for command in commands:
stdin, stdout, stderr = ssh.exec_command(command)
stdout.channel.recv_exit_status()
ssh.close()
log_success('Applied rule to ' + host_ip)
def main():
colorama.init()
parser = argparse.ArgumentParser(description='Apply a "Security Group" to a Deis cluster')
parser.add_argument('--private-key', required=True, type=file, dest='private_key', help='Cluster SSH Private Key')
parser.add_argument('--private', action='store_true', dest='private', help='Only allow access to the cluster from the private network')
parser.add_argument('--discovery-url', dest='discovery_url', help='Etcd discovery url')
parser.add_argument('--hosts', nargs='+', dest='hosts', help='The IP addresses of the hosts to apply rules to')
args = parser.parse_args()
nodes = get_nodes_from_args(args)
hosts = args.hosts if args.hosts is not None else nodes
node_ips = []
for ip in nodes:
if validate_ip_address(ip):
node_ips.append(ip)
else:
log_warning('Invalid IP will not be added to security group: ' + ip)
if not len(node_ips) > 0:
raise ValueError('No valid IP addresses in security group.')
host_ips = []
for ip in hosts:
if validate_ip_address(ip):
host_ips.append(ip)
else:
log_warning('Host has invalid IP address: ' + ip)
if not len(host_ips) > 0:
raise ValueError('No valid host addresses.')
log_info('Generating iptables rules...')
rules = get_firewall_contents(node_ips, args.private)
log_success('Generated rules:')
log_debug(rules)
log_info('Applying rules...')
apply_rules_to_all(host_ips, rules, args.private_key)
log_success('Done!')
def log_debug(message):
print(Style.DIM + Fore.MAGENTA + message + Fore.RESET + Style.RESET_ALL)
def log_info(message):
print(Fore.CYAN + message + Fore.RESET)
def log_warning(message):
print(Fore.YELLOW + message + Fore.RESET)
def log_success(message):
print(Style.BRIGHT + Fore.GREEN + message + Fore.RESET + Style.RESET_ALL)
def log_error(message):
print(Style.BRIGHT + Fore.RED + message + Fore.RESET + Style.RESET_ALL)
if __name__ == "__main__":
try:
main()
except Exception as e:
log_error(e.message)
sys.exit(1)