#-*- coding: utf-8
import sys
from ConfigParser import RawConfigParser
from datetime import datetime, timedelta
import getopt
import dbengine
import json
import locale
import time
import cx_Oracle
from path import path
import re


VERSION = "1.5"
SYSENCODING = locale.getdefaultlocale()[1]

if not SYSENCODING:
    SYSENCODING = "utf-8"


def json_minify(json, strip_space=True):
    tokenizer = re.compile('"|(/\*)|(\*/)|(//)|\n|\r')
    in_string = False
    in_multiline_comment = False
    in_singleline_comment = False

    new_str = []
    from_index = 0  # from is a keyword in Python

    for match in re.finditer(tokenizer, json):

        if not in_multiline_comment and not in_singleline_comment:
            tmp2 = json[from_index:match.start()]
            if not in_string and strip_space:
                tmp2 = re.sub('[ \t\n\r]*', '', tmp2)  # replace only white space defined in standard
            new_str.append(tmp2)

        from_index = match.end()

        if match.group() == '"' and not in_multiline_comment and not in_singleline_comment:
            escaped = re.search('(\\\\)*$', json[:match.start()])
            if not in_string or escaped is None or len(escaped.group()) % 2 == 0:
                # start of string with ", or unescaped " character found to end string
                in_string = not in_string
            from_index -= 1  # include " character in next catch

        elif match.group() == '/*' and not in_string and not in_multiline_comment and not in_singleline_comment:
            in_multiline_comment = True
        elif match.group() == '*/' and not in_string and in_multiline_comment and not in_singleline_comment:
            in_multiline_comment = False
        elif match.group() == '//' and not in_string and not in_multiline_comment and not in_singleline_comment:
            in_singleline_comment = True
        elif (match.group() == '\n' or match.group() == '\r') and not in_string and not in_multiline_comment and in_singleline_comment:
            in_singleline_comment = False
        elif not in_multiline_comment and not in_singleline_comment and (
                 match.group() not in ['\n', '\r', ' ', '\t'] or not strip_space):
                new_str.append(match.group())

    new_str.append(json[from_index:])
    return ''.join(new_str)


class V2DBConvert(object):
    def __init__(self):
        super(V2DBConvert, self).__init__()
        self._src_db = None
        self._dst_db = None
        self._base_path = path(sys.argv[0]).realpath().dirname()
        self._base_name = path(sys.argv[0]).basename()
        self._cvt_process_file = self._base_path / '.cvtstep.txt'
        self._cvt_config_file = self._base_path / "cvtcfg.ini"
        self._cvt_rule_file = None
        self._cvt_rules = None
        self._cvt_sql_params = {}
        self._cvt_exec_rule_name = []
        self._list_rules = False
        self._manual_convert = False

    def run(self):
        if not self._parse_command_line():
            self.usage()
            sys.exit(1)

        current = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        print u"启动转换程序 Version: {0}, {1}".format(VERSION, current)

        if not self._load_convert_file():
            self.die(u"不能加载转换规则文件")
        print u"加载转换规则文件成功<{0}>".format(self._cvt_rule_file.basename())

        if self._list_rules:
            for rule in self._cvt_rules:
                print u"转换规则[{0}]: {1}<{2}>".format(rule['name'], rule['desc'],
                        rule.get('remark', ''))
            return

        self._load_config()
        try:
            print u"等待连接原始数据..."
            param = dict(database=self._src_database, host=self._src_host,
                port=self._src_port, user=self._src_user, password=self._src_pswd,
                name=self._src_name)
            self._src_db = self._connect_db(self._src_type, **param)
            print u"连接源库成功"

            print u"等待连接目标数据..."
            param = dict(database=self._dst_database, host=self._dst_host,
                port=self._dst_port, user=self._dst_user, password=self._dst_pswd,
                name=self._dst_name)
            self._dst_db = self._connect_db(self._dst_type, **param)
            print u"连接目标库成功"
        except Exception, ex:
            print u"连接数据失败, {0}".format(str(ex).decode('utf-8'))
            return

        self._convert()

        self._src_db.close()
        self._dst_db.close()

    def die(self, msg, exit_code=1):
        print msg
        sys.exit(exit_code)

    def _convert(self):
        for rule in self._cvt_rules:
            if self._cvt_exec_rule_name:
                if rule['name'] not in self._cvt_exec_rule_name:
                    continue
            if self._manual_convert:
                msg = u"是否转换规则<{0}>[{1}](Y/n)".format(rule['name'], rule['desc'])
                ans = raw_input(msg.encode("utf-8"))
                if not (ans is None or ans in ('Y', 'y')):
                    return
            else:
                print u"等待转换<{0}>[{1}]".format(rule['name'], rule['desc'])
            if not self._do_convert(rule):
                return

    def _do_execute_sql(self, cursor, statment):
        try:
            cursor.prepare(statment)
            sql_params = {}
            for param in cursor.bindnames():
                if param not in self._cvt_sql_params:
                    print u"参数<{0}>未指定".format(param)
                    return False
                sql_params[param] = self._cvt_sql_params[param]
            # print sql_params
            if not cursor.execute(statment, **sql_params):
                return True

            return True
        except Exception as exc:
            try:
                if isinstance(exc, cx_Oracle.DatabaseError):
                    error, = exc.args
                    msg = error.message.decode("utf-8")
                else:
                    msg = "{0}".format(exc)
                print u"执行sql 失败, {0}".format(msg)
            except UnicodeError:
                print exc
                print u"执行语句[{0}]错误".format(statment)
            return False

    def _do_convert(self, rule):
        query_cursor = self._src_db.new_cursor()

        if rule['action'] == 'truncate':
            print u"Truncate 目标表数据..."
            self._dst_db.exec_sql('truncate table {0}'.format(rule['name']))

        if 'pre_exec' in rule:
            for statment in rule['pre_exec']:
                if not self._do_execute_sql(self._src_db.cursor, statment):
                    print u"执行 pre_exec 语句失败"
                    self._src_db.rollback()
                    return False
                self._src_db.commit()

        if 'pre_dst_exec' in rule:
            for statment in rule['pre_dst_exec']:
                if not self._do_execute_sql(self._dst_db.cursor, statment):
                    print u"执行 pre_dst_exec 语句失败"
                    self._dst_db.rollback()
                    return False
                self._dst_db.commit()

        if not self._do_execute_sql(query_cursor, rule['src_sql']):
            print u"查询原始数据失败"
            return False
        # print query_cursor.description
        count = 0
        self._dst_db.begin_transaction()
        for row in query_cursor.fetchall():
            v = []
            for data in row:
                if isinstance(data, int):
                    v.append('%d' % data)
                elif isinstance(data, float):
                    v.append('%f' % data)
                elif isinstance(data, str) or isinstance(data, unicode):
                    v.append("'{0}'".format(data.replace("'", "''")))
                elif data is None:
                    v.append('NULL')
                else:
                    v.append(data)
            values = ",".join(v)
            values = values.decode('utf-8')
            insert_sql = u'insert into {0} ({1}) values({2})'.format(rule['name'],
                         rule['dest_column'], values)
            # print insert_sql
            count += 1
            if not self._dst_db.exec_sql(insert_sql):
                print u"导入第{0}条记录错误".format(count)
                self._dst_db.rollback()
                print u"SQL<{0}>".format(insert_sql)
                return False

            if count % self._commit_count == 0:
                print u"导入数据 {0} 条".format(count)
                self._dst_db.commit()
                self._dst_db.begin_transaction()

        self._dst_db.commit()
        print u"导入数据 {0} 条".format(count)
        if 'post_exec' in rule:
            print u"执行 post_exec..."
            for statment in rule['post_exec']:
                if not self._do_execute_sql(self._dst_db.cursor, statment):
                    print u"执行 post_exec 语句失败"
                    self._dst_db.rollback()
                    return False
                self._dst_db.commit()
            print u"执行 post_exec 完成"
        query_cursor.close()
        return True

    def _load_convert_file(self):
        rule_doc = json_minify(self._cvt_rule_file.text(encoding="utf-8"))
        self._cvt_rules = json.loads(rule_doc, "utf-8")
        return True

    def _parse_command_line(self):
        execute_rules = ""
        optlist, args = getopt.getopt(sys.argv[1:], ":c:r:t:hlm")
        for k, v in optlist:
            if k == "-c":
                self._cvt_config_file = path(v)
            elif k == "-r":
                self._cvt_rule_file = path(v)
            elif k == "-l":
                self._list_rules = True
            elif k == "-t":
                execute_rules = v
            elif k == "-m":
                self._manual_convert = True
            elif k == "-h":
                self.usage()
                sys.exit(0)

        if not self._cvt_rule_file or not self._cvt_rule_file.exists():
            print u"请指定转换规则文件"
            return False

        if not self._cvt_config_file.exists():
            print u"请指定配置文件"
            return False

        if self._list_rules:
            return True

        if execute_rules:
            target_rule = execute_rules.split(',')
            self._cvt_exec_rule_name = [rule.upper() for rule in target_rule]
        # print self._cvt_exec_rule_name

        for v in args:
            param = v.split('=')
            if len(param) != 2:
                print u"参数[{0}]错误！".format(v)
                return False
            kn = param[0].strip(' ').upper()
            kv = param[1].strip(' ')
            self._cvt_sql_params[kn] = kv

        return True

    def usage(self):
        print u"转换程序 {0} Version: {1}".format(self._base_name, VERSION)
        print u"\t-c   转换配置文件默认为 cvtcfg.ini"
        print u"\t-r   转换配置规则文件"
        print u"\t-t   转换使用表名，多个表名用逗号分隔；例如 t_card, t_customer"
        print u"\t-l   只列出转换规则文件中所有的 target 不做转换"
        print u"\t-h   输出帮助"

    def _load_config(self):
        parser = RawConfigParser()
        parser.read(self._cvt_config_file)
        cfg = dict(parser.items('database'))
        if 'srcdsn' in cfg:
            self._src_database = cfg['srcdsn']
            self._src_host = None
            self._src_port = 0
            self._src_name = None
        else:
            self._src_host = cfg.get('srchost', '')
            self._src_port = int(cfg.get('srcport', 0))
            self._src_name = cfg.get('srcname', '')
            self._src_database = None
        self._src_user = cfg.get('srcuser', '')
        self._src_pswd = cfg.get('srcpswd', '')
        self._src_type = cfg.get('srctype', 'oracle')

        if 'dstdsn' in cfg:
            self._dst_database = cfg['dstdsn']
            self._dst_host = None
            self._dst_port = 0
            self._dst_name = None
        else:
            self._dst_host = cfg.get('dsthost', '')
            self._dst_port = int(cfg.get('dstport', 0))
            self._dst_name = cfg.get('dstname', '')
            self._dst_database = None
        self._dst_user = cfg.get('dstuser', '')
        self._dst_pswd = cfg.get('dstpswd', '')
        self._dst_type = cfg.get('dsttype', 'oracle')

        self._commit_count = int(cfg.get('commitcount', 1000))

    def _connect_db(self, dbtype, **kwargs):
        if dbtype == "oracle":
            db = dbengine.OraEngine()
        elif dbtype == "db2":
            db = dbengine.DB2Engine()
        else:
            raise ValueError(u"不支持数据类型[{0}]".format(dbtype))

        db.connect(database=kwargs['database'], user=kwargs['user'],
                   password=kwargs['password'], host=kwargs['host'],
                   port=kwargs['port'], name=kwargs['name'])
        return db


if __name__ == "__main__":
    begin = time.time()
    cvt = V2DBConvert()
    cvt.run()
    print u"执行时间 %.3f 秒" % ((time.time() - begin))
