blob: 4c09bd2e088ef5a2c1f0b164d371906c1d1d35a2 [file] [log] [blame]
#-*- coding: utf-8
import sys
from ConfigParser import RawConfigParser
from datetime import datetime, timedelta
import getopt
import dbengine
import json
import locale
import time
from path import path
import re
VERSION = "1.1"
SYSENCODING = locale.getdefaultlocale()[1]
if not SYSENCODING:
SYSENCODING = "utf-8"
def json_minify(json, strip_space=True):
tokenizer = re.compile('"|(/\*)|(\*/)|(//)|\n|\r')
in_string = False
in_multiline_comment = False
in_singleline_comment = False
new_str = []
from_index = 0 # from is a keyword in Python
for match in re.finditer(tokenizer, json):
if not in_multiline_comment and not in_singleline_comment:
tmp2 = json[from_index:match.start()]
if not in_string and strip_space:
tmp2 = re.sub('[ \t\n\r]*', '', tmp2) # replace only white space defined in standard
new_str.append(tmp2)
from_index = match.end()
if match.group() == '"' and not in_multiline_comment and not in_singleline_comment:
escaped = re.search('(\\\\)*$', json[:match.start()])
if not in_string or escaped is None or len(escaped.group()) % 2 == 0:
# start of string with ", or unescaped " character found to end string
in_string = not in_string
from_index -= 1 # include " character in next catch
elif match.group() == '/*' and not in_string and not in_multiline_comment and not in_singleline_comment:
in_multiline_comment = True
elif match.group() == '*/' and not in_string and in_multiline_comment and not in_singleline_comment:
in_multiline_comment = False
elif match.group() == '//' and not in_string and not in_multiline_comment and not in_singleline_comment:
in_singleline_comment = True
elif (match.group() == '\n' or match.group() == '\r') and not in_string and not in_multiline_comment and in_singleline_comment:
in_singleline_comment = False
elif not in_multiline_comment and not in_singleline_comment and (
match.group() not in ['\n', '\r', ' ', '\t'] or not strip_space):
new_str.append(match.group())
new_str.append(json[from_index:])
return ''.join(new_str)
class V2DBConvert(object):
def __init__(self):
super(V2DBConvert, self).__init__()
self._src_db = None
self._dst_db = None
self._base_path = path(sys.argv[0]).realpath().dirname()
self._base_name = path(sys.argv[0]).basename()
self._cvt_process_file = self._base_path / '.cvtstep.txt'
self._cvt_config_file = self._base_path / "cvtcfg.ini"
self._cvt_rule_file = None
self._cvt_rules = None
self._cvt_sql_params = {}
self._cvt_exec_rule_name = []
self._list_rules = False
self._manual_convert = False
def run(self):
if not self._parse_command_line():
self.usage()
sys.exit(1)
current = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print "启动转换程序 Version: {0}, {1}".format(VERSION, current)
if not self._load_convert_file():
self.die(u"不能加载转换规则文件")
print u"加载转换规则文件成功<{0}>".format(self._cvt_rule_file.basename())
if self._list_rules:
for rule in self._cvt_rules:
print u"转换规则[{0}]: {1}<{2}>".format(rule['name'], rule['desc'],
rule.get('remark', ''))
return
self._load_config()
try:
print u"等待连接原始数据..."
param = dict(database=self._src_database, host=self._src_host,
port=self._src_port, user=self._src_user, password=self._src_pswd,
name=self._src_name)
self._src_db = self._connect_db(self._src_type, **param)
print u"连接源库成功"
print u"等待连接目标数据..."
param = dict(database=self._dst_database, host=self._dst_host,
port=self._dst_port, user=self._dst_user, password=self._dst_pswd,
name=self._dst_name)
self._dst_db = self._connect_db(self._dst_type, **param)
print u"连接目标库成功"
except Exception, ex:
print u"连接数据失败, {0}".format(ex)
return
self._convert()
self._src_db.close()
self._dst_db.close()
def die(self, msg, exit_code=1):
print msg
sys.exit(exit_code)
def _convert(self):
for rule in self._cvt_rules:
if self._cvt_exec_rule_name:
if rule['name'] not in self._cvt_exec_rule_name:
continue
if self._manual_convert:
msg = u"是否转换规则<{0}>[{1}](Y/n)".format(rule['name'], rule['desc'])
ans = raw_input(msg.encode("utf-8"))
if not (ans is None or ans in ('Y', 'y')):
return
else:
print u"等待转换<{0}>[{1}]".format(rule['name'], rule['desc'])
if not self._do_convert(rule):
return
def _do_execute_sql(self, cursor, statment):
try:
cursor.prepare(statment)
sql_params = {}
for param in cursor.bindnames():
if param not in self._cvt_sql_params:
print u"参数<{0}>未指定".format(param)
return False
sql_params[param] = self._cvt_sql_params[param]
# print sql_params
if not cursor.execute(statment, **sql_params):
return True
return True
except Exception as e:
print u"执行sql 失败, {0}".format(e)
return False
def _do_convert(self, rule):
query_cursor = self._src_db.new_cursor()
if rule['action'] == 'truncate':
print u"Truncate 目标表数据..."
self._dst_db.exec_sql('truncate table {0}'.format(rule['name']))
if 'pre_exec' in rule:
for statment in rule['pre_exec']:
if not self._do_execute_sql(self._src_db.cursor, statment):
print u"执行 pre_exec 语句失败"
self._src_db.rollback()
return False
self._src_db.commit()
if not self._do_execute_sql(query_cursor, rule['src_sql']):
print u"查询原始数据失败"
return False
# print query_cursor.description
count = 0
self._dst_db.begin_transaction()
for row in query_cursor.fetchall():
v = []
for data in row:
if isinstance(data, int):
v.append('%d' % data)
elif isinstance(data, float):
v.append('%f' % data)
elif isinstance(data, str) or isinstance(data, unicode):
v.append("'{0}'".format(data))
elif data is None:
v.append('NULL')
else:
v.append(data)
values = ",".join(v)
insert_sql = 'insert into {0} ({1}) values({2})'.format(rule['name'],
rule['dest_column'], values)
# print insert_sql
if not self._dst_db.exec_sql(insert_sql):
self._dst_db.rollback()
return False
count += 1
if count % self._commit_count == 0:
print u"导入数据 {0} 条".format(count)
self._dst_db.commit()
self._dst_db.begin_transaction()
self._dst_db.commit()
print u"导入数据 {0} 条".format(count)
if 'post_exec' in rule:
print u"执行 post_exec..."
for statment in rule['post_exec']:
if not self._do_execute_sql(self._dst_db.cursor, statment):
print u"执行 post_exec 语句失败"
self._dst_db.rollback()
return False
self._dst_db.commit()
print u"执行 post_exec 完成"
query_cursor.close()
return True
def _load_convert_file(self):
rule_doc = json_minify(self._cvt_rule_file.text(encoding="utf-8"))
self._cvt_rules = json.loads(rule_doc, "utf-8")
return True
def _parse_command_line(self):
execute_rules = ""
optlist, args = getopt.getopt(sys.argv[1:], ":c:r:t:hlm")
for k, v in optlist:
if k == "-c":
self._cvt_config_file = path(v)
elif k == "-r":
self._cvt_rule_file = path(v)
elif k == "-l":
self._list_rules = True
elif k == "-t":
execute_rules = v
elif k == "-m":
self._manual_convert = True
elif k == "-h":
self.usage()
sys.exit(0)
if not self._cvt_rule_file or not self._cvt_rule_file.exists():
print u"请指定转换规则文件"
return False
if not self._cvt_config_file.exists():
print u"请指定配置文件"
return False
if self._list_rules:
return True
if execute_rules:
target_rule = execute_rules.split(',')
self._cvt_exec_rule_name = [rule.upper() for rule in target_rule]
# print self._cvt_exec_rule_name
for v in args:
param = v.split('=')
if len(param) != 2:
print u"参数[{0}]错误!".format(v)
return False
kn = param[0].strip(' ').upper()
kv = param[1].strip(' ')
self._cvt_sql_params[kn] = kv
return True
def usage(self):
print u"转换程序 {0} Version: {1}".format(self._base_name, VERSION)
print u"\t-c 转换配置文件默认为 cvtcfg.ini"
print u"\t-r 转换配置规则文件"
print u"\t-t 转换使用表名,多个表名用逗号分隔;例如 t_card, t_customer"
print u"\t-l 只列出转换规则文件中所有的 target 不做转换"
print u"\t-h 输出帮助"
def _load_config(self):
parser = RawConfigParser()
parser.read(self._cvt_config_file)
cfg = dict(parser.items('database'))
if 'srcdsn' in cfg:
self._src_database = cfg['srcdsn']
self._src_host = None
self._src_port = 0
self._src_name = None
else:
self._src_host = cfg.get('srchost', '')
self._src_port = int(cfg.get('srcport', 0))
self._src_name = cfg.get('srcname', '')
self._src_database = None
self._src_user = cfg.get('srcuser', '')
self._src_pswd = cfg.get('srcpswd', '')
self._src_type = cfg.get('srctype', 'oracle')
if 'dstdsn' in cfg:
self._dst_database = cfg['dstdsn']
self._dst_host = None
self._dst_port = 0
self._dst_name = None
else:
self._dst_host = cfg.get('dsthost', '')
self._dst_port = int(cfg.get('dstport', 0))
self._dst_name = cfg.get('dstname', '')
self._dst_database = None
self._dst_user = cfg.get('dstuser', '')
self._dst_pswd = cfg.get('dstpswd', '')
self._dst_type = cfg.get('dsttype', 'oracle')
self._commit_count = int(cfg.get('commitcount', 1000))
def _connect_db(self, dbtype, **kwargs):
if dbtype == "oracle":
db = dbengine.OraEngine()
elif dbtype == "db2":
db = dbengine.DB2Engine()
else:
raise ValueError(u"不支持数据类型[{0}]".format(dbtype))
db.connect(database=kwargs['database'], user=kwargs['user'],
password=kwargs['password'], host=kwargs['host'],
port=kwargs['port'], name=kwargs['name'])
return db
if __name__ == "__main__":
begin = time.time()
cvt = V2DBConvert()
cvt.run()
print u"执行时间 %.3f 秒" % ((time.time() - begin))