背景
项目中需要对文件上传类型进行限制,同时需要进行文件的实际类型的判断,避免恶意修改文件后缀的情况,因此需要对文件的后缀与实际类型进行校验。
实现
经过查找,Hutool 工具提供了文件类型判断的工具 FileTypeUtil,里面支持了70 多种常见的文件类型,它是通过 FileMagic 类来实现的,FileMagic 里面通过文件的魔法值来进行判断,然而在进行实际测试时发现,某些文件它判断的并不是很准确,依旧需要我们进一步进行手动编码的判断。
在file signatures上可以找到几乎所有文件类型的校验方式,所以考虑对部分文件类型进行扩充。
同时还考虑到一个问题,目前我们仅支持一少部分的文件类型校验,如果后续我们需要增加或减少支持的文件类型,我们不能每次都停机去部署,需要动态的维护,那么我们就应该将自己实现的校验动态的加载到系统中,或者从系统中卸载,同时也要动态的增减系统支持的文件类型。
动态加载类,那么首先想到的就是通过 SPI 进行实现,需要定义接口,要动态的增减支持文件类型,那么就需要在接口中顶级当前校验的文件类型,如接口定义如下,接口应该放在一个独立的 maven 中,这样可以让具体的实现去引用它
/**
* 文件类型匹配器
*/
public interface FileTypeMatch {
/**
* 文件后缀
**/
List<String> extension();
/**
* 校验
**/
boolean match(InputStream inputStream);
}
接口有了,动态增减的前提是可以动态的感知到扩展类的变化,hutool 工具提供了封装过的 watch,可以用于监控指定目录中文件的变化情况,于是考虑将扩展 jar 放在指定目录下,通过 watch 去监控里面文件的变化。同时考虑到纯文本文件是没有魔法值的,所以需要将一些纯文本文件进行排除,那么添加配置类
@Data
@Configuration
@ConfigurationProperties(prefix = "young.common.filetype.matcher")
@Validated
public class FileTypeMatcherProperties {
/**
* 监听路径
*/
@NotBlank(message = "监听路径不能为空")
private String watchJarPath;
/**
* 跳过不认证的文件类型
*/
private List<String> skipFileType = new ArrayList<>(Arrays.asList("txt","sql","html","xml","yaml","json","md","log","rtf"));
}
对于一些简单的依靠偏移量和魔法值就可以判断的文件类型,可以定义一个类,让其也可以进行注册
@Data
@AllArgsConstructor
@NoArgsConstructor
public class FileTypeData {
private int offset;
private String header;
}
接着应该定义匹配器,匹配器应该支持类型的注册、移除,同时也要告知调用方支持哪些类型,最重要的还有校验逻辑
@Service
@RequiredArgsConstructor
public class FileTypeMatcher {
private final List<FileTypeMatch> fileTypeMatches;
private final FileTypeMatcherProperties properties;
/**
* 内置文件类型
*/
private Map<String, FileMagicNumber> innerFileExtension = Arrays.stream(FileMagicNumber.values())
.filter(e -> StringUtils.isNotBlank(e.getExtension()))
.collect(Collectors.toMap(e -> e.getExtension().toLowerCase(), e -> e));
/**
* 注册的文件类型及匹配规则
*/
private Map<String, Set<FileTypeData>> fileDataMapping = new ConcurrentHashMap<>();
/**
* 文件类型映射
* 例如 mp4文件,可能对应多个文件类型,例如:mp4、mp41、m4v等
*/
private Map<String, Set<String>> fileTypeMapping = new ConcurrentHashMap<>();
/**
* 自定义文件类型匹配器
*/
private Map<String, List<FileTypeMatch>> customerFileTypeMatcher = new ConcurrentHashMap<>();
@PostConstruct
public void init() {
// 初始化自定义匹配器
for (FileTypeMatch fileTypeMatch : fileTypeMatches) {
List<String> extensions = fileTypeMatch.extension();
for (String extension : extensions) {
extension = extension.toLowerCase();
List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.computeIfAbsent(extension, k -> new ArrayList<>());
fileTypeMatches.add(fileTypeMatch);
}
}
}
public List<String> getSupportFileType() {
Set<String> supportFileType = new HashSet<>();
supportFileType.addAll(customerFileTypeMatcher.keySet());
supportFileType.addAll(innerFileExtension.keySet());
supportFileType.addAll(fileDataMapping.keySet());
supportFileType.addAll(properties.getSkipFileType());
return new ArrayList<>(supportFileType);
}
/**
* 注册文件类型映射,如映射m4v到mp4
*
* @param fileExtension 文件类型
* @param mappingType 映射类型
*/
public void registry(String fileExtension, String mappingType) {
Set<String> mappingSet = fileTypeMapping.computeIfAbsent(mappingType, k -> new ConcurrentHashSet<>());
mappingSet.add(fileExtension);
}
/**
* 注册文件类型
*
* @param fileExtension 文件后缀
* @param magicString 魔法值
* @param offset 偏移量
*/
public void registry(String fileExtension, String magicString, int offset) {
this.registry(fileExtension, magicString, offset, null);
}
/**
* 注册文件类型匹配器
*
* @param fileTypeMatch
*/
public void registry(FileTypeMatch fileTypeMatch) {
List<String> extensions = fileTypeMatch.extension();
for (String extension : extensions) {
extension = extension.toLowerCase();
List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.computeIfAbsent(extension, k -> new ArrayList<>());
synchronized (FileTypeMatcher.class) {
fileTypeMatches.add(fileTypeMatch);
}
}
}
public void remove(FileTypeMatch fileTypeMatch) {
synchronized (FileTypeMatcher.class) {
List<String> extensions = fileTypeMatch.extension();
for (String extension : extensions) {
List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.getOrDefault(extension, new ArrayList<>());
Iterator<FileTypeMatch> iterator = fileTypeMatches.iterator();
if (iterator.hasNext()) {
FileTypeMatch next = iterator.next();
if (next.getClass() == fileTypeMatch.getClass()) {
iterator.remove();
}
}
}
}
}
public void remove(Class<? extends FileTypeMatch> clazz) {
synchronized (FileTypeMatcher.class) {
Set<Map.Entry<String, List<FileTypeMatch>>> entries = customerFileTypeMatcher.entrySet();
for (Map.Entry<String, List<FileTypeMatch>> entry : entries) {
List<FileTypeMatch> value = entry.getValue();
if (value != null) {
Iterator<FileTypeMatch> iterator = value.iterator();
while (iterator.hasNext()) {
FileTypeMatch next = iterator.next();
if (next.getClass() == clazz) {
iterator.remove();
}
}
}
}
}
}
/**
* 注册文件类型
*
* @param fileExtension 文件后缀
* @param magicString 魔法值
* @param offset 偏移量
* @param mappingType 映射类型,比如m4v类型,映射到mp4类型上,如果不涉及,传空
*/
public void registry(String fileExtension, String magicString, int offset, String mappingType) {
fileExtension = fileExtension.toLowerCase();
mappingType = mappingType.toLowerCase();
magicString = magicString.toUpperCase();
Set<FileTypeData> fileTypeData = fileDataMapping.computeIfAbsent(fileExtension, k -> new ConcurrentHashSet<>());
fileTypeData.add(new FileTypeData(offset, magicString));
if (StringHelper.isNotBlank(mappingType)) {
Set<String> typeSet = fileTypeMapping.computeIfAbsent(mappingType, k -> new ConcurrentHashSet<>());
typeSet.add(fileExtension);
}
}
public boolean match(InputStream inputStream, String fileExtension) throws Exception {
try {
fileExtension = fileExtension.toLowerCase();
if (properties.getSkipFileType().contains(fileExtension)) {
String type = FileTypeUtil.getType(inputStream, "test." + fileExtension);
return StringUtils.equalsIgnoreCase(type, fileExtension);
}
Boolean match = null;
// 将流读取缓存,以便复用
byte[] bytes = IOHelper.inputStreamToBytes(inputStream);
// 使用自定义匹配器
if (fileDataMapping.containsKey(fileExtension)) {
// 使用偏移量匹配
ByteArrayInputStream bis = null;
try {
bis = new ByteArrayInputStream(bytes);
match = doOffsetMatch(bis, fileExtension);
} finally {
IOHelper.close(bis);
}
}
if (match != null && match) {
return true;
}
if (innerFileExtension.containsKey(fileExtension)) {
// 使用工具内置文件类型匹配
ByteArrayInputStream bis = null;
try {
bis = new ByteArrayInputStream(bytes);
match = doToolsMatch(bis, fileExtension);
} finally {
IOHelper.close(bis);
}
}
if (match != null && match) {
return true;
}
// 使用自定义处理器匹配
if (customerFileTypeMatcher.containsKey(fileExtension)) {
List<FileTypeMatch> fileTypeMatches = customerFileTypeMatcher.get(fileExtension);
match = fileTypeMatches.stream().anyMatch(e -> {
ByteArrayInputStream bis = null;
try {
bis = new ByteArrayInputStream(bytes);
return e.match(bis);
} finally {
IOHelper.close(bis);
}
});
}
if (match != null && match) {
return true;
}
if (properties.getSkipFileType().contains(fileExtension)) {
String type = FileTypeUtil.getType(inputStream, "a." + fileExtension);
if (StringUtils.equalsIgnoreCase(type, fileExtension)) {
match = true;
}
}
if (match == null) {
throw new UnSupportFileTypeException("不支持的文件类型");
}
return match;
} finally {
IOHelper.close(inputStream);
}
}
private boolean doToolsMatch(InputStream inputStream, String fileExtension) throws IOException {
// 获取文件类型
FileMagicNumber fileMagicNumber = innerFileExtension.get(fileExtension);
if (fileMagicNumber == null) {
return false;
}
byte[] byteArray = IOUtils.toByteArray(inputStream);
boolean match = fileMagicNumber.match(byteArray);
if (!match) {
ByteArrayInputStream bis = null;
try {
bis = new ByteArrayInputStream(byteArray);
String type = FileTypeUtil.getType(bis);
match = StringUtils.equalsIgnoreCase(type, fileExtension);
if (!match) {
// 如果不匹配,看是否满足映射关系,如m4v可以认为是mp4
Set<String> mappingExtension = fileTypeMapping.getOrDefault(fileExtension, new HashSet<>());
return mappingExtension.contains(type);
}
}catch (Exception e){
return false;
}finally {
IoUtil.close(bis);
}
}
return match;
}
private boolean doOffsetMatch(InputStream inputStream, String fileExtension) throws IOException {
// 获取文件类型对应的配置数据
Set<FileTypeData> fileTypeDatas = fileDataMapping.get(fileExtension);
// 读取文件数据
byte[] dataBytes = readFileBytes(inputStream, fileTypeDatas);
// 匹配配置数据
return fileTypeDatas.stream().anyMatch(e -> {
byte[] bytes = Arrays.copyOfRange(dataBytes, e.getOffset(), dataBytes.length);
String fileHeader = new String(HexUtil.encodeHex(bytes, false));
return fileHeader.startsWith(e.getHeader().replace(" ", ""));
});
}
private byte[] readFileBytes(InputStream inputStream, Set<FileTypeData> fileTypeDatas) throws IOException {
int maxOffset = 0;
int maxMagicLength = 0;
// 获取最大的偏移量和魔法值的长度
for (FileTypeData fileTypeData : fileTypeDatas) {
int offset = fileTypeData.getOffset();
int length = StringUtils.split(fileTypeData.getHeader(), " ").length;
maxOffset = Math.max(maxOffset, offset);
maxMagicLength = Math.max(maxMagicLength, length);
}
// 获取文件长度
int available = inputStream.available();
byte[] dataBytes;
// 如果偏移量加魔法值长度大于64
if (maxOffset + maxMagicLength > 64) {
// 文件长度小于8192
if (available < 8192) {
// 读取全部
dataBytes = IoUtil.readBytes(inputStream);
} else {
// 读取前8192位
dataBytes = IoUtil.readBytes(inputStream, 8192);
}
} else {
// 如果偏移量加魔法值长度小于64
if (available < 64) {
// 读取全部
dataBytes = IoUtil.readBytes(inputStream);
} else {
// 读取前64位
dataBytes = IoUtil.readBytes(inputStream, 64);
}
}
return dataBytes;
}
}
最后实现监听功能,根据监控目录中 jar 的变化,调用注册和移除接口,实现类型的动态增删
@Service
@RequiredArgsConstructor
public class FileTypeMatchJarWatch {
private final FileTypeMatcherProperties properties;
private final FileTypeMatcher fileTypeMatcher;
private static final String JAR_FILE_SUFFIX = ".jar";
private final Log log = Log.getInstance(this.getClass());
private final Map<String, List<FileTypeMatch>> filePathAndInstanceMapping = new ConcurrentHashMap<>();
private final String FILE_PROTOCOL_PREFIX = SystemUtils.IS_OS_WINDOWS ? "file:///" : "file://";
@PostConstruct
public void init() {
loadDirJar();
addJarWatch();
}
private void loadDirJar() {
File jarDir = new File(properties.getWatchJarPath());
log.info("load jar from {}", properties.getWatchJarPath());
File[] files = jarDir.listFiles();
if (CollectionHelper.isEmpty(files)) {
log.info("watch path jar is empty");
return;
}
for (File file : files) {
if (file.getName().endsWith(JAR_FILE_SUFFIX)) {
loadJar(file.getAbsolutePath());
}
}
}
private void addJarWatch() {
String path = properties.getWatchJarPath();
WatchMonitor watchMonitor = WatchMonitor.create(path, WatchMonitor.EVENTS_ALL);
watchMonitor.setWatcher(new Watcher() {
@Override
public void onCreate(WatchEvent<?> event, Path currentPath) {
log.info("create file: path:【{}】==》file:【{}】", currentPath, event.context());
loadJar(currentPath, event);
}
@Override
public void onModify(WatchEvent<?> event, Path currentPath) {
log.info("modify file: path:【{}】==>file:【{}】", currentPath, event.context());
modify(currentPath, event);
}
@Override
public void onDelete(WatchEvent<?> event, Path currentPath) {
log.info("delete file: path:【{}】==>file:【{}】", currentPath, event.context());
remove(currentPath, event);
}
@Override
public void onOverflow(WatchEvent<?> event, Path currentPath) {
log.info("overflow file: path:【{}】==>file:【{}】", currentPath, event.context());
}
});
log.info("watch jar path :{}", path);
watchMonitor.start();
}
private void modify(Path currentPath, WatchEvent<?> event) {
synchronized (FileTypeMatchJarWatch.class) {
remove(currentPath, event);
loadJar(currentPath, event);
}
}
private void remove(Path currentPath, WatchEvent<?> event) {
String path = String.valueOf(currentPath);
String filePath = path.endsWith("/") ? path + event.context() : path + "/" + event.context();
List<FileTypeMatch> fileTypeMatches = filePathAndInstanceMapping.get(filePath);
if (fileTypeMatches != null) {
synchronized (FileTypeMatchJarWatch.class) {
Iterator<FileTypeMatch> iterator = fileTypeMatches.iterator();
while (iterator.hasNext()) {
FileTypeMatch next = iterator.next();
fileTypeMatcher.remove(next);
log.info("remove class {}", next.getClass());
iterator.remove();
}
}
}
}
private void loadJar(Path currentPath, WatchEvent<?> event) {
String path = String.valueOf(currentPath);
String filePath = path.endsWith("/") ? path + event.context() : path + "/" + event.context();
if (filePath.endsWith(JAR_FILE_SUFFIX)) {
loadJar(filePath);
}
}
private void loadJar(String filePath) {
try {
log.info("load jar from path:{}", filePath);
File file = new File(filePath);
URL fileUrl = file.toURI().toURL();
// URL url = new URL(FILE_PROTOCOL_PREFIX + filePath);
log.info("fileUrl:{}", fileUrl);
try (URLClassLoader urlClassLoader = new URLClassLoader(new URL[]{fileUrl},FileTypeMatch.class.getClassLoader())){
ServiceLoader<FileTypeMatch> load = ServiceLoader.load(FileTypeMatch.class, urlClassLoader);
Iterator<FileTypeMatch> iterator = load.iterator();
while (iterator.hasNext()) {
FileTypeMatch next = iterator.next();
log.info("registry class {}", next.getClass());
fileTypeMatcher.registry(next);
List<FileTypeMatch> fileTypeMatches = filePathAndInstanceMapping.computeIfAbsent(filePath, k -> new ArrayList<>());
synchronized (FileTypeMatchJarWatch.class) {
fileTypeMatches.add(next);
}
}
}
} catch (Exception e) {
log.error("load {} error", filePath, e);
}
}
}
开发过程中遇见的问题可参考 https://yhsblog.cn/archives/ji-yi-ci-jia-zai-wai-bu-jar-yu-dao-de-wen-ti