SqlValidateUtil.java 3.98 KB
package com.xly.util;


import com.xly.exception.sqlexception.SqlValidateException;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.StatementVisitorAdapter;
import net.sf.jsqlparser.statement.delete.Delete;
import net.sf.jsqlparser.statement.drop.Drop;
import net.sf.jsqlparser.statement.insert.Insert;
import net.sf.jsqlparser.statement.update.Update;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;
import java.util.regex.Pattern;

/**
 * SQL强校验工具(MySQL专属)
 */
public class SqlValidateUtil {
    private static final Logger log = LoggerFactory.getLogger(SqlValidateUtil.class);

    // 危险SQL关键词(生产可根据业务扩展)
    private static final List<String> DANGER_KEYWORDS = Arrays.asList(
            "DROP", "ALTER", "TRUNCATE", "DELETE", "INSERT", "UPDATE", "CREATE",
            "RENAME", "REPLACE", "GRANT", "REVOKE", "CALL", "SHUTDOWN", "LOAD"
    );
    // 关键词匹配正则(忽略大小写,单词边界匹配,避免误判)
    private static final Pattern DANGER_KEYWORD_PATTERN = Pattern.compile(
            "\\b(" + String.join("|", DANGER_KEYWORDS) + ")\\b",
            Pattern.CASE_INSENSITIVE
    );

    /**
     * MySQL SQL全量强校验
     * @param sql 生成的SQL语句
     */
    public static void validateMysqlSql(String sql) {
        log.info("开始MySQL SQL强校验,待校验SQL:{}", sql);
        // 1. 空值/空白校验
        if (sql == null || sql.trim().isEmpty()) {
            throw new SqlValidateException("校验失败:生成的SQL语句为空");
        }
        String cleanSql = sql.trim();

        // 2. 危险关键词过滤
        if (DANGER_KEYWORD_PATTERN.matcher(cleanSql).find()) {
            throw new SqlValidateException("校验失败:SQL包含危险关键词,禁止执行!危险关键词:" + DANGER_KEYWORDS);
        }

        // 3. 语法校验 + 非SELECT语句精准拦截
        try {
            Statement statement = CCJSqlParserUtil.parse(cleanSql);
            // 遍历SQL语句,拦截INSERT/UPDATE/DELETE/DROP等非查询语句
            statement.accept(new StatementVisitorAdapter() {
                @Override
                public void visit(Insert insert) {
                    throw new SqlValidateException("校验失败:禁止执行INSERT语句");
                }

                @Override
                public void visit(Update update) {
                    throw new SqlValidateException("校验失败:禁止执行UPDATE语句");
                }

                @Override
                public void visit(Delete delete) {
                    throw new SqlValidateException("校验失败:禁止执行DELETE语句");
                }

                @Override
                public void visit(Drop drop) {
                    throw new SqlValidateException("校验失败:禁止执行DROP语句");
                }
            });
        } catch (JSQLParserException e) {
            throw new SqlValidateException("校验失败:SQL语法错误!错误信息:" + e.getMessage(), e);
        } catch (SqlValidateException e) {
            throw e; // 抛出拦截的非SELECT异常
        } catch (Exception e) {
            throw new SqlValidateException("校验失败:SQL解析异常!错误信息:" + e.getMessage(), e);
        }

        log.info("MySQL SQL强校验通过");
    }

    /**
     * 清理模型生成的多余符号(```sql/```/换行),避免SQL执行报错
     */
    public static String cleanSqlSymbol(String sql) {
        if (sql == null) {
            return "";
        }
        if(sql.contains("```sql") && sql.contains("```")){
            sql=sql.split("```sql")[1];
            sql=sql.split("```")[0];
        }
        return sql.replace("```sql", "")
                .replace("```", "")
                .replaceAll("\\n|\\r", " ")
                .trim();
    }
}