Skip to content

Commit 442dbec

Browse files
committedMar 10, 2017
Day7 PatriciaTrie
1 parent 2a92c3e commit 442dbec

File tree

7 files changed

+1033
-2
lines changed

7 files changed

+1033
-2
lines changed
 

‎build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
buildscript {
2-
ext.kotlin_version = "1.0.6"
2+
ext.kotlin_version = "1.1.0"
33

44
repositories {
55
mavenCentral()

‎src/main/kotlin/mbc/trie/PatriciaTrie.kt

Lines changed: 657 additions & 0 deletions
Large diffs are not rendered by default.

‎src/main/kotlin/mbc/trie/Trie.kt

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
package mbc.trie
2+
3+
import mbc.util.CodecUtil
4+
import mbc.util.CryptoUtil
5+
import org.spongycastle.asn1.ASN1EncodableVector
6+
import org.spongycastle.asn1.DERSequence
7+
import org.spongycastle.util.encoders.Hex
8+
9+
class Trie<Value> {
10+
11+
/**
12+
* Radix of the Trie.
13+
*/
14+
val radix = 256
15+
16+
/**
17+
* Root TrieNode.
18+
*/
19+
var root: TrieNode<Value>? = null
20+
21+
/**
22+
* Trie TrieNode.
23+
*/
24+
class TrieNode<Value> {
25+
26+
val radix = 256
27+
val EMPTY_VALUE = ByteArray(0)
28+
29+
var value: Value? = null
30+
val next = kotlin.arrayOfNulls<TrieNode<Value>>(radix)
31+
32+
fun hash(): ByteArray {
33+
val bin = encode()
34+
35+
val hash = CryptoUtil.sha256(bin)
36+
return hash
37+
}
38+
39+
private fun encode(): ByteArray {
40+
val vec = ASN1EncodableVector()
41+
42+
if (value != null) {
43+
vec.add(CodecUtil.asn1Encode(value!!))
44+
} else {
45+
vec.add(CodecUtil.asn1Encode(EMPTY_VALUE))
46+
}
47+
48+
val nextVec = ASN1EncodableVector()
49+
next.forEach {
50+
// 如果是Null就序列化为Byte(0)
51+
if (it == null) {
52+
nextVec.add(CodecUtil.asn1Encode(EMPTY_VALUE))
53+
} else if (it is TrieNode) {
54+
nextVec.add(CodecUtil.asn1Encode(it.hash()))
55+
}
56+
}
57+
vec.add(DERSequence(nextVec))
58+
59+
val bin = DERSequence(vec).encoded
60+
return bin
61+
}
62+
}
63+
64+
fun get(key: String): Value? {
65+
val x = getSubNode(root, key, 0) ?: return null
66+
return x.value
67+
}
68+
69+
private fun getSubNode(x: TrieNode<Value>?, key: String, d: Int): TrieNode<Value>? {
70+
if (x == null) return null
71+
if (d == key.length) return x
72+
val c = key[d]
73+
return getSubNode(x.next[c.toInt()], key, d + 1)
74+
}
75+
76+
fun put(key: String, v: Value?) {
77+
if (v == null)
78+
delete(key)
79+
else
80+
root = putSubNode(root, key, v, 0)
81+
}
82+
83+
private fun putSubNode(x: TrieNode<Value>?, key: String, v: Value, d: Int): TrieNode<Value> {
84+
var node: TrieNode<Value>?
85+
86+
if (x == null) {
87+
node = TrieNode<Value>()
88+
} else {
89+
node = x
90+
}
91+
92+
// 如果到达Radix的最后一位,完成节点构造并返回。
93+
if (d == key.length) {
94+
node.value = v
95+
return node
96+
}
97+
98+
// 继续构造节点。
99+
val c = key[d]
100+
node.next[c.toInt()] = putSubNode(node.next[c.toInt()], key, v, d + 1)
101+
return node
102+
}
103+
104+
fun delete(key: String) {
105+
root = deleteNode(root, key, 0)
106+
}
107+
108+
private fun deleteNode(x: TrieNode<Value>?, key: String, d: Int): TrieNode<Value>? {
109+
/**
110+
* 1. 检查节点是否为Null。
111+
*/
112+
if (x == null) return null
113+
114+
/**
115+
* 2. 如果到达Radix的最后一位,删除节点数据。否则继续执行删除节点操作。
116+
*/
117+
if (d == key.length) {
118+
x.value = null
119+
} else {
120+
val c = key[d]
121+
x.next[c.toInt()] = deleteNode(x.next[c.toInt()], key, d + 1)
122+
}
123+
124+
/**
125+
* 3. 遍历删除后当前节点的Value如果不为空,说明Root没有变化。否则Root节点替换为子节点(Value不为空)。
126+
*/
127+
if (x.value != null) {
128+
return x
129+
} else {
130+
for (c in 0..radix - 1) {
131+
if (x.next[c] != null) {
132+
return x
133+
}
134+
}
135+
}
136+
return null
137+
}
138+
139+
}
140+
141+
fun main(args: Array<String>) {
142+
val trie1 = Trie<Int>()
143+
144+
trie1.put("hello", 342)
145+
println(Hex.toHexString(trie1.root?.hash()))
146+
147+
trie1.put("message", 432)
148+
println(Hex.toHexString(trie1.root?.hash()))
149+
150+
trie1.put("message2", 456)
151+
println(Hex.toHexString(trie1.root?.hash()))
152+
153+
trie1.put("message3", 555)
154+
println(Hex.toHexString(trie1.root?.hash()))
155+
156+
trie1.delete("message2")
157+
println(Hex.toHexString(trie1.root?.hash()))
158+
159+
println(trie1.get("hello"))
160+
}

‎src/main/kotlin/mbc/util/CodecUtil.kt

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,26 @@ object CodecUtil {
256256
return BigInteger(b).toInt()
257257
}
258258

259+
fun asn1Encode(v: Any): ASN1Object {
260+
if (v is ByteArray) {
261+
return DERBitString(v)
262+
} else if (v is String) {
263+
return DERUTF8String(v)
264+
} else if (v is Int) {
265+
return ASN1Integer(v.toLong())
266+
} else if (v is Long) {
267+
return ASN1Integer(v)
268+
} else if (v is BigInteger) {
269+
return ASN1Integer(v)
270+
} else if (v is Array<*>) {
271+
val vec = ASN1EncodableVector()
272+
273+
v.forEach { vec.add(it?.let { asn1Encode(it) }) }
274+
275+
return DERSequence(vec)
276+
} else {
277+
throw Exception("Can not convert type ${v.javaClass} to ASN1 object.")
278+
}
279+
}
280+
259281
}

‎src/main/kotlin/mbc/util/CryptoUtil.kt

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ class CryptoUtil {
103103
}
104104

105105
/**
106-
* SHA-256预算
106+
* SHA-256
107107
*/
108108
fun sha256(msg: ByteArray): ByteArray {
109109
val digest = MessageDigest.getInstance("SHA-256", "SC")
@@ -113,6 +113,13 @@ class CryptoUtil {
113113
return hash
114114
}
115115

116+
/**
117+
* SHA3
118+
*/
119+
fun sha3(msg: ByteArray): ByteArray {
120+
return sha256((msg))
121+
}
122+
116123
fun deserializePrivateKey(bytes: ByteArray): PrivateKey {
117124
val kf = KeyFactory.getInstance("EC", "SC")
118125
return kf.generatePrivate(PKCS8EncodedKeySpec(bytes))
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package mbc.core
2+
3+
import mbc.storage.LevelDbDataSource
4+
import mbc.trie.*
5+
import org.junit.Assert.assertArrayEquals
6+
import org.junit.Test
7+
import kotlin.test.assertEquals
8+
import kotlin.test.assertNotNull
9+
10+
class PatriciaTrieTest {
11+
12+
@Test fun testBinToNibbles() {
13+
val res0 = binToNibbbles("".toByteArray())
14+
assertArrayEquals(res0, arrayOf<Int>())
15+
16+
val res1 = binToNibbbles("h".toByteArray())
17+
assertArrayEquals(res1, arrayOf(6, 8))
18+
19+
val res2 = binToNibbbles("he".toByteArray())
20+
assertArrayEquals(res2, arrayOf(6, 8, 6, 5))
21+
22+
val res3 = binToNibbbles("hello".toByteArray())
23+
assertArrayEquals(res2, arrayOf(6, 8, 6, 5, 6, 12, 6, 12, 6, 15))
24+
}
25+
26+
@Test fun testNibblesToBin() {
27+
val res0 = nibblesToBin(arrayOf())
28+
assertArrayEquals(res0, "".toByteArray())
29+
30+
val res1 = nibblesToBin(arrayOf(6, 8))
31+
assertArrayEquals(res1, "h".toByteArray())
32+
33+
val res2 = nibblesToBin(arrayOf(6, 8, 6, 5))
34+
assertArrayEquals(res2, "he".toByteArray())
35+
36+
val res3 = nibblesToBin(arrayOf(6, 8, 6, 5, 6, 12, 6, 12, 6, 15))
37+
assertArrayEquals(res3, "hello".toByteArray())
38+
}
39+
40+
@Test fun testPackNibbles() {
41+
val key = arrayOf(0, 1, 0, 1, 0, 2)
42+
val packed = packNibbles(withTerminator(key))
43+
44+
val expected = arrayOf(0x20.toByte(), 0x01.toByte(), 0x01.toByte(), 0x02.toByte()).toByteArray()
45+
assertArrayEquals(packed, expected)
46+
}
47+
48+
@Test fun testUnpackToNibbles() {
49+
val bin = arrayOf(0x20.toByte(), 0x01.toByte(), 0x01.toByte(), 0x02.toByte()).toByteArray()
50+
val nibbles = unpackToNibbles(bin)
51+
52+
assertArrayEquals(nibbles, arrayOf(0, 1, 0, 1, 0, 2, NIBBLE_TERMINATOR))
53+
}
54+
55+
@Test fun testTrieNodeEncodeDecode() {
56+
val blankNode = BLANK_NODE
57+
val trie = PatriciaTrie(LevelDbDataSource("test", "test-database"))
58+
val blankNodeEncoded = trie.encodeNode(blankNode)
59+
assertNotNull(blankNodeEncoded)
60+
assertEquals(trie.decodeToNode(blankNodeEncoded), BLANK_NODE)
61+
62+
val leafNode = TrieNode(arrayOf(0x20.toByte(), 0x01.toByte(), 0x01.toByte(), 0x02.toByte()).toByteArray(),
63+
"hello".toByteArray(), null)
64+
assertEquals(leafNode.type, NODE_TYPE.NODE_TYPE_LEAF)
65+
66+
val leafNodeEncoded = trie.encodeNode(leafNode)
67+
assertNotNull(leafNodeEncoded)
68+
assertEquals(trie.decodeToNode(leafNodeEncoded), leafNode)
69+
70+
val extensionNode = TrieNode(arrayOf(0x01.toByte(), 0x01.toByte(), 0x01.toByte(), 0x02.toByte()).toByteArray(),
71+
EMPTY_VALUE, null)
72+
assertEquals(extensionNode.type, NODE_TYPE.NODE_TYPE_EXTENSION)
73+
74+
val extensionNodeEncoded = trie.encodeNode(extensionNode)
75+
assertNotNull(extensionNodeEncoded)
76+
assertEquals(trie.decodeToNode(extensionNodeEncoded), extensionNode)
77+
78+
val branchNode = TrieNode(EMPTY_VALUE, "hello".toByteArray(),
79+
arrayOf(EMPTY_VALUE, "234".toByteArray(), "3233".toByteArray(), EMPTY_VALUE, EMPTY_VALUE,
80+
EMPTY_VALUE, EMPTY_VALUE, EMPTY_VALUE, EMPTY_VALUE, EMPTY_VALUE, EMPTY_VALUE,
81+
EMPTY_VALUE, EMPTY_VALUE, EMPTY_VALUE, EMPTY_VALUE, EMPTY_VALUE))
82+
assertEquals(branchNode.type, NODE_TYPE.NODE_TYPE_BRANCH)
83+
84+
val branchNodeEncoded = trie.encodeNode(branchNode)
85+
assertNotNull(branchNodeEncoded)
86+
assertEquals(trie.decodeToNode(branchNodeEncoded), branchNode)
87+
}
88+
89+
@Test fun testPutGet() {
90+
val db = LevelDbDataSource("test", "test-database")
91+
db.init()
92+
val trie = PatriciaTrie(db)
93+
trie.update(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray(), "11111".toByteArray())
94+
trie.update(arrayOf<Byte>(0x01, 0x01, 0x03).toByteArray(), "22222".toByteArray())
95+
96+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray()), "11111".toByteArray())
97+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x03).toByteArray()), "22222".toByteArray())
98+
99+
trie.delete(arrayOf<Byte>(0x01, 0x01, 0x03).toByteArray())
100+
101+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x03).toByteArray()), EMPTY_VALUE)
102+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray()), "11111".toByteArray())
103+
}
104+
105+
@Test fun testChangeRoot() {
106+
val db = LevelDbDataSource("test", "test-database")
107+
db.init()
108+
val trie = PatriciaTrie(db)
109+
trie.update(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray(), "11111".toByteArray())
110+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray()), "11111".toByteArray())
111+
val rootHash1 = trie.rootHash
112+
113+
114+
trie.update(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray(), "22222".toByteArray())
115+
val rootHash2 = trie.rootHash
116+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray()), "22222".toByteArray())
117+
118+
trie.changeRoot(rootHash1)
119+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray()), "11111".toByteArray())
120+
121+
trie.changeRoot(rootHash2)
122+
assertArrayEquals(trie.get(arrayOf<Byte>(0x01, 0x01, 0x02).toByteArray()), "22222".toByteArray())
123+
}
124+
}

‎src/test/kotlin/mbc/core/TrieTest.kt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package mbc.core
2+
3+
import mbc.storage.MemoryDataSource
4+
import mbc.trie.Trie
5+
import org.iq80.leveldb.CompressionType
6+
import org.iq80.leveldb.DB
7+
import org.iq80.leveldb.Options
8+
import org.iq80.leveldb.impl.Iq80DBFactory
9+
import org.junit.Assert.assertFalse
10+
import org.junit.Assert.assertTrue
11+
import org.junit.Test
12+
import org.spongycastle.util.encoders.Hex
13+
import java.io.File
14+
import java.util.*
15+
import kotlin.test.assertEquals
16+
import kotlin.test.assertNull
17+
18+
class TrieTest {
19+
20+
@Test fun testTrie() {
21+
val trie1 = Trie<Int>()
22+
23+
trie1.put("hello", 342)
24+
println(Hex.toHexString(trie1.root?.hash()))
25+
26+
trie1.put("message", 432)
27+
println(Hex.toHexString(trie1.root?.hash()))
28+
29+
trie1.put("message2", 456)
30+
println(Hex.toHexString(trie1.root?.hash()))
31+
32+
trie1.put("message3", 555)
33+
println(Hex.toHexString(trie1.root?.hash()))
34+
35+
trie1.delete("message2")
36+
println(Hex.toHexString(trie1.root?.hash()))
37+
38+
assertEquals(trie1.get("hello"), 342)
39+
assertNull(trie1.get("helo"))
40+
}
41+
42+
@Test fun readDb() {
43+
val db: DB
44+
val options = Options()
45+
options.createIfMissing(false)
46+
47+
val factory = Iq80DBFactory.factory
48+
db = factory.open(File("/Users/qikh/github/understanding_ethereum_trie/triedb"), options)
49+
50+
db.iterator().use { iterator ->
51+
val result = HashSet<ByteArray>()
52+
iterator.seekToFirst()
53+
while (iterator.hasNext()) {
54+
result.add(iterator.peekNext().key)
55+
iterator.next()
56+
}
57+
58+
result.forEach { println(it) }
59+
}
60+
}
61+
}

0 commit comments

Comments
 (0)
Please sign in to comment.