#-*- coding: utf-8 | |
import sys | |
from ConfigParser import RawConfigParser | |
from datetime import datetime, timedelta | |
import getopt | |
import dbengine | |
import json | |
import locale | |
import time | |
from path import path | |
import re | |
VERSION = "1.1" | |
SYSENCODING = locale.getdefaultlocale()[1] | |
if not SYSENCODING: | |
SYSENCODING = "utf-8" | |
def json_minify(json, strip_space=True): | |
tokenizer = re.compile('"|(/\*)|(\*/)|(//)|\n|\r') | |
in_string = False | |
in_multiline_comment = False | |
in_singleline_comment = False | |
new_str = [] | |
from_index = 0 # from is a keyword in Python | |
for match in re.finditer(tokenizer, json): | |
if not in_multiline_comment and not in_singleline_comment: | |
tmp2 = json[from_index:match.start()] | |
if not in_string and strip_space: | |
tmp2 = re.sub('[ \t\n\r]*', '', tmp2) # replace only white space defined in standard | |
new_str.append(tmp2) | |
from_index = match.end() | |
if match.group() == '"' and not in_multiline_comment and not in_singleline_comment: | |
escaped = re.search('(\\\\)*$', json[:match.start()]) | |
if not in_string or escaped is None or len(escaped.group()) % 2 == 0: | |
# start of string with ", or unescaped " character found to end string | |
in_string = not in_string | |
from_index -= 1 # include " character in next catch | |
elif match.group() == '/*' and not in_string and not in_multiline_comment and not in_singleline_comment: | |
in_multiline_comment = True | |
elif match.group() == '*/' and not in_string and in_multiline_comment and not in_singleline_comment: | |
in_multiline_comment = False | |
elif match.group() == '//' and not in_string and not in_multiline_comment and not in_singleline_comment: | |
in_singleline_comment = True | |
elif (match.group() == '\n' or match.group() == '\r') and not in_string and not in_multiline_comment and in_singleline_comment: | |
in_singleline_comment = False | |
elif not in_multiline_comment and not in_singleline_comment and ( | |
match.group() not in ['\n', '\r', ' ', '\t'] or not strip_space): | |
new_str.append(match.group()) | |
new_str.append(json[from_index:]) | |
return ''.join(new_str) | |
class V2DBConvert(object): | |
def __init__(self): | |
super(V2DBConvert, self).__init__() | |
self._src_db = None | |
self._dst_db = None | |
self._base_path = path(sys.argv[0]).realpath().dirname() | |
self._base_name = path(sys.argv[0]).basename() | |
self._cvt_process_file = self._base_path / '.cvtstep.txt' | |
self._cvt_config_file = self._base_path / "cvtcfg.ini" | |
self._cvt_rule_file = None | |
self._cvt_rules = None | |
self._cvt_sql_params = {} | |
self._cvt_exec_rule_name = [] | |
self._list_rules = False | |
self._manual_convert = False | |
def run(self): | |
if not self._parse_command_line(): | |
self.usage() | |
sys.exit(1) | |
current = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
print "启动转换程序 Version: {0}, {1}".format(VERSION, current) | |
if not self._load_convert_file(): | |
self.die(u"不能加载转换规则文件") | |
print u"加载转换规则文件成功<{0}>".format(self._cvt_rule_file.basename()) | |
if self._list_rules: | |
for rule in self._cvt_rules: | |
print u"转换规则[{0}]: {1}<{2}>".format(rule['name'], rule['desc'], | |
rule.get('remark', '')) | |
return | |
self._load_config() | |
try: | |
print u"等待连接原始数据..." | |
param = dict(database=self._src_database, host=self._src_host, | |
port=self._src_port, user=self._src_user, password=self._src_pswd, | |
name=self._src_name) | |
self._src_db = self._connect_db(self._src_type, **param) | |
print u"连接源库成功" | |
print u"等待连接目标数据..." | |
param = dict(database=self._dst_database, host=self._dst_host, | |
port=self._dst_port, user=self._dst_user, password=self._dst_pswd, | |
name=self._dst_name) | |
self._dst_db = self._connect_db(self._dst_type, **param) | |
print u"连接目标库成功" | |
except Exception, ex: | |
print u"连接数据失败, {0}".format(ex) | |
return | |
self._convert() | |
self._src_db.close() | |
self._dst_db.close() | |
def die(self, msg, exit_code=1): | |
print msg | |
sys.exit(exit_code) | |
def _convert(self): | |
for rule in self._cvt_rules: | |
if self._cvt_exec_rule_name: | |
if rule['name'] not in self._cvt_exec_rule_name: | |
continue | |
if self._manual_convert: | |
msg = u"是否转换规则<{0}>[{1}](Y/n)".format(rule['name'], rule['desc']) | |
ans = raw_input(msg.encode("utf-8")) | |
if not (ans is None or ans in ('Y', 'y')): | |
return | |
else: | |
print u"等待转换<{0}>[{1}]".format(rule['name'], rule['desc']) | |
if not self._do_convert(rule): | |
return | |
def _do_execute_sql(self, cursor, statment): | |
try: | |
cursor.prepare(statment) | |
sql_params = {} | |
for param in cursor.bindnames(): | |
if param not in self._cvt_sql_params: | |
print u"参数<{0}>未指定".format(param) | |
return False | |
sql_params[param] = self._cvt_sql_params[param] | |
# print sql_params | |
if not cursor.execute(statment, **sql_params): | |
return True | |
return True | |
except Exception as e: | |
print u"执行sql 失败, {0}".format(e) | |
return False | |
def _do_convert(self, rule): | |
query_cursor = self._src_db.new_cursor() | |
if rule['action'] == 'truncate': | |
print u"Truncate 目标表数据..." | |
self._dst_db.exec_sql('truncate table {0}'.format(rule['name'])) | |
if 'pre_exec' in rule: | |
for statment in rule['pre_exec']: | |
if not self._do_execute_sql(self._src_db.cursor, statment): | |
print u"执行 pre_exec 语句失败" | |
self._src_db.rollback() | |
return False | |
self._src_db.commit() | |
if not self._do_execute_sql(query_cursor, rule['src_sql']): | |
print u"查询原始数据失败" | |
return False | |
# print query_cursor.description | |
count = 0 | |
self._dst_db.begin_transaction() | |
for row in query_cursor.fetchall(): | |
v = [] | |
for data in row: | |
if isinstance(data, int): | |
v.append('%d' % data) | |
elif isinstance(data, float): | |
v.append('%f' % data) | |
elif isinstance(data, str) or isinstance(data, unicode): | |
v.append("'{0}'".format(data)) | |
elif data is None: | |
v.append('NULL') | |
else: | |
v.append(data) | |
values = ",".join(v) | |
insert_sql = 'insert into {0} ({1}) values({2})'.format(rule['name'], | |
rule['dest_column'], values) | |
# print insert_sql | |
if not self._dst_db.exec_sql(insert_sql): | |
self._dst_db.rollback() | |
return False | |
count += 1 | |
if count % self._commit_count == 0: | |
print u"导入数据 {0} 条".format(count) | |
self._dst_db.commit() | |
self._dst_db.begin_transaction() | |
self._dst_db.commit() | |
print u"导入数据 {0} 条".format(count) | |
if 'post_exec' in rule: | |
print u"执行 post_exec..." | |
for statment in rule['post_exec']: | |
if not self._do_execute_sql(self._dst_db.cursor, statment): | |
print u"执行 post_exec 语句失败" | |
self._dst_db.rollback() | |
return False | |
self._dst_db.commit() | |
print u"执行 post_exec 完成" | |
query_cursor.close() | |
return True | |
def _load_convert_file(self): | |
rule_doc = json_minify(self._cvt_rule_file.text(encoding="utf-8")) | |
self._cvt_rules = json.loads(rule_doc, "utf-8") | |
return True | |
def _parse_command_line(self): | |
execute_rules = "" | |
optlist, args = getopt.getopt(sys.argv[1:], ":c:r:t:hlm") | |
for k, v in optlist: | |
if k == "-c": | |
self._cvt_config_file = path(v) | |
elif k == "-r": | |
self._cvt_rule_file = path(v) | |
elif k == "-l": | |
self._list_rules = True | |
elif k == "-t": | |
execute_rules = v | |
elif k == "-m": | |
self._manual_convert = True | |
elif k == "-h": | |
self.usage() | |
sys.exit(0) | |
if not self._cvt_rule_file or not self._cvt_rule_file.exists(): | |
print u"请指定转换规则文件" | |
return False | |
if not self._cvt_config_file.exists(): | |
print u"请指定配置文件" | |
return False | |
if self._list_rules: | |
return True | |
if execute_rules: | |
target_rule = execute_rules.split(',') | |
self._cvt_exec_rule_name = [rule.upper() for rule in target_rule] | |
# print self._cvt_exec_rule_name | |
for v in args: | |
param = v.split('=') | |
if len(param) != 2: | |
print u"参数[{0}]错误!".format(v) | |
return False | |
kn = param[0].strip(' ').upper() | |
kv = param[1].strip(' ') | |
self._cvt_sql_params[kn] = kv | |
return True | |
def usage(self): | |
print u"转换程序 {0} Version: {1}".format(self._base_name, VERSION) | |
print u"\t-c 转换配置文件默认为 cvtcfg.ini" | |
print u"\t-r 转换配置规则文件" | |
print u"\t-t 转换使用表名,多个表名用逗号分隔;例如 t_card, t_customer" | |
print u"\t-l 只列出转换规则文件中所有的 target 不做转换" | |
print u"\t-h 输出帮助" | |
def _load_config(self): | |
parser = RawConfigParser() | |
parser.read(self._cvt_config_file) | |
cfg = dict(parser.items('database')) | |
if 'srcdsn' in cfg: | |
self._src_database = cfg['srcdsn'] | |
self._src_host = None | |
self._src_port = 0 | |
self._src_name = None | |
else: | |
self._src_host = cfg.get('srchost', '') | |
self._src_port = int(cfg.get('srcport', 0)) | |
self._src_name = cfg.get('srcname', '') | |
self._src_database = None | |
self._src_user = cfg.get('srcuser', '') | |
self._src_pswd = cfg.get('srcpswd', '') | |
self._src_type = cfg.get('srctype', 'oracle') | |
if 'dstdsn' in cfg: | |
self._dst_database = cfg['dstdsn'] | |
self._dst_host = None | |
self._dst_port = 0 | |
self._dst_name = None | |
else: | |
self._dst_host = cfg.get('dsthost', '') | |
self._dst_port = int(cfg.get('dstport', 0)) | |
self._dst_name = cfg.get('dstname', '') | |
self._dst_database = None | |
self._dst_user = cfg.get('dstuser', '') | |
self._dst_pswd = cfg.get('dstpswd', '') | |
self._dst_type = cfg.get('dsttype', 'oracle') | |
self._commit_count = int(cfg.get('commitcount', 1000)) | |
def _connect_db(self, dbtype, **kwargs): | |
if dbtype == "oracle": | |
db = dbengine.OraEngine() | |
elif dbtype == "db2": | |
db = dbengine.DB2Engine() | |
else: | |
raise ValueError(u"不支持数据类型[{0}]".format(dbtype)) | |
db.connect(database=kwargs['database'], user=kwargs['user'], | |
password=kwargs['password'], host=kwargs['host'], | |
port=kwargs['port'], name=kwargs['name']) | |
return db | |
if __name__ == "__main__": | |
begin = time.time() | |
cvt = V2DBConvert() | |
cvt.run() | |
print u"执行时间 %.3f 秒" % ((time.time() - begin)) |