Skip to content

Commit c364a49

Browse files
wzx140王子旋
authored andcommitted
[FLINK-38432][python] Fix missing finish/reset for keyWriter/valueWriter in MapWriter
Signed-off-by: 王子旋 <wangzixuan5@xiaohongshu.com>
1 parent 04e6437 commit c364a49

File tree

2 files changed

+149
-0
lines changed

2 files changed

+149
-0
lines changed

flink-python/src/main/java/org/apache/flink/table/runtime/arrow/writers/MapWriter.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,20 @@ public void doWrite(T in, int ordinal) {
8484
}
8585
}
8686

87+
@Override
88+
public void finish() {
89+
super.finish();
90+
keyWriter.finish();
91+
valueWriter.finish();
92+
}
93+
94+
@Override
95+
public void reset() {
96+
super.reset();
97+
keyWriter.reset();
98+
valueWriter.reset();
99+
}
100+
87101
// ------------------------------------------------------------------------------------------
88102

89103
/** {@link MapWriter} for {@link RowData} input. */
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.flink.table.runtime.arrow.writers;
20+
21+
import org.apache.flink.table.data.GenericMapData;
22+
import org.apache.flink.table.data.GenericRowData;
23+
import org.apache.flink.table.data.MapData;
24+
import org.apache.flink.table.data.RowData;
25+
import org.apache.flink.table.runtime.arrow.ArrowUtils;
26+
27+
import org.apache.flink.shaded.guava33.com.google.common.collect.ImmutableMap;
28+
29+
import org.apache.arrow.memory.BufferAllocator;
30+
import org.apache.arrow.vector.IntVector;
31+
import org.apache.arrow.vector.complex.MapVector;
32+
import org.apache.arrow.vector.complex.StructVector;
33+
import org.apache.arrow.vector.types.pojo.ArrowType;
34+
import org.apache.arrow.vector.types.pojo.Field;
35+
import org.apache.arrow.vector.types.pojo.FieldType;
36+
import org.junit.jupiter.api.AfterEach;
37+
import org.junit.jupiter.api.Assertions;
38+
import org.junit.jupiter.api.BeforeEach;
39+
import org.junit.jupiter.api.Test;
40+
41+
import java.util.Arrays;
42+
import java.util.HashMap;
43+
import java.util.List;
44+
import java.util.Map;
45+
46+
import static org.junit.jupiter.api.Assertions.assertEquals;
47+
import static org.junit.jupiter.api.Assertions.assertFalse;
48+
49+
/** Tests for {@link MapWriter}. */
50+
class MapWriterTest {
51+
52+
private BufferAllocator allocator;
53+
54+
@BeforeEach
55+
void setUp() {
56+
allocator = ArrowUtils.getRootAllocator().newChildAllocator("test", 0, Long.MAX_VALUE);
57+
}
58+
59+
@AfterEach
60+
void tearDown() {
61+
if (allocator != null) {
62+
allocator.close();
63+
}
64+
}
65+
66+
@Test
67+
void testMapWriterMultipleWritesAndReset() {
68+
// Create Map<int, int> vector
69+
Field keyField = new Field("key", FieldType.notNullable(new ArrowType.Int(32, true)), null);
70+
Field valueField =
71+
new Field("value", FieldType.nullable(new ArrowType.Int(32, true)), null);
72+
FieldType mapType = new FieldType(true, new ArrowType.Map(false), null);
73+
Field mapField =
74+
new Field(
75+
"myMap",
76+
mapType,
77+
List.of(
78+
new Field(
79+
"entries",
80+
FieldType.notNullable(ArrowType.Struct.INSTANCE),
81+
Arrays.asList(keyField, valueField))));
82+
// Create arrow writer
83+
try (MapVector mapVector = (MapVector) mapField.createVector(allocator)) {
84+
StructVector structVector = (StructVector) mapVector.getDataVector();
85+
IntVector keyVector = (IntVector) structVector.getChild(MapVector.KEY_NAME);
86+
IntVector valueVector = (IntVector) structVector.getChild(MapVector.VALUE_NAME);
87+
MapWriter<RowData> mapWriter =
88+
MapWriter.forRow(
89+
mapVector, IntWriter.forArray(keyVector), IntWriter.forArray(valueVector));
90+
// Write once
91+
mapWriter.write(createRowData(ImmutableMap.of(1, 1, 2, 2)), 0);
92+
mapWriter.write(createRowData(ImmutableMap.of(1, 1)), 0);
93+
mapWriter.finish();
94+
assertEquals(2, mapVector.getValueCount());
95+
checkMapVector(mapVector, keyVector, valueVector, 0, ImmutableMap.of(1, 1, 2, 2));
96+
checkMapVector(mapVector, keyVector, valueVector, 1, ImmutableMap.of(1, 1));
97+
// Reset and write again
98+
mapWriter.reset();
99+
mapWriter.write(createRowData(ImmutableMap.of(1, 1, 2, 2)), 0);
100+
mapWriter.write(createRowData(ImmutableMap.of(1, 1)), 0);
101+
mapWriter.finish();
102+
assertEquals(2, mapVector.getValueCount());
103+
checkMapVector(mapVector, keyVector, valueVector, 0, ImmutableMap.of(1, 1, 2, 2));
104+
checkMapVector(mapVector, keyVector, valueVector, 1, ImmutableMap.of(1, 1));
105+
}
106+
}
107+
108+
private RowData createRowData(Map<Integer, Integer> map) {
109+
MapData mapData = new GenericMapData(map);
110+
return GenericRowData.of(mapData);
111+
}
112+
113+
private void checkMapVector(
114+
MapVector mapVector,
115+
IntVector keyVector,
116+
IntVector valueVector,
117+
int rowIndex,
118+
Map<Integer, Integer> expected) {
119+
if (mapVector.isNull(rowIndex)) {
120+
Assertions.assertNull(expected);
121+
}
122+
123+
int start = mapVector.getOffsetBuffer().getInt(rowIndex * MapVector.OFFSET_WIDTH);
124+
int end = mapVector.getOffsetBuffer().getInt((rowIndex + 1) * MapVector.OFFSET_WIDTH);
125+
126+
Map<Integer, Integer> result = new HashMap<>();
127+
for (int i = start; i < end; i++) {
128+
assertFalse(keyVector.isNull(i));
129+
int key = keyVector.get(i);
130+
Integer value = valueVector.isNull(i) ? null : valueVector.get(i);
131+
result.put(key, value);
132+
}
133+
assertEquals(expected, result);
134+
}
135+
}

0 commit comments

Comments
 (0)