文件类型校验

young 156 2024-11-23

背景

项目中需要对文件上传类型进行限制,同时需要进行文件的实际类型的判断,避免恶意修改文件后缀的情况,因此需要对文件的后缀与实际类型进行校验。

实现

经过查找,Hutool 工具提供了文件类型判断的工具 FileTypeUtil,里面支持了70 多种常见的文件类型,它是通过 FileMagic 类来实现的,FileMagic 里面通过文件的魔法值来进行判断,然而在进行实际测试时发现,某些文件它判断的并不是很准确,依旧需要我们进一步进行手动编码的判断。

file signatures上可以找到几乎所有文件类型的校验方式,所以考虑对部分文件类型进行扩充。

同时还考虑到一个问题,目前我们仅支持一少部分的文件类型校验,如果后续我们需要增加或减少支持的文件类型,我们不能每次都停机去部署,需要动态的维护,那么我们就应该将自己实现的校验动态的加载到系统中,或者从系统中卸载,同时也要动态的增减系统支持的文件类型。

动态加载类,那么首先想到的就是通过 SPI 进行实现,需要定义接口,要动态的增减支持文件类型,那么就需要在接口中顶级当前校验的文件类型,如接口定义如下,接口应该放在一个独立的 maven 中,这样可以让具体的实现去引用它

/**
 * 文件类型匹配器
 */
public interface FileTypeMatch {
  	/**
  	 * 文件后缀
  	 **/
    List<String> extension();
		/** 
		 * 校验
		 **/
    boolean match(InputStream inputStream);
}

接口有了,动态增减的前提是可以动态的感知到扩展类的变化,hutool 工具提供了封装过的 watch,可以用于监控指定目录中文件的变化情况,于是考虑将扩展 jar 放在指定目录下,通过 watch 去监控里面文件的变化。同时考虑到纯文本文件是没有魔法值的,所以需要将一些纯文本文件进行排除,那么添加配置类

@Data
@Configuration
@ConfigurationProperties(prefix = "young.common.filetype.matcher")
@Validated
public class FileTypeMatcherProperties {


    /**
     * 监听路径
     */
    @NotBlank(message = "监听路径不能为空")
    private String watchJarPath;

    /**
     * 跳过不认证的文件类型
     */
    private List<String> skipFileType = new ArrayList<>(Arrays.asList("txt","sql","html","xml","yaml","json","md","log","rtf"));
}

对于一些简单的依靠偏移量和魔法值就可以判断的文件类型,可以定义一个类,让其也可以进行注册

@Data
@AllArgsConstructor
@NoArgsConstructor
public class FileTypeData {
    private int offset;
    private String header;
}

接着应该定义匹配器,匹配器应该支持类型的注册、移除,同时也要告知调用方支持哪些类型,最重要的还有校验逻辑

@Service
@RequiredArgsConstructor
public class FileTypeMatcher {

    private final List<FileTypeMatch> fileTypeMatches;
    private final FileTypeMatcherProperties properties;

    /**
     * 内置文件类型
     */
    private Map<String, FileMagicNumber> innerFileExtension = Arrays.stream(FileMagicNumber.values())
            .filter(e -> StringUtils.isNotBlank(e.getExtension()))
            .collect(Collectors.toMap(e -> e.getExtension().toLowerCase(), e -> e));

    /**
     * 注册的文件类型及匹配规则
     */
    private Map<String, Set<FileTypeData>> fileDataMapping = new ConcurrentHashMap<>();
    /**
     * 文件类型映射
     * 例如 mp4文件,可能对应多个文件类型,例如:mp4、mp41、m4v等
     */
    private Map<String, Set<String>> fileTypeMapping = new ConcurrentHashMap<>();

    /**
     * 自定义文件类型匹配器
     */
    private Map<String, List<FileTypeMatch>> customerFileTypeMatcher = new ConcurrentHashMap<>();

    @PostConstruct
    public void init() {
        // 初始化自定义匹配器
        for (FileTypeMatch fileTypeMatch : fileTypeMatches) {
            List<String> extensions = fileTypeMatch.extension();
            for (String extension : extensions) {
                extension = extension.toLowerCase();
                List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.computeIfAbsent(extension, k -> new ArrayList<>());
                fileTypeMatches.add(fileTypeMatch);
            }
        }
    }

    public List<String> getSupportFileType() {
        Set<String> supportFileType = new HashSet<>();
        supportFileType.addAll(customerFileTypeMatcher.keySet());
        supportFileType.addAll(innerFileExtension.keySet());
        supportFileType.addAll(fileDataMapping.keySet());
        supportFileType.addAll(properties.getSkipFileType());
        return new ArrayList<>(supportFileType);
    }


    /**
     * 注册文件类型映射,如映射m4v到mp4
     *
     * @param fileExtension 文件类型
     * @param mappingType   映射类型
     */
    public void registry(String fileExtension, String mappingType) {
        Set<String> mappingSet = fileTypeMapping.computeIfAbsent(mappingType, k -> new ConcurrentHashSet<>());
        mappingSet.add(fileExtension);
    }

    /**
     * 注册文件类型
     *
     * @param fileExtension 文件后缀
     * @param magicString   魔法值
     * @param offset        偏移量
     */
    public void registry(String fileExtension, String magicString, int offset) {
        this.registry(fileExtension, magicString, offset, null);
    }

    /**
     * 注册文件类型匹配器
     *
     * @param fileTypeMatch
     */
    public void registry(FileTypeMatch fileTypeMatch) {
        List<String> extensions = fileTypeMatch.extension();
        for (String extension : extensions) {
            extension = extension.toLowerCase();
            List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.computeIfAbsent(extension, k -> new ArrayList<>());
            synchronized (FileTypeMatcher.class) {
                fileTypeMatches.add(fileTypeMatch);
            }
        }
    }

    public void remove(FileTypeMatch fileTypeMatch) {
        synchronized (FileTypeMatcher.class) {
            List<String> extensions = fileTypeMatch.extension();
            for (String extension : extensions) {
                List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.getOrDefault(extension, new ArrayList<>());
                Iterator<FileTypeMatch> iterator = fileTypeMatches.iterator();
                if (iterator.hasNext()) {
                    FileTypeMatch next = iterator.next();
                    if (next.getClass() == fileTypeMatch.getClass()) {
                        iterator.remove();
                    }
                }
            }
        }
    }

    public void remove(Class<? extends FileTypeMatch> clazz) {
        synchronized (FileTypeMatcher.class) {
            Set<Map.Entry<String, List<FileTypeMatch>>> entries = customerFileTypeMatcher.entrySet();
            for (Map.Entry<String, List<FileTypeMatch>> entry : entries) {
                List<FileTypeMatch> value = entry.getValue();
                if (value != null) {
                    Iterator<FileTypeMatch> iterator = value.iterator();
                    while (iterator.hasNext()) {
                        FileTypeMatch next = iterator.next();
                        if (next.getClass() == clazz) {
                            iterator.remove();
                        }
                    }
                }
            }
        }
    }

    /**
     * 注册文件类型
     *
     * @param fileExtension 文件后缀
     * @param magicString   魔法值
     * @param offset        偏移量
     * @param mappingType   映射类型,比如m4v类型,映射到mp4类型上,如果不涉及,传空
     */
    public void registry(String fileExtension, String magicString, int offset, String mappingType) {
        fileExtension = fileExtension.toLowerCase();
        mappingType = mappingType.toLowerCase();
        magicString = magicString.toUpperCase();
        Set<FileTypeData> fileTypeData = fileDataMapping.computeIfAbsent(fileExtension, k -> new ConcurrentHashSet<>());
        fileTypeData.add(new FileTypeData(offset, magicString));
        if (StringHelper.isNotBlank(mappingType)) {
            Set<String> typeSet = fileTypeMapping.computeIfAbsent(mappingType, k -> new ConcurrentHashSet<>());
            typeSet.add(fileExtension);
        }
    }

    public boolean match(InputStream inputStream, String fileExtension) throws Exception {
        try {
            fileExtension = fileExtension.toLowerCase();
            if (properties.getSkipFileType().contains(fileExtension)) {
                String type = FileTypeUtil.getType(inputStream, "test." + fileExtension);
                return StringUtils.equalsIgnoreCase(type, fileExtension);

            }
            Boolean match = null;
            // 将流读取缓存,以便复用
            byte[] bytes = IOHelper.inputStreamToBytes(inputStream);
            // 使用自定义匹配器
            if (fileDataMapping.containsKey(fileExtension)) {
                // 使用偏移量匹配
                ByteArrayInputStream bis = null;
                try {
                    bis = new ByteArrayInputStream(bytes);
                    match = doOffsetMatch(bis, fileExtension);
                } finally {
                    IOHelper.close(bis);
                }

            }
            if (match != null && match) {
                return true;
            }

            if (innerFileExtension.containsKey(fileExtension)) {
                // 使用工具内置文件类型匹配
                ByteArrayInputStream bis = null;
                try {
                    bis = new ByteArrayInputStream(bytes);
                    match = doToolsMatch(bis, fileExtension);
                } finally {
                    IOHelper.close(bis);
                }
            }

            if (match != null && match) {
                return true;
            }

            // 使用自定义处理器匹配
            if (customerFileTypeMatcher.containsKey(fileExtension)) {
                List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.get(fileExtension);
                match = fileTypeMatches.stream().anyMatch(e -> {
                    ByteArrayInputStream bis = null;
                    try {
                        bis = new ByteArrayInputStream(bytes);
                        return e.match(bis);
                    } finally {
                        IOHelper.close(bis);
                    }
                });
            }

            if (match != null && match) {
                return true;
            }
            if (properties.getSkipFileType().contains(fileExtension)) {
                String type = FileTypeUtil.getType(inputStream, "a." + fileExtension);
                if (StringUtils.equalsIgnoreCase(type, fileExtension)) {
                    match = true;
                }
            }
            if (match == null) {
                throw new UnSupportFileTypeException("不支持的文件类型");
            }
            return match;
        } finally {
            IOHelper.close(inputStream);
        }


    }

    private boolean doToolsMatch(InputStream inputStream, String fileExtension) throws IOException {
        // 获取文件类型
        FileMagicNumber fileMagicNumber = innerFileExtension.get(fileExtension);
        if (fileMagicNumber == null) {
            return false;
        }
        byte[] byteArray = IOUtils.toByteArray(inputStream);
        boolean match = fileMagicNumber.match(byteArray);
        if (!match) {
            ByteArrayInputStream bis = null;
            try {
                bis = new ByteArrayInputStream(byteArray);
                String type = FileTypeUtil.getType(bis);
                match = StringUtils.equalsIgnoreCase(type, fileExtension);
                if (!match) {
                    // 如果不匹配,看是否满足映射关系,如m4v可以认为是mp4
                    Set<String> mappingExtension = fileTypeMapping.getOrDefault(fileExtension, new HashSet<>());
                    return mappingExtension.contains(type);
                }
            }catch (Exception e){
                return false;
            }finally {
                IoUtil.close(bis);
            }
        }
        return match;
    }

    private boolean doOffsetMatch(InputStream inputStream, String fileExtension) throws IOException {
        // 获取文件类型对应的配置数据
        Set<FileTypeData> fileTypeDatas = fileDataMapping.get(fileExtension);
        // 读取文件数据
        byte[] dataBytes = readFileBytes(inputStream, fileTypeDatas);
        // 匹配配置数据
        return fileTypeDatas.stream().anyMatch(e -> {
            byte[] bytes = Arrays.copyOfRange(dataBytes, e.getOffset(), dataBytes.length);
            String fileHeader = new String(HexUtil.encodeHex(bytes, false));
            return fileHeader.startsWith(e.getHeader().replace(" ", ""));
        });
    }

    private byte[] readFileBytes(InputStream inputStream, Set<FileTypeData> fileTypeDatas) throws IOException {
        int maxOffset = 0;
        int maxMagicLength = 0;
        // 获取最大的偏移量和魔法值的长度
        for (FileTypeData fileTypeData : fileTypeDatas) {
            int offset = fileTypeData.getOffset();
            int length = StringUtils.split(fileTypeData.getHeader(), " ").length;
            maxOffset = Math.max(maxOffset, offset);
            maxMagicLength = Math.max(maxMagicLength, length);
        }
        // 获取文件长度
        int available = inputStream.available();
        byte[] dataBytes;
        // 如果偏移量加魔法值长度大于64
        if (maxOffset + maxMagicLength > 64) {
            // 文件长度小于8192
            if (available < 8192) {
                // 读取全部
                dataBytes = IoUtil.readBytes(inputStream);
            } else {
                // 读取前8192位
                dataBytes = IoUtil.readBytes(inputStream, 8192);
            }
        } else {
            // 如果偏移量加魔法值长度小于64
            if (available < 64) {
                // 读取全部
                dataBytes = IoUtil.readBytes(inputStream);
            } else {
                // 读取前64位
                dataBytes = IoUtil.readBytes(inputStream, 64);
            }
        }
        return dataBytes;
    }

}

最后实现监听功能,根据监控目录中 jar 的变化,调用注册和移除接口,实现类型的动态增删

@Service
@RequiredArgsConstructor
public class FileTypeMatchJarWatch {


    private final FileTypeMatcherProperties properties;
    private final FileTypeMatcher fileTypeMatcher;
    private static final String JAR_FILE_SUFFIX = ".jar";
    private final Log log = Log.getInstance(this.getClass());
    private final Map<String, List<FileTypeMatch>> filePathAndInstanceMapping = new ConcurrentHashMap<>();
    private final String FILE_PROTOCOL_PREFIX = SystemUtils.IS_OS_WINDOWS ? "file:///" : "file://";

    @PostConstruct
    public void init() {
        loadDirJar();
        addJarWatch();
    }

    private void loadDirJar() {
        File jarDir = new File(properties.getWatchJarPath());
        log.info("load jar from {}", properties.getWatchJarPath());
        File[] files = jarDir.listFiles();
        if (CollectionHelper.isEmpty(files)) {
            log.info("watch path jar is empty");
            return;
        }
        for (File file : files) {
            if (file.getName().endsWith(JAR_FILE_SUFFIX)) {
                loadJar(file.getAbsolutePath());
            }
        }
    }


    private void addJarWatch() {
        String path = properties.getWatchJarPath();
        WatchMonitor watchMonitor = WatchMonitor.create(path, WatchMonitor.EVENTS_ALL);

        watchMonitor.setWatcher(new Watcher() {
            @Override
            public void onCreate(WatchEvent<?> event, Path currentPath) {
                log.info("create file: path:【{}】==》file:【{}】", currentPath, event.context());
                loadJar(currentPath, event);
            }

            @Override
            public void onModify(WatchEvent<?> event, Path currentPath) {
                log.info("modify file: path:【{}】==>file:【{}】", currentPath, event.context());
                modify(currentPath, event);
            }

            @Override
            public void onDelete(WatchEvent<?> event, Path currentPath) {
                log.info("delete file: path:【{}】==>file:【{}】", currentPath, event.context());
                remove(currentPath, event);
            }

            @Override
            public void onOverflow(WatchEvent<?> event, Path currentPath) {
                log.info("overflow file: path:【{}】==>file:【{}】", currentPath, event.context());
            }
        });
        log.info("watch jar path :{}", path);
        watchMonitor.start();
    }

    private void modify(Path currentPath, WatchEvent<?> event) {
        synchronized (FileTypeMatchJarWatch.class) {
            remove(currentPath, event);
            loadJar(currentPath, event);
        }
    }

    private void remove(Path currentPath, WatchEvent<?> event) {
        String path = String.valueOf(currentPath);
        String filePath = path.endsWith("/") ? path + event.context() : path + "/" + event.context();
        List<FileTypeMatch> fileTypeMatches = filePathAndInstanceMapping.get(filePath);
        if (fileTypeMatches != null) {
            synchronized (FileTypeMatchJarWatch.class) {
                Iterator<FileTypeMatch> iterator = fileTypeMatches.iterator();
                while (iterator.hasNext()) {
                    FileTypeMatch next = iterator.next();
                    fileTypeMatcher.remove(next);
                    log.info("remove class {}", next.getClass());
                    iterator.remove();
                }
            }
        }
    }


    private void loadJar(Path currentPath, WatchEvent<?> event) {
        String path = String.valueOf(currentPath);
        String filePath = path.endsWith("/") ? path + event.context() : path + "/" + event.context();
        if (filePath.endsWith(JAR_FILE_SUFFIX)) {
            loadJar(filePath);

        }
    }

    private void loadJar(String filePath) {
        try {
            log.info("load jar from path:{}", filePath);
            File file = new File(filePath);
            URL fileUrl = file.toURI().toURL();
//            URL url = new URL(FILE_PROTOCOL_PREFIX + filePath);
            log.info("fileUrl:{}", fileUrl);
            try (URLClassLoader urlClassLoader = new URLClassLoader(new URL[]{fileUrl},FileTypeMatch.class.getClassLoader())){
                ServiceLoader<FileTypeMatch> load = ServiceLoader.load(FileTypeMatch.class, urlClassLoader);
                Iterator<FileTypeMatch> iterator = load.iterator();
                while (iterator.hasNext()) {
                    FileTypeMatch next = iterator.next();
                    log.info("registry class {}", next.getClass());
                    fileTypeMatcher.registry(next);
                    List<FileTypeMatch> fileTypeMatches = filePathAndInstanceMapping.computeIfAbsent(filePath, k -> new ArrayList<>());
                    synchronized (FileTypeMatchJarWatch.class) {
                        fileTypeMatches.add(next);
                    }
                }
            }
        } catch (Exception e) {
            log.error("load {} error", filePath, e);
        }
    }

}

开发过程中遇见的问题可参考 https://yhsblog.cn/archives/ji-yi-ci-jia-zai-wai-bu-jar-yu-dao-de-wen-ti