Skip to content

Commit d85c05e

Browse files
committed
feat: add lru_cache
1 parent 9acd37b commit d85c05e

6 files changed

Lines changed: 144 additions & 264 deletions

File tree

smartdns/checkconfig.py

Lines changed: 0 additions & 123 deletions
This file was deleted.
Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# -*- coding: utf-8 -*-
2+
import bisect
3+
import logging
24
import re
3-
import yaml
45
import sys
56
import time
6-
import bisect
7-
import logging
7+
import yaml
8+
from functools import lru_cache
89
from os.path import isfile
910

1011
logger = logging.getLogger(__name__)
@@ -25,7 +26,7 @@ def long2ip(num):
2526
return '.'.join(iplist)
2627

2728

28-
class BaseIPPool(object):
29+
class BaseFinder(object):
2930

3031
def __init__(self, ipfile, recordfile):
3132
if not isfile(ipfile) or not isfile(recordfile):
@@ -45,14 +46,14 @@ def __init__(self, ipfile, recordfile):
4546
self.locmapip = {}
4647

4748
# load record data
48-
self.LoadRecord()
49+
self.loadRecord()
4950

5051
# load ip data
51-
self.LoadIP()
52+
self.loadIP()
5253

5354
print('Init IP pool finished !')
5455

55-
def LoadIP(self):
56+
def loadIP(self):
5657
f = open(self.ipfile, 'r')
5758
logger.warning("before load: %s" % (time.time()))
5859
for eachline in f:
@@ -72,14 +73,14 @@ def LoadIP(self):
7273
self.iphash[ipstart] = [ipstart, ipend,
7374
country, province, city, sp, {}]
7475
# 最好合并后再计算
75-
self.JoinIP(ipstart)
76+
self.joinIP(ipstart)
7677
f.close()
7778
logger.warning("after load: %s" % (time.time()))
7879
self.iplist.sort()
7980
logger.warning("after sort: %s" % (time.time()))
8081

8182
# 重写LoadRecord和JoinIP,提升启动效率
82-
def LoadRecord(self):
83+
def loadRecord(self):
8384
Add = [8, 4, 2, 1]
8485
f = open(self.recordfile, 'r')
8586
self.record = yaml.load(f, Loader=yaml.FullLoader)
@@ -126,31 +127,31 @@ def LoadRecord(self):
126127
f.close()
127128
# logger.warning(self.locmapip)
128129

129-
def JoinIP(self, ip):
130+
def joinIP(self, ip):
130131
for fqdnk, fqdnv in self.locmapip.items():
131132
l1 = []
132133
l2 = []
133134
l3 = []
134135
weight = 0
135-
#logger.warning("l1 : %s, %s" %(self.iphash[ip][2], fqdnv.keys()))
136+
# logger.warning("l1 : %s, %s" %(self.iphash[ip][2], fqdnv.keys()))
136137
if "" in fqdnv and "" != self.iphash[ip][2]:
137138
l1.append(fqdnv[""])
138139
if self.iphash[ip][2] in fqdnv:
139140
l1.append(fqdnv[self.iphash[ip][2]])
140141
for k in l1:
141-
#logger.warning("l2 : %s, %s" %(self.iphash[ip][3], k.keys()))
142+
# logger.warning("l2 : %s, %s" %(self.iphash[ip][3], k.keys()))
142143
if "" in k and "" != self.iphash[ip][3]:
143144
l2.append(k[""])
144145
if self.iphash[ip][3] in k:
145146
l2.append(k[self.iphash[ip][3]])
146147
for k in l2:
147-
#logger.warning("l3 : %s, %s" %(self.iphash[ip][4], k.keys()))
148+
# logger.warning("l3 : %s, %s" %(self.iphash[ip][4], k.keys()))
148149
if "" in k and "" != self.iphash[ip][4]:
149150
l3.append(k[""])
150151
if self.iphash[ip][4] in k:
151152
l3.append(k[self.iphash[ip][4]])
152153
for k in l3:
153-
#logger.warning("l4 : %s, %s" %(self.iphash[ip][5], k.keys()))
154+
# logger.warning("l4 : %s, %s" %(self.iphash[ip][5], k.keys()))
154155
if "" in k and k[""][1] > weight:
155156
self.iphash[ip][6][fqdnk] = k[""]
156157
weight = k[""][1]
@@ -160,14 +161,15 @@ def JoinIP(self, ip):
160161
if fqdnk not in self.iphash[ip][6]:
161162
self.iphash[ip][6][fqdnk] = [self.record[fqdnk]['default'], 0]
162163

163-
def ListIP(self):
164+
def listIP(self):
164165
for key in self.iphash:
165166
print("ipstart: %s ipend: %s country: %s province: %s city: %s sp: %s" % (
166-
key, self.iphash[key][1], self.iphash[key][2], self.iphash[key][3], self.iphash[key][4], self.iphash[key][5]))
167+
key, self.iphash[key][1], self.iphash[key][2], self.iphash[key][3], self.iphash[key][4],
168+
self.iphash[key][5]))
167169
for i in self.iphash[key][6]:
168170
print("[domain:%s ip: %s]" % (i, self.iphash[key][6][i][0]))
169171

170-
def SearchLocation(self, ip):
172+
def searchLocation(self, ip):
171173
ipnum = ip2long(ip)
172174
ip_point = bisect.bisect_right(self.iplist, ipnum)
173175
i = self.iplist[ip_point - 1]
@@ -178,8 +180,9 @@ def SearchLocation(self, ip):
178180

179181
return i, j, ipnum
180182

181-
def FindIP(self, ip, name):
182-
i, _, ipnum = self.SearchLocation(ip)
183+
@lru_cache(maxsize=2048 * 2048, typed=True)
184+
def findIP(self, ip, name):
185+
i, _, ipnum = self.searchLocation(ip)
183186
ip_list = None
184187
if i in self.iphash:
185188
ipstart = self.iphash[i][0]
@@ -202,16 +205,14 @@ def FindIP(self, ip, name):
202205
return ip_list
203206

204207

205-
class IPPool(object):
208+
class Finder(object):
206209

207210
def __init__(self, ipfile, recordfile, monitor_mapping):
208211
self.monitor_mapping = monitor_mapping
209-
self.finder = BaseIPPool(ipfile, recordfile)
212+
self.finder = BaseFinder(ipfile, recordfile)
210213

211-
def FindIP(self, ip, name):
212-
start_time = time.time()
213-
tmp_ip_list = self.finder.FindIP(ip, name)
214-
logger.warning("use time: %s" % (time.time() - start_time))
214+
def findIP(self, ip, name):
215+
tmp_ip_list = self.finder.findIP(ip, name)
215216
ip_list = [
216217
tmp_ip for tmp_ip in tmp_ip_list if self.monitor_mapping.check(name, tmp_ip)]
217218
if len(ip_list) == 0:

smartdns/monitor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from urllib.parse import urlparse
2-
from twisted.web.iweb import IPolicyForHTTPS
3-
from twisted.web.client import Agent, BrowserLikePolicyForHTTPS
41
from twisted.internet import task, ssl, reactor
2+
from twisted.web.client import Agent, BrowserLikePolicyForHTTPS
53
from twisted.web.http_headers import Headers
4+
from twisted.web.iweb import IPolicyForHTTPS
5+
from urllib.parse import urlparse
66
from zope.interface import implementer
77

88

@@ -29,7 +29,7 @@ def _check(self):
2929
if ip not in self.black_mapping:
3030
self.black_mapping[ip] = 0
3131
url = self.monitor['url'].replace(host, ip, 1).encode("utf8")
32-
agent=Agent(reactor, contextFactory=SmartClientContextFactory(), connectTimeout=30)
32+
agent = Agent(reactor, contextFactory=SmartClientContextFactory(), connectTimeout=30)
3333
agent.request(b'GET', url, headers=Headers({"host": [host, ]})).addCallbacks(
3434
BlackMappingRemover(ip, self.black_mapping), BlackMappingAdder(ip, self.black_mapping))
3535

smartdns/sdns.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,17 @@
33
# version
44
__version__ = '1.0.1.1'
55

6-
import sys
6+
import logging
77
import os
8+
import sys
89
import yaml
9-
import logging
10+
from multiprocessing import cpu_count, Process
1011
from os.path import isfile
11-
from zope.interface import implements
12-
from twisted.internet import defer, interfaces
13-
from twisted.python import failure
14-
from twisted.internet.protocol import DatagramProtocol
15-
from twisted.application import service, internet
16-
from twisted.names import dns, server, client, cache, common, resolve
1712
from twisted.internet import reactor
18-
from multiprocessing import cpu_count, Process
19-
from . import dnsserver, ippool, monitor
13+
from twisted.names import dns
2014

15+
from . import server, monitor
16+
from .finder import Finder
2117

2218
logger = logging.getLogger(__name__)
2319

@@ -44,34 +40,32 @@ def prepare_run(run_env):
4440
monitor_mapping = monitor.MonitorMapping(monitor_config, a_mapping)
4541
# load dns record config file
4642
logger.info('start to init IP pool ......')
47-
finder = ippool.IPPool(
43+
finder = Finder(
4844
os.path.join(run_env['conf'], 'ip.csv'),
4945
os.path.join(run_env['conf'], 'a.yaml'),
5046
monitor_mapping)
5147

5248
run_env['finder'] = finder
5349

5450
# set up a resolver that uses the mapping or a secondary nameserver
55-
dnsforward = []
56-
for i in conf['dnsforward']:
57-
dnsforward_ip, dnsforward_port = i.split(":")
58-
dnsforward.append((dnsforward_ip, int(dnsforward_port)))
51+
dns_forwards = []
52+
for i in conf['dns_forwards']:
53+
dns_forward_ip, dns_forward_port = i.split(":")
54+
dns_forwards.append((dns_forward_ip, int(dns_forward_port)))
5955

6056
# create the protocols
6157
for listen_tcp in conf['listen']['tcp']:
6258
listen_tcp_ip, listen_tcp_port = listen_tcp.split(":")
63-
f = dnsserver.SmartDNSFactory(
64-
caches=[cache.CacheResolver()], clients=[
65-
dnsserver.MapResolver(
66-
finder, a_mapping, ns_mapping, soa_mapping, servers=dnsforward)])
59+
f = server.SmartDNSFactory(clients=[
60+
server.MapResolver(
61+
finder, a_mapping, ns_mapping, soa_mapping, servers=dns_forwards)])
6762
f.noisy = False
6863
run_env['tcp'].append([int(listen_tcp_port), f, listen_tcp_ip])
6964
for listen_udp in conf['listen']['tcp']:
7065
listen_udp_ip, listen_udp_port = listen_udp.split(":")
71-
p = dns.DNSDatagramProtocol(dnsserver.SmartDNSFactory(
72-
caches=[cache.CacheResolver()], clients=[
73-
dnsserver.MapResolver(
74-
finder, a_mapping, ns_mapping, soa_mapping, servers=dnsforward)]))
66+
p = dns.DNSDatagramProtocol(server.SmartDNSFactory(clients=[
67+
server.MapResolver(
68+
finder, a_mapping, ns_mapping, soa_mapping, servers=dns_forwards)]))
7569
p.noisy = False
7670
run_env['udp'].append([int(listen_udp_port), p, listen_udp_ip])
7771
return conf

0 commit comments

Comments
 (0)