python 比较 mysql 表结构差异

阅读 13

2024-06-09

最近在做项目的时候,需要比对两个数据库的表结构差异,由于表数量比较多,人工比对的话需要大量时间,且不可复用,于是想到用 python 写一个脚本来达到诉求,下次有相同诉求的时候只需改 sql 文件名即可。

compare_diff.py:

import re
import json


# 建表语句对象
class TableStmt(object):
    table_name = ""
    create_stmt = ""


# 表对象
class Table(object):
    table_name = ""
    fields = []
    indexes = []


# 字段对象
class Field(object):
    field_name = ""
    field_type = ""


# 索引对象
class Index(object):
    name = ""
    type = ""
    columns = ""


# 自定义JSON序列化器,非必须,打印时可用到
def obj_2_dict(obj):
    if isinstance(obj, Field):
        return {
            "field_name": obj.field_name,
            "field_type": obj.field_type
        }
    elif isinstance(obj, Index):
        return {
            "name": obj.name,
            "type": obj.type,
            "columns": obj.columns
        }
    raise TypeError(f"Type {type(obj)} is not serializable")


# 正则表达式模式来匹配完整的建表语句
create_table_pattern = re.compile(
    r"CREATE TABLE `(?P<table_name>\w+)`.*?\)\s*ENGINE[A-Za-z0-9=_ ''\n\r\u4e00-\u9fa5]+;",
    re.DOTALL | re.IGNORECASE
)

# 正则表达式模式来匹配字段名和字段类型,只提取基本类型忽略其他信息
table_pattern = re.compile(
    r"^\s*`(?P<field>\w+)`\s+(?P<type>[a-zA-Z]+(?:\(\d+(?:,\d+)?\))?)",
    re.MULTILINE
)

# 正则表达式模式来匹配索引定义
index_pattern = re.compile(r'(?<!`)KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'
                           r'PRIMARY\s+KEY\s*\(([^)]+)\)|'
                           r'UNIQUE\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)|'
                           r'FULLTEXT\s+KEY\s+`?(\w+)`?\s*\(([^)]+)\)',
                           re.IGNORECASE)


# 提取每个表名及建表语句
def extract_create_table_statements(sql_script):
    matches = create_table_pattern.finditer(sql_script)
    table_create_stmts = []
    for match in matches:
        tableStmt = TableStmt()
        tableStmt.table_name = match.group('table_name').lower()  # 表名统一转换成小写
        tableStmt.create_stmt = match.group(0).strip()  # 获取匹配到的整个建表语句
        table_create_stmts.append(tableStmt)
    return table_create_stmts


# 提取索引
def extract_indexes(sql):
    matches = index_pattern.findall(sql)
    indexes = []
    for match in matches:
        index = Index()
        if match[0]:  # 普通索引
            index.type = 'index'
            index.name = match[0].lower()
            index.columns = match[1].lower()
        elif match[2]:  # 主键
            index.type = 'primary key'
            index.name = 'primary'
            index.columns = match[2].lower()
        elif match[3]:  # 唯一索引
            index.type = 'unique index'
            index.name = match[3].lower()
            index.columns = match[4].lower()
        elif match[5]:  # 全文索引
            index.type = 'fulltext index'
            index.name = match[5].lower()
            index.columns = match[6].lower()
        indexes.append(index)
    return indexes


# 提取字段
def extract_fields(sql):
    matches = table_pattern.finditer(sql)
    fields = []
    for match in matches:
        field = Field()
        field.field_name = match.group('field').lower()  # 字段名统一转换成小写
        field.field_type = match.group('type').lower()  # 字段类型统一转换小写
        fields.append(field)
    return fields


# 提取表信息
def extract_table_info(tableStmt: TableStmt):
    table = Table()
    table.table_name = tableStmt.table_name.lower()
    # 获取字段
    table.fields = extract_fields(tableStmt.create_stmt)
    # 获取索引
    table.indexes = extract_indexes(tableStmt.create_stmt)
    return table


# 提取sql脚本中所有的表
def get_all_tables(sql_script):
    table_map = {}
    table_stmts = extract_create_table_statements(sql_script)
    for stmt in table_stmts:
        table = extract_table_info(stmt)
        table_map[table.table_name] = table
    return table_map


# 比较两个表的字段
def compare_fields(source: Table, target: Table):
    source_fields_map = {field.field_name: field for field in source.fields}
    target_fields_map = {field.field_name: field for field in target.fields}

    source_fields_not_in_target = []
    fields_type_not_match = []
    #  source表有,而target表没有的字段
    for field in source.fields:
        if field.field_name not in target_fields_map.keys():
            source_fields_not_in_target.append(field.field_name)
            continue

        target_field = target_fields_map.get(field.field_name)
        if field.field_type != target_field.field_type:
            fields_type_not_match.append(
                "field=" + field.field_name + ", source type: " + field.field_type + ", target type: " + target_field.field_type)

    target_fields_not_in_source = []
    #  target表有,而source表没有的字段
    for field in target.fields:
        if field.field_name not in source_fields_map.keys():
            target_fields_not_in_source.append(field.field_name)
            continue
        # 不用再比较type了,因为如果这个字段在source和target都有的话,前面已经比较过type了

    return source_fields_not_in_target, fields_type_not_match, target_fields_not_in_source


# 比较两个表的索引
def compare_indexes(source: Table, target: Table):
    source_indexes_map = {index.name: index for index in source.indexes}
    target_indexes_map = {index.name: index for index in target.indexes}

    source_indexes_not_in_target = []
    index_column_not_match = []
    index_type_not_match = []
    for index in source.indexes:
        if index.name not in target_indexes_map.keys():
            # source表有而target表没有的索引
            source_indexes_not_in_target.append(index.name)
            continue

        target_index = target_indexes_map.get(index.name)
        # 索引名相同,类型不同
        if index.type != target_index.type:
            index_type_not_match.append(
                "name=" + index.name + ", source type: " + index.type + ", target type: " + target_index.type)
            continue

        # 索引名和类型都相同,字段不同
        if index.columns != target_index.columns:
            index_column_not_match.append(
                "name=" + index.name + ", source columns=" + index.columns + ", target columns=" + target_index.columns)

    target_indexes_not_in_source = []
    for index in target.indexes:
        if index.name not in source_indexes_map.keys():
            # target表有而source表没有的索引
            target_indexes_not_in_source.append(index.name)
            continue

    return source_indexes_not_in_target, index_column_not_match, index_type_not_match, target_indexes_not_in_source


# 打印比较的结果,如果结果为空列表(说明没有不同)则不打印
def print_diff(desc, compare_result):
    if len(compare_result) > 0:
        print(f"{desc} {compare_result}")


# 比较脚本里面的所有表
def compare_table(source_sql_script, target_sql_script):
    source_table_map = get_all_tables(source_sql_script)
    target_table_map = get_all_tables(target_sql_script)

    source_table_not_in_target = []
    for key, source_table in source_table_map.items():
        # 只比较白名单里面的表
        if len(white_list_tables) > 0 and key not in white_list_tables:
            continue

        # 不比较黑名单里面的表
        if len(black_list_tables) > 0 and key in black_list_tables:
            continue

        if key not in target_table_map.keys():
            # source有而target没有的表
            source_table_not_in_target.append(key)
            continue

        target_table = target_table_map[key]
        # 比较字段
        (source_fields_not_in_target, fields_type_not_match
         , target_fields_not_in_source) = compare_fields(source_table, target_table)

        # 比较索引
        (source_indexes_not_in_target, index_column_not_match
         , index_type_not_match, target_indexes_not_in_source) = compare_indexes(source_table, target_table)

        print(f"====== table = {key} ======")
        print_diff("source field not in target, fields:", source_fields_not_in_target)
        print_diff("target field not in source, fields:", target_fields_not_in_source)
        print_diff("field type not match:", fields_type_not_match)
        print_diff("source index not in target, indexes:", source_indexes_not_in_target)
        print_diff("target index not in source, indexes:", target_indexes_not_in_source)
        print_diff("index type not match:", index_type_not_match)
        print_diff("index column not match:", index_column_not_match)
        print("")

    # 找出target有而source没有的表
    target_table_not_in_source = []
    for key, target_table in target_table_map.items():
        # 只比较白名单里面的表
        if len(white_list_tables) > 0 and key not in white_list_tables:
            continue

        # 不比较黑名单里面的表
        if len(black_list_tables) > 0 and key in black_list_tables:
            continue

        if key not in source_table_map.keys():
            target_table_not_in_source.append(key)

    print_diff("source table not in target, table list:", source_table_not_in_target)
    print_diff("target table not in source, table list:", target_table_not_in_source)


# 读取sql文件
def sql_read(file_name):
    with open(file_name, "r", encoding='utf-8') as file:
        return file.read()


def print_all_tables():
    table_map = get_all_tables(sql_read("sql1.sql"))
    for key, item in table_map.items():
        print(key)
        print(json.dumps(item.fields, default=obj_2_dict, ensure_ascii=False, indent=4))
        print(json.dumps(item.indexes, default=obj_2_dict, ensure_ascii=False, indent=4))
        print("")


# print_all_tables()

# 黑白名单设置,适用于只比较所有表中一部分表的情况
# 白名单表,不为空的话,只比较这里面的表
white_list_tables = []
# 黑名单表,不为空的话,不比较这里面的表
black_list_tables = []

if __name__ == '__main__':
    # 说明:mysql默认大小写不敏感,如果数据库设置了大小写敏感,脚本需要修改,里面所有的表名、字段名和索引名都默认转了小写再去比较的
    source_script = sql_read("sql1.sql")
    target_script = sql_read("sql2.sql")
    compare_table(source_script, target_script)

运行效果如下:

====== table = table1 ======
source field not in target, fields: ['age', 'email']
target field not in source, fields: ['name']
field type not match: ['field=created_at, source type: date, target type: bigint(20)', 'field=updated_at, source type: timestamp, target type: date']
source index not in target, indexes: ['unique_name']
target index not in source, indexes: ['idx_country_env']

====== table = table2 ======
index type not match: ['name=fulltext_index, source type: fulltext index, target type: index']
index column not match: ['name=index, source columns=`age`, target columns=`description`']

====== table = table3 ======
index column not match: ['name=primary, source columns=`id`, `value`, target columns=`value`, `id`']

source table not in target, table list: ['activity_instance']
target table not in source, table list: ['table5']

结果说明:

  • 按照 table 来打印 source table 和 target table 的字段和索引差异,此时 table 在两个 sql 脚本里都存在
  • 最后打印只在其中一个 sql 脚本里存在的 table list

sql1.sql:

CREATE TABLE `table1` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `age` INT(11) DEFAULT NULL,
  `email` varchar(32)   DEFAULT NULL COMMENT '邮箱',
  `created_at` date DEFAULT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  UNIQUE KEY `unique_name` (`name`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';

CREATE TABLE `table2` (
  `id` INT(11) NOT NULL,
  `description` TEXT NOT NULL,
  `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  UNIQUE KEY `unique_name` (`name`),
  KEY `index` (`age`),
  FULLTEXT KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `table3` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `value` DECIMAL(10,2) NOT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`, `value`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

/******************************************/
/*   DatabaseName = database   */
/*   TableName = activity_instance   */
/******************************************/
CREATE TABLE `activity_instance`
(
    `id`                   bigint(20) unsigned NOT NULL AUTO_INCREMENT COMMENT '主键',
    `gmt_create`           bigint(20) NOT NULL COMMENT '创建时间',
    `gmt_modified`         bigint(20) NOT NULL COMMENT '修改时间',
    `activity_name`        varchar(400)  NOT NULL COMMENT '活动名称',
    `benefit_type`         varchar(16)   DEFAULT NULL,
    `benefit_id`           varchar(32)   DEFAULT NULL,
    PRIMARY KEY (`id`),
    KEY `idx_country_env` (`env`, `country_code`),
    KEY `idx_benefit_type_id` (`benefit_type`, `benefit_id`)
) ENGINE = InnoDB
  AUTO_INCREMENT = 139
  DEFAULT CHARSET = utf8mb4 COMMENT ='活动时间模板表'
;

sql2.sql:

CREATE TABLE `TABLE1` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `name` VARCHAR(255) NOT NULL,
  `created_at` bigint(20) DEFAULT NULL,
  `updated_at` date ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  KEY `idx_country_env` (`env`, `country_code`),
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COMMENT ='测试表';

CREATE TABLE `table2` (
  `id` INT(11) NOT NULL,
  `description` TEXT NOT NULL,
  `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  PRIMARY KEY (`id`),
  UNIQUE KEY `unique_name` (`name`),
  KEY `index` (`description`),
  KEY `fulltext_index` (`name`, `age`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `table3` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `value` DECIMAL(10,2) NOT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  PRIMARY KEY (`value`, `id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

CREATE TABLE `TABLE5` (
  `id` INT(11) NOT NULL AUTO_INCREMENT,
  `value` DECIMAL(10,2) NOT NULL,
  `updated_at` TIMESTAMP ON UPDATE CURRENT_TIMESTAMP
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;

把 python 和 sql 脚本拷贝下来分别放在同一个目录下的3个文件中即可,示例在 python 3.12 环境上成功运行。

精彩评论(0)

0 0 举报