Java获取所有类的class对象

今天需求需要获取到项目里所有所有类的class对象,查阅了一番资料后写了一个工具类,如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package com.xiaoyun.utils;

import com.xiaoyun.annotation.Quartz;
import com.xiaoyun.holder.QuartzHolder;
import org.springframework.stereotype.Component;

import java.io.File;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.util.*;

/**
* 获取class类对象的工具类
*/
@Component
public class ClassUtils {

/**
* 获取class类对象
*
* @param basePackage 要扫描的包
* @return 文件列表
*/
public List<Class<?>> getAllClasses(String basePackage) {
basePackage = basePackage.replace(".", "/");
Enumeration<URL> resources = null;
try {
resources = Thread.currentThread().getContextClassLoader().getResources(basePackage);
} catch (IOException e) {
throw new RuntimeException(e);
}

ArrayList<String> classes = new ArrayList<>();
while (resources.hasMoreElements()) {
URL url = resources.nextElement();
if ("file".equals(url.getProtocol())) {
String file = url.getFile();
findFiles(basePackage, file, classes);
}
}

return this.loadClazz(classes);
}

/**
* 获取所有类的包名
*
* @param basePackage 磁盘路径
* @param file 文件路径
* @param classes class包名
*/
private void findFiles(String basePackage, String file, ArrayList<String> classes) {
File f = new File(file);
if (f.isDirectory()) {
File[] fileArr = f.listFiles();
if (fileArr != null) {
for (File afile : fileArr) {
findFiles(basePackage, afile.getAbsolutePath(), classes);
}
}
} else {
// 处理class文件
if (file.toLowerCase().endsWith(".class")) {
file = file.replace('\\', '/');
String className = file.substring(file.lastIndexOf(basePackage), file.length() - 6)
.replace('/', '.');
classes.add(className);
}
}
}

/**
* 加载class
*
* @param classes 包名
* @return class对象
*/
private List<Class<?>> loadClazz(ArrayList<String> classes) {
List<Class<?>> resArr = new ArrayList<>();
try {
for (String classpath : classes) {
Class<?> aClass = Class.forName(classpath);
resArr.add(aClass);
}
} catch (Exception e) {
throw new RuntimeException("load class error", e);
}
return resArr;
}

/**
* 寻找所有带有某个注解的所有方法
*
* @param anno 注解名字
* @return key:className, value:methodName
*/
public Map<String, List<String>> findHasAnnotationMethods(String basePackage, Class<? extends Annotation> anno) {
List<Class<?>> allClasses = this.getAllClasses(basePackage);
return scanMethods(anno, allClasses);
}

/**
* 寻找所有带有某个注解的所有方法
*
* @param anno 注解名字
* @return key:className, value:methodName
*/
public Map<String, List<String>> findHasAnnotationMethods(List<Class<?>> allClasses, Class<? extends Annotation> anno) {
return scanMethods(anno, allClasses);
}

/**
* 扫描class的方法,扫描其中带有给定注解的方法,返回以class包名,方法名为键值对的map
*
* @param anno 注解
* @param allClasses class集合
* @return key:className, value:methodName
*/
private Map<String, List<String>> scanMethods(Class<? extends Annotation> anno, List<Class<?>> allClasses) {
Map<String, List<String>> map = new HashMap<>();
for (Class<?> clazz : allClasses) {
Method[] methods = clazz.getMethods();
for (Method method : methods) {
Annotation annotation = method.getAnnotation(anno);
if (annotation != null) {
String name = clazz.getName();
if (map.containsKey(name)) {
map.get(name).add(method.getName());
} else {
List<String> list = new ArrayList<>();
list.add(method.getName());
map.put(name, list);
}
}
}
}
return map;
}

/*
* @Author TryGo
* @Description 根据类的全包名转换成默认的包名
* @Date 2024/10/31 22:33
* @Param [className] 类名称
* @return bean的名字
**/
public String formatBeanName(String className) {
try {
String[] split = className.split("\\.");
String clazzName = split[split.length - 1];
clazzName = clazzName.substring(0, 1).toLowerCase() + clazzName.substring(1);
return clazzName;
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}