1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152
| import java.util.Arrays;
public class SegmentTree<E> { private E[] data; private E[] tree; private Merger<E> merger;
public SegmentTree(E[] arr, Merger<E> merger) { this.merger = merger; data = (E[]) new Object[arr.length]; for (int i = 0; i < arr.length; i++) { data[i] = arr[i]; } tree = (E[]) new Object[4 * arr.length]; buildSegmentTree(0, 0, arr.length - 1); }
private void buildSegmentTree(int treeIndex, int l, int r) { if (l == r) { tree[treeIndex] = data[l]; return; } int mid = l + (r - l) / 2; buildSegmentTree(leftChild(treeIndex), l, mid); buildSegmentTree(rightChild(treeIndex), mid + 1, r); tree[treeIndex] = merger.merge(tree[leftChild(treeIndex)], tree[rightChild(treeIndex)]); }
public E query(int queryL, int queryR) { if (queryL < 0 || queryL >= data.length || queryR < 0 || queryR >= data.length || queryL > queryR) { throw new IllegalArgumentException("索引不正确"); }
return query(0, 0, data.length - 1, queryL, queryR); }
private E query(int treeIndex, int l, int r, int queryL, int queryR) { if (l == queryL && r == queryR) { return tree[treeIndex]; } int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (queryL >= mid + 1) { return query(rightTreeIndex, mid + 1, r, queryL, queryR); } else if (queryR <= mid) { return query(leftTreeIndex, l, mid, queryL, queryR); } else { E leftResult = query(leftTreeIndex, l, mid, queryL, mid); E rightResult = query(rightTreeIndex, mid + 1, r, mid + 1, queryR); return merger.merge(leftResult, rightResult); } }
public void set(int index, E e) { if (index < 0 || index >= data.length) { throw new IllegalArgumentException("索引不正确"); } data[index] = e; set(0, 0, data.length - 1, index, e); }
private void set(int treeIndex, int l, int r, int index, E e) { if (l == r) { tree[treeIndex] = e; return; } int mid = l + (r - l) / 2; int leftTreeIndex = leftChild(treeIndex); int rightTreeIndex = rightChild(treeIndex); if (index >= mid + 1) { set(rightTreeIndex, mid + 1, r, index, e); } else { set(leftTreeIndex,l,mid,index,e); } tree[treeIndex] = merger.merge(tree[leftTreeIndex],tree[rightTreeIndex]); }
public E get(int index) { if (index < 0 || index >= data.length) { throw new IllegalArgumentException("索引不正确"); } return data[index]; }
private int leftChild(int index) { return 2 * index + 1; }
private int rightChild(int index) { return 2 * index + 2; }
@Override public String toString() { return "SegmentTree{" + "tree=" + Arrays.toString(tree) + '}'; } }
|