Java实现的决策树算法完整实例

2025-05-27 0 71

本文实例讲述了Java实现的决策树算法。分享给大家供大家参考,具体如下:

决策树算法是一种逼近离散函数值的方法。它是一种典型的分类方法,首先对数据进行处理,利用归纳算法生成可读的规则和决策树,然后使用决策对新数据进行分析。本质上决策树是通过一系列规则对数据进行分类的过程。

决策树构造可以分两步进行。第一步,决策树的生成:由训练样本集生成决策树的过程。一般情况下,训练样本数据集是根据实际需要有历史的、有一定综合程度的,用于数据分析处理的数据集。第二步,决策树的剪枝:决策树的剪枝是对上一阶段生成的决策树进行检验、校正和修下的过程,主要是用新的样本数据集(称为测试数据集)中的数据校验决策树生成过程中产生的初步规则,将那些影响预衡准确性的分枝剪除。

java实现代码如下:

?

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

33

34

35

36

37

38

39

40

41

42

43

44

45

46

47

48

49

50

51

52

53

54

55

56

57

58

59

60

61

62

63

64

65

66

67

68

69

70

71

72

73

74

75

76

77

78

79

80

81

82

83

84

85

86

87

88

89

90

91

92

93

94

95

96

97

98

99

100

101

102

103

104

105

106

107

108

109

110

111

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142

143

144

145

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163

164

165

166

167

168

169

170

171

172

173

174

175

176

177

178

179

180

181

182

183

184

185

186

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225
package demo;

import java.util.HashMap;

import java.util.LinkedList;

import java.util.List;

import java.util.Map;

import java.util.Map.Entry;

import java.util.Set;

public class DicisionTree {

public static void main(String[] args) throws Exception {

System.out.print("快网idc测试结果:");

String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT",

"CREDIT_RATING" };

// 读取样本集

Map<Object, List<Sample>> samples = readSamples(attrNames);

// 生成决策树

Object decisionTree = generateDecisionTree(samples, attrNames);

// 输出决策树

outputDecisionTree(decisionTree, 0, null);

}

/**

* 读取已分类的样本集,返回Map:分类 -> 属于该分类的样本的列表

*/

static Map<Object, List<Sample>> readSamples(String[] attrNames) {

// 样本属性及其所属分类(数组中的最后一个元素为样本所属分类)

Object[][] rawData = new Object[][] {

{ "<30 ", "High ", "No ", "Fair ", "0" },

{ "<30 ", "High ", "No ", "Excellent", "0" },

{ "30-40", "High ", "No ", "Fair ", "1" },

{ ">40 ", "Medium", "No ", "Fair ", "1" },

{ ">40 ", "Low ", "Yes", "Fair ", "1" },

{ ">40 ", "Low ", "Yes", "Excellent", "0" },

{ "30-40", "Low ", "Yes", "Excellent", "1" },

{ "<30 ", "Medium", "No ", "Fair ", "0" },

{ "<30 ", "Low ", "Yes", "Fair ", "1" },

{ ">40 ", "Medium", "Yes", "Fair ", "1" },

{ "<30 ", "Medium", "Yes", "Excellent", "1" },

{ "30-40", "Medium", "No ", "Excellent", "1" },

{ "30-40", "High ", "Yes", "Fair ", "1" },

{ ">40 ", "Medium", "No ", "Excellent", "0" } };

// 读取样本属性及其所属分类,构造表示样本的Sample对象,并按分类划分样本集

Map<Object, List<Sample>> ret = new HashMap<Object, List<Sample>>();

for (Object[] row : rawData) {

Sample sample = new Sample();

int i = 0;

for (int n = row.length - 1; i < n; i++)

sample.setAttribute(attrNames[i], row[i]);

sample.setCategory(row[i]);

List<Sample> samples = ret.get(row[i]);

if (samples == null) {

samples = new LinkedList<Sample>();

ret.put(row[i], samples);

}

samples.add(sample);

}

return ret;

}

/**

* 构造决策树

*/

static Object generateDecisionTree(

Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {

// 如果只有一个样本,将该样本所属分类作为新样本的分类

if (categoryToSamples.size() == 1)

return categoryToSamples.keySet().iterator().next();

// 如果没有供决策的属性,则将样本集中具有最多样本的分类作为新样本的分类,即投票选举出分类

if (attrNames.length == 0) {

int max = 0;

Object maxCategory = null;

for (Entry<Object, List<Sample>> entry : categoryToSamples

.entrySet()) {

int cur = entry.getValue().size();

if (cur > max) {

max = cur;

maxCategory = entry.getKey();

}

}

return maxCategory;

}

// 选取测试属性

Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames);

// 决策树根结点,分支属性为选取的测试属性

Tree tree = new Tree(attrNames[(Integer) rst[0]]);

// 已用过的测试属性不应再次被选为测试属性

String[] subA = new String[attrNames.length - 1];

for (int i = 0, j = 0; i < attrNames.length; i++)

if (i != (Integer) rst[0])

subA[j++] = attrNames[i];

// 根据分支属性生成分支

@SuppressWarnings("unchecked")

Map<Object, Map<Object, List<Sample>>> splits =

/* NEW LINE */(Map<Object, Map<Object, List<Sample>>>) rst[2];

for (Entry<Object, Map<Object, List<Sample>>> entry : splits.entrySet()) {

Object attrValue = entry.getKey();

Map<Object, List<Sample>> split = entry.getValue();

Object child = generateDecisionTree(split, subA);

tree.setChild(attrValue, child);

}

return tree;

}

/**

* 选取最优测试属性。最优是指如果根据选取的测试属性分支,则从各分支确定新样本

* 的分类需要的信息量之和最小,这等价于确定新样本的测试属性获得的信息增益最大

* 返回数组:选取的属性下标、信息量之和、Map(属性值->(分类->样本列表))

*/

static Object[] chooseBestTestAttribute(

Map<Object, List<Sample>> categoryToSamples, String[] attrNames) {

int minIndex = -1; // 最优属性下标

double minValue = Double.MAX_VALUE; // 最小信息量

Map<Object, Map<Object, List<Sample>>> minSplits = null; // 最优分支方案

// 对每一个属性,计算将其作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和,选取最小为最优

for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) {

int allCount = 0; // 统计样本总数的计数器

// 按当前属性构建Map:属性值->(分类->样本列表)

Map<Object, Map<Object, List<Sample>>> curSplits =

/* NEW LINE */new HashMap<Object, Map<Object, List<Sample>>>();

for (Entry<Object, List<Sample>> entry : categoryToSamples

.entrySet()) {

Object category = entry.getKey();

List<Sample> samples = entry.getValue();

for (Sample sample : samples) {

Object attrValue = sample

.getAttribute(attrNames[attrIndex]);

Map<Object, List<Sample>> split = curSplits.get(attrValue);

if (split == null) {

split = new HashMap<Object, List<Sample>>();

curSplits.put(attrValue, split);

}

List<Sample> splitSamples = split.get(category);

if (splitSamples == null) {

splitSamples = new LinkedList<Sample>();

split.put(category, splitSamples);

}

splitSamples.add(sample);

}

allCount += samples.size();

}

// 计算将当前属性作为测试属性的情况下在各分支确定新样本的分类需要的信息量之和

double curValue = 0.0; // 计数器:累加各分支

for (Map<Object, List<Sample>> splits : curSplits.values()) {

double perSplitCount = 0;

for (List<Sample> list : splits.values())

perSplitCount += list.size(); // 累计当前分支样本数

double perSplitValue = 0.0; // 计数器:当前分支

for (List<Sample> list : splits.values()) {

double p = list.size() / perSplitCount;

perSplitValue -= p * (Math.log(p) / Math.log(2));

}

curValue += (perSplitCount / allCount) * perSplitValue;

}

// 选取最小为最优

if (minValue > curValue) {

minIndex = attrIndex;

minValue = curValue;

minSplits = curSplits;

}

}

return new Object[] { minIndex, minValue, minSplits };

}

/**

* 将决策树输出到标准输出

*/

static void outputDecisionTree(Object obj, int level, Object from) {

for (int i = 0; i < level; i++)

System.out.print("|-----");

if (from != null)

System.out.printf("(%s):", from);

if (obj instanceof Tree) {

Tree tree = (Tree) obj;

String attrName = tree.getAttribute();

System.out.printf("[%s = ?]\\n", attrName);

for (Object attrValue : tree.getAttributeValues()) {

Object child = tree.getChild(attrValue);

outputDecisionTree(child, level + 1, attrName + " = "

+ attrValue);

}

} else {

System.out.printf("[CATEGORY = %s]\\n", obj);

}

}

/**

* 样本,包含多个属性和一个指明样本所属分类的分类值

*/

static class Sample {

private Map<String, Object> attributes = new HashMap<String, Object>();

private Object category;

public Object getAttribute(String name) {

return attributes.get(name);

}

public void setAttribute(String name, Object value) {

attributes.put(name, value);

}

public Object getCategory() {

return category;

}

public void setCategory(Object category) {

this.category = category;

}

public String toString() {

return attributes.toString();

}

}

/**

* 决策树(非叶结点),决策树中的每个非叶结点都引导了一棵决策树

* 每个非叶结点包含一个分支属性和多个分支,分支属性的每个值对应一个分支,该分支引导了一棵子决策树

*/

static class Tree {

private String attribute;

private Map<Object, Object> children = new HashMap<Object, Object>();

public Tree(String attribute) {

this.attribute = attribute;

}

public String getAttribute() {

return attribute;

}

public Object getChild(Object attrValue) {

return children.get(attrValue);

}

public void setChild(Object attrValue, Object child) {

children.put(attrValue, child);

}

public Set<Object> getAttributeValues() {

return children.keySet();

}

}

}

运行结果:

Java实现的决策树算法完整实例

希望本文所述对大家java程序设计有所帮助。

原文链接:http://blog.csdn.net/u013058160/article/details/50035693

收藏 (0) 打赏

感谢您的支持,我会继续努力的!

打开微信/支付宝扫一扫,即可进行扫码打赏哦,分享从这里开始,精彩与您同在
点赞 (0)

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。

快网idc优惠网 建站教程 Java实现的决策树算法完整实例 https://www.kuaiidc.com/77396.html

相关文章

发表评论
暂无评论