Here are the examples of the java api org.apache.flink.types.Row taken from open source projects. By voting up you can indicate which examples are most useful and appropriate.
1058 Examples
19
Source : JsonRowSerializationSchemaTest.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
// --------------------------------------------------------------------------------------------
private Row serializeAndDeserialize(TypeInformation<Row> rowSchema, Row row) throws IOException {
final JsonRowSerializationSchema serializationSchema = new JsonRowSerializationSchema(rowSchema);
final JsonRowDeserializationSchema deserializationSchema = new JsonRowDeserializationSchema(rowSchema);
final byte[] bytes = serializationSchema.serialize(row);
return deserializationSchema.deserialize(bytes);
}
19
Source : AvroRowDeSerializationSchemaTest.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private void testSerializability(AvroRowSerializationSchema ser, AvroRowDeserializationSchema deser, Row data) throws Exception {
final byte[] serBytes = InstantiationUtil.serializeObject(ser);
final byte[] deserBytes = InstantiationUtil.serializeObject(deser);
final AvroRowSerializationSchema serCopy = InstantiationUtil.deserializeObject(serBytes, Thread.currentThread().getContextClreplacedLoader());
final AvroRowDeserializationSchema deserCopy = InstantiationUtil.deserializeObject(deserBytes, Thread.currentThread().getContextClreplacedLoader());
final byte[] bytes = serCopy.serialize(data);
deserCopy.deserialize(bytes);
deserCopy.deserialize(bytes);
final Row actual = deserCopy.deserialize(bytes);
replacedertEquals(data, actual);
}
19
Source : RowComparatorWithManyFieldsTests.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
/**
* Tests {@link RowComparator} for wide rows.
*/
public clreplaced RowComparatorWithManyFieldsTests extends ComparatorTestBase<Row> {
private static final int numberOfFields = 10;
private static RowTypeInfo typeInfo;
private static final Row[] data = new Row[] { createRow(null, "b0", "c0", "d0", "e0", "f0", "g0", "h0", "i0", "j0"), createRow("a1", "b1", "c1", "d1", "e1", "f1", "g1", "h1", "i1", "j1"), createRow("a2", "b2", "c2", "d2", "e2", "f2", "g2", "h2", "i2", "j2"), createRow("a3", "b3", "c3", "d3", "e3", "f3", "g3", "h3", "i3", "j3") };
@BeforeClreplaced
public static void setUp() throws Exception {
TypeInformation<?>[] fieldTypes = new TypeInformation[numberOfFields];
for (int i = 0; i < numberOfFields; i++) {
fieldTypes[i] = BasicTypeInfo.STRING_TYPE_INFO;
}
typeInfo = new RowTypeInfo(fieldTypes);
}
@Override
protected void deepEquals(String message, Row should, Row is) {
int arity = should.getArity();
replacedertEquals(message, arity, is.getArity());
for (int i = 0; i < arity; i++) {
Object copiedValue = should.getField(i);
Object element = is.getField(i);
replacedertEquals(message, element, copiedValue);
}
}
@Override
protected TypeComparator<Row> createComparator(boolean ascending) {
return typeInfo.createComparator(new int[] { 0 }, new boolean[] { ascending }, 0, new ExecutionConfig());
}
@Override
protected TypeSerializer<Row> createSerializer() {
return typeInfo.createSerializer(new ExecutionConfig());
}
@Override
protected Row[] getSortedTestData() {
return data;
}
@Override
protected boolean supportsNullKeys() {
return true;
}
private static Row createRow(Object... values) {
checkNotNull(values);
checkArgument(values.length == numberOfFields);
Row row = new Row(numberOfFields);
for (int i = 0; i < values.length; i++) {
row.setField(i, values[i]);
}
return row;
}
}
19
Source : RowComparatorTest.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
public clreplaced RowComparatorTest extends ComparatorTestBase<Row> {
private static final RowTypeInfo typeInfo = new RowTypeInfo(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.DOUBLE_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO, new TupleTypeInfo<Tuple3<Integer, Boolean, Short>>(BasicTypeInfo.INT_TYPE_INFO, BasicTypeInfo.BOOLEAN_TYPE_INFO, BasicTypeInfo.SHORT_TYPE_INFO), TypeExtractor.createTypeInfo(MyPojo.clreplaced));
private static MyPojo testPojo1 = new MyPojo();
private static MyPojo testPojo2 = new MyPojo();
private static MyPojo testPojo3 = new MyPojo();
private static final Row[] data = new Row[] { createRow(null, null, null, null, null), createRow(0, null, null, null, null), createRow(0, 0.0, null, null, null), createRow(0, 0.0, "a", null, null), createRow(1, 0.0, "a", null, null), createRow(1, 1.0, "a", null, null), createRow(1, 1.0, "b", null, null), createRow(1, 1.0, "b", new Tuple3<>(1, false, (short) 2), null), createRow(1, 1.0, "b", new Tuple3<>(2, false, (short) 2), null), createRow(1, 1.0, "b", new Tuple3<>(2, true, (short) 2), null), createRow(1, 1.0, "b", new Tuple3<>(2, true, (short) 3), null), createRow(1, 1.0, "b", new Tuple3<>(2, true, (short) 3), testPojo1), createRow(1, 1.0, "b", new Tuple3<>(2, true, (short) 3), testPojo2), createRow(1, 1.0, "b", new Tuple3<>(2, true, (short) 3), testPojo3) };
@BeforeClreplaced
public static void init() {
// TODO we cannot test null here as PojoComparator has no support for null keys
testPojo1.name = "";
testPojo2.name = "Test1";
testPojo3.name = "Test2";
}
@Override
protected void deepEquals(String message, Row should, Row is) {
int arity = should.getArity();
replacedertEquals(message, arity, is.getArity());
for (int i = 0; i < arity; i++) {
Object copiedValue = should.getField(i);
Object element = is.getField(i);
replacedertEquals(message, element, copiedValue);
}
}
@Override
protected TypeComparator<Row> createComparator(boolean ascending) {
return typeInfo.createComparator(new int[] { 0, 1, 2, 3, 4, 5, 6 }, new boolean[] { ascending, ascending, ascending, ascending, ascending, ascending, ascending }, 0, new ExecutionConfig());
}
@Override
protected TypeSerializer<Row> createSerializer() {
return typeInfo.createSerializer(new ExecutionConfig());
}
@Override
protected Row[] getSortedTestData() {
return data;
}
@Override
protected boolean supportsNullKeys() {
return true;
}
private static Row createRow(Object f0, Object f1, Object f2, Object f3, Object f4) {
Row row = new Row(5);
row.setField(0, f0);
row.setField(1, f1);
row.setField(2, f2);
row.setField(3, f3);
row.setField(4, f4);
return row;
}
public static clreplaced MyPojo implements Serializable, Comparable<MyPojo> {
// we cannot use null because the PojoComparator does not support null properly
public String name = "";
@Override
public int compareTo(MyPojo o) {
if (name == null && o.name == null) {
return 0;
} else if (name == null) {
return -1;
} else if (o.name == null) {
return 1;
} else {
return name.compareTo(o.name);
}
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClreplaced() != o.getClreplaced()) {
return false;
}
MyPojo myPojo = (MyPojo) o;
return name != null ? name.equals(myPojo.name) : myPojo.name == null;
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
/**
* Sets a repeating value to all objects or row fields of the preplaceded vals array.
*
* @param vals The array of objects or Rows.
* @param fieldIdx If the objs array is an array of Row, the index of the field that needs to be filled.
* Otherwise a -1 must be preplaceded and the data is directly filled into the array.
* @param repeatingValue The value that is set.
* @param childCount The number of times the value is set.
*/
private static void fillColumnWithRepeatingValue(Object[] vals, int fieldIdx, Object repeatingValue, int childCount) {
if (fieldIdx == -1) {
// set value as an object
Arrays.fill(vals, 0, childCount, repeatingValue);
} else {
// set value as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
rows[i].setField(fieldIdx, repeatingValue);
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private static <T> void readNonNullDoubleColumn(Object[] vals, int fieldIdx, DoubleColumnVector vector, int childCount, DoubleFunction<T> reader) {
if (vector.isRepeating) {
// fill complete column with first value
T repeatingValue = reader.apply(vector.vector[0]);
fillColumnWithRepeatingValue(vals, fieldIdx, repeatingValue, childCount);
} else {
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
vals[i] = reader.apply(vector.vector[i]);
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
rows[i].setField(fieldIdx, reader.apply(vector.vector[i]));
}
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private static <T> void readDoubleColumn(Object[] vals, int fieldIdx, DoubleColumnVector vector, int childCount, DoubleFunction<T> reader) {
if (vector.isRepeating) {
// fill complete column with first value
if (vector.isNull[0]) {
// fill vals with null values
fillColumnWithRepeatingValue(vals, fieldIdx, null, childCount);
} else {
// read repeating non-null value by forwarding call
readNonNullDoubleColumn(vals, fieldIdx, vector, childCount, reader);
}
} else {
boolean[] isNullVector = vector.isNull;
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
if (isNullVector[i]) {
vals[i] = null;
} else {
vals[i] = reader.apply(vector.vector[i]);
}
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
if (isNullVector[i]) {
rows[i].setField(fieldIdx, null);
} else {
rows[i].setField(fieldIdx, reader.apply(vector.vector[i]));
}
}
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private static void readNonNullBytesColumnreplacedtring(Object[] vals, int fieldIdx, BytesColumnVector bytes, int childCount) {
if (bytes.isRepeating) {
// fill complete column with first value
String repeatingValue = readString(bytes.vector[0], bytes.start[0], bytes.length[0]);
fillColumnWithRepeatingValue(vals, fieldIdx, repeatingValue, childCount);
} else {
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
vals[i] = readString(bytes.vector[i], bytes.start[i], bytes.length[i]);
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
rows[i].setField(fieldIdx, readString(bytes.vector[i], bytes.start[i], bytes.length[i]));
}
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private static void readNonNullTimestampColumn(Object[] vals, int fieldIdx, TimestampColumnVector vector, int childCount) {
if (vector.isRepeating) {
// fill complete column with first value
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
// do not reuse value to prevent object mutation
vals[i] = readTimestamp(vector.time[0], vector.nanos[0]);
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
// do not reuse value to prevent object mutation
rows[i].setField(fieldIdx, readTimestamp(vector.time[0], vector.nanos[0]));
}
}
} else {
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
vals[i] = readTimestamp(vector.time[i], vector.nanos[i]);
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
rows[i].setField(fieldIdx, readTimestamp(vector.time[i], vector.nanos[i]));
}
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private static void readBytesColumnreplacedtring(Object[] vals, int fieldIdx, BytesColumnVector bytes, int childCount) {
if (bytes.isRepeating) {
// fill complete column with first value
if (bytes.isNull[0]) {
// fill vals with null values
fillColumnWithRepeatingValue(vals, fieldIdx, null, childCount);
} else {
// read repeating non-null value by forwarding call
readNonNullBytesColumnreplacedtring(vals, fieldIdx, bytes, childCount);
}
} else {
boolean[] isNullVector = bytes.isNull;
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
if (isNullVector[i]) {
vals[i] = null;
} else {
vals[i] = readString(bytes.vector[i], bytes.start[i], bytes.length[i]);
}
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
if (isNullVector[i]) {
rows[i].setField(fieldIdx, null);
} else {
rows[i].setField(fieldIdx, readString(bytes.vector[i], bytes.start[i], bytes.length[i]));
}
}
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private static void readBytesColumnAsBinary(Object[] vals, int fieldIdx, BytesColumnVector bytes, int childCount) {
if (bytes.isRepeating) {
// fill complete column with first value
if (bytes.isNull[0]) {
// fill vals with null values
fillColumnWithRepeatingValue(vals, fieldIdx, null, childCount);
} else {
// read repeating non-null value by forwarding call
readNonNullBytesColumnAsBinary(vals, fieldIdx, bytes, childCount);
}
} else {
boolean[] isNullVector = bytes.isNull;
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
if (isNullVector[i]) {
vals[i] = null;
} else {
vals[i] = readBinary(bytes.vector[i], bytes.start[i], bytes.length[i]);
}
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
if (isNullVector[i]) {
rows[i].setField(fieldIdx, null);
} else {
rows[i].setField(fieldIdx, readBinary(bytes.vector[i], bytes.start[i], bytes.length[i]));
}
}
}
}
}
19
Source : OrcBatchReader.java
with Apache License 2.0
from ljygz
with Apache License 2.0
from ljygz
private static void readNonNullBytesColumnAsBinary(Object[] vals, int fieldIdx, BytesColumnVector bytes, int childCount) {
if (bytes.isRepeating) {
// fill complete column with first value
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
// don't reuse repeating val to avoid object mutation
vals[i] = readBinary(bytes.vector[0], bytes.start[0], bytes.length[0]);
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
// don't reuse repeating val to avoid object mutation
rows[i].setField(fieldIdx, readBinary(bytes.vector[0], bytes.start[0], bytes.length[0]));
}
}
} else {
if (fieldIdx == -1) {
// set as an object
for (int i = 0; i < childCount; i++) {
vals[i] = readBinary(bytes.vector[i], bytes.start[i], bytes.length[i]);
}
} else {
// set as a field of Row
Row[] rows = (Row[]) vals;
for (int i = 0; i < childCount; i++) {
rows[i].setField(fieldIdx, readBinary(bytes.vector[i], bytes.start[i], bytes.length[i]));
}
}
}
}
19
Source : ReplicateRows.java
with Apache License 2.0
from flink-tpc-ds
with Apache License 2.0
from flink-tpc-ds
/**
* Replicate the row N times. N is specified as the first argument to the function.
* This is an internal function solely used by optimizer to rewrite EXCEPT ALL AND
* INTERSECT ALL queries.
*/
public clreplaced ReplicateRows extends TableFunction<Row> {
private static final long serialVersionUID = 1L;
private final TypeInformation[] fieldTypes;
private transient Row reuseRow;
public ReplicateRows(TypeInformation[] fieldTypes) {
this.fieldTypes = fieldTypes;
}
public void eval(Object... inputs) {
checkArgument(inputs.length == fieldTypes.length + 1);
long numRows = (long) inputs[0];
if (reuseRow == null) {
reuseRow = new Row(fieldTypes.length);
}
for (int i = 0; i < fieldTypes.length; i++) {
reuseRow.setField(i, inputs[i + 1]);
}
for (int i = 0; i < numRows; i++) {
collect(reuseRow);
}
}
@Override
public TypeInformation<Row> getResultType() {
return new RowTypeInfo(fieldTypes);
}
@Override
public TypeInformation<?>[] getParameterTypes(Clreplaced<?>[] signature) {
TypeInformation[] paraTypes = new TypeInformation[1 + fieldTypes.length];
paraTypes[0] = Types.LONG;
System.arraycopy(fieldTypes, 0, paraTypes, 1, fieldTypes.length);
return paraTypes;
}
}
19
Source : JDBCUpsertOutputFormatTest.java
with Apache License 2.0
from flink-tpc-ds
with Apache License 2.0
from flink-tpc-ds
private void check(Row[] rows) throws SQLException {
check(rows, DB_URL, OUTPUT_TABLE, fieldNames);
}
19
Source : RestapiOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
private void requestErrorMessage(Exception e, int index, Row row) {
if (index < row.getArity()) {
recordConvertDetailErrorMessage(index, row);
LOG.warn("添加脏数据:" + row.getField(index));
}
}
19
Source : RedisOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
private String concatValues(Row row) {
return StringUtils.join(getValues(row), valueFieldDelimiter);
}
19
Source : JdbcOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
protected void processWriteException(Exception e, int index, Row row) throws WriteRecordException {
if (e instanceof SQLException) {
if (e.getMessage().contains(CONN_CLOSE_ERROR_MSG)) {
throw new RuntimeException("Connection maybe closed", e);
}
}
if (index < row.getArity()) {
String message = recordConvertDetailErrorMessage(index, row);
LOG.error(message, e);
throw new WriteRecordException(message, e, index, row);
}
throw new WriteRecordException(e.getMessage(), e);
}
19
Source : KingbaseOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
private void writeSingleRecordCommit(Row row) throws WriteRecordException {
try {
super.writeSingleRecordInternal(row);
try {
dbConn.commit();
} catch (Exception e) {
// 提交失败直接结束任务
throw new RuntimeException(e);
}
} catch (WriteRecordException e) {
try {
dbConn.rollback();
} catch (Exception e1) {
// 回滚失败直接结束任务
throw new RuntimeException(e);
}
throw e;
}
}
19
Source : HiveOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
@Override
protected void writeSingleRecordInternal(Row row) throws WriteRecordException {
}
19
Source : ErrorLimiter.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
public void setErrorData(Row errorData) {
this.errorData = errorData;
}
19
Source : BaseRichOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
private Row setChannelInfo(Row row) {
Row internalRow = new Row(row.getArity() - 1);
for (int i = 0; i < internalRow.getArity(); i++) {
internalRow.setField(i, row.getField(i));
}
return internalRow;
}
19
Source : BaseRichOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
protected String recordConvertDetailErrorMessage(int pos, Row row) {
return getClreplaced().getName() + " WriteRecord error: when converting field[" + pos + "] in Row(" + row + ")";
}
19
Source : BaseFileOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
/**
* @author jiangbo
* @date 2019/8/28
*/
public abstract clreplaced BaseFileOutputFormat extends BaseRichOutputFormat {
protected Row lastRow;
protected String currentBlockFileNamePrefix;
protected String currentBlockFileName;
protected long sumRowsOfBlock;
protected long rowsOfCurrentBlock;
protected long maxFileSize;
protected long flushInterval = 0;
protected static final String APPEND_MODE = "APPEND";
protected static final String DATA_SUBDIR = ".data";
protected static final String FINISHED_SUBDIR = ".finished";
protected static final String ACTION_FINISHED = ".action_finished";
protected static final String RESTART_FILE_NAME_SUFFIX = "restart";
protected static final String JOB_ID_DELIMITER = "_";
protected static final int SECOND_WAIT = 30;
protected static final String SP = "/";
protected String charsetName = "UTF-8";
protected String outputFilePath;
protected String path;
protected String fileName;
protected String tmpPath;
protected String finishedPath;
protected String actionFinishedTag;
/**
* 写入模式
*/
protected String writeMode;
/**
* 压缩方式
*/
protected String compress;
protected boolean readyCheckpoint;
protected int blockIndex = 0;
protected boolean makeDir = true;
private long nextNumForCheckDataSize = 1000;
private long lastWriteSize;
protected long lastWriteTime = System.currentTimeMillis();
@Override
protected void openInternal(int taskNumber, int numTasks) throws IOException {
initFileIndex();
initPath();
openSource();
actionBeforeWriteData();
nextBlock();
}
protected void initPath() {
if (StringUtils.isNotBlank(fileName)) {
outputFilePath = path + SP + fileName;
} else {
outputFilePath = path;
}
currentBlockFileNamePrefix = taskNumber + "." + jobId;
tmpPath = outputFilePath + SP + DATA_SUBDIR;
finishedPath = outputFilePath + SP + FINISHED_SUBDIR + SP + taskNumber;
actionFinishedTag = tmpPath + SP + ACTION_FINISHED + "_" + jobId;
LOG.info("Channel:[{}], currentBlockFileNamePrefix:[{}], tmpPath:[{}], finishedPath:[{}]", taskNumber, currentBlockFileNamePrefix, tmpPath, finishedPath);
}
protected void initFileIndex() {
if (null != formatState && formatState.getFileIndex() > -1) {
blockIndex = formatState.getFileIndex() + 1;
}
LOG.info("Start block index:{}", blockIndex);
}
protected void actionBeforeWriteData() {
if (taskNumber > 0) {
waitForActionFinishedBeforeWrite();
return;
}
checkOutputDir();
try {
// 覆盖模式并且不是从检查点恢复时先删除数据目录
boolean isCoverageData = !APPEND_MODE.equalsIgnoreCase(writeMode) && (formatState == null || formatState.getState() == null);
if (isCoverageData) {
coverageData();
}
// 处理上次任务因异常失败产生的脏数据
if (restoreConfig.isRestore() && formatState != null) {
cleanDirtyData();
}
} catch (Exception e) {
LOG.error("e = {}", ExceptionUtil.getErrorMessage(e));
throw new RuntimeException(e);
}
try {
LOG.info("Delete [.data] dir before write records");
clearTemporaryDataFiles();
} catch (Exception e) {
LOG.warn("Clean temp dir error before write records:{}", e.getMessage());
} finally {
createActionFinishedTag();
}
}
@Override
public void writeSingleRecordInternal(Row row) throws WriteRecordException {
if (restoreConfig.isRestore() && !restoreConfig.isStream()) {
if (lastRow != null) {
readyCheckpoint = !ObjectUtils.equals(lastRow.getField(restoreConfig.getRestoreColumnIndex()), row.getField(restoreConfig.getRestoreColumnIndex()));
}
}
checkSize();
writeSingleRecordToFile(row);
lastWriteTime = System.currentTimeMillis();
}
private void checkSize() {
if (numWriteCounter.getLocalValue() < nextNumForCheckDataSize) {
return;
}
if (getCurrentFileSize() > maxFileSize) {
try {
flushData();
LOG.info("Flush data by check file size");
} catch (Exception e) {
throw new RuntimeException("Flush data error", e);
}
lastWriteSize = bytesWriteCounter.getLocalValue();
}
nextNumForCheckDataSize = getNextNumForCheckDataSize();
}
private long getCurrentFileSize() {
return (long) (getDeviation() * (bytesWriteCounter.getLocalValue() - lastWriteSize));
}
private long getNextNumForCheckDataSize() {
long totalBytesWrite = bytesWriteCounter.getLocalValue();
long totalRecordWrite = numWriteCounter.getLocalValue();
float eachRecordSize = (totalBytesWrite * getDeviation()) / totalRecordWrite;
long currentFileSize = getCurrentFileSize();
long recordNum = (long) ((maxFileSize - currentFileSize) / eachRecordSize);
return totalRecordWrite + recordNum;
}
protected void nextBlock() {
if (restoreConfig.isRestore()) {
currentBlockFileName = "." + currentBlockFileNamePrefix + "." + blockIndex + getExtension();
} else {
currentBlockFileName = currentBlockFileNamePrefix + "." + blockIndex + getExtension();
}
}
@Override
public FormatState getFormatState() {
if (!restoreConfig.isRestore() || lastRow == null) {
return null;
}
if (restoreConfig.isStream() || readyCheckpoint) {
try {
flushData();
lastWriteSize = bytesWriteCounter.getLocalValue();
} catch (Exception e) {
throw new RuntimeException("Flush data error when create snapshot:", e);
}
try {
if (sumRowsOfBlock != 0) {
moveTemporaryDataFileToDirectory();
}
} catch (Exception e) {
throw new RuntimeException("Move temporary file to data directory error when create snapshot:", e);
}
snapshotWriteCounter.add(sumRowsOfBlock);
numWriteCounter.add(sumRowsOfBlock);
formatState.setNumberWrite(numWriteCounter.getLocalValue());
if (!restoreConfig.isStream()) {
formatState.setState(lastRow.getField(restoreConfig.getRestoreColumnIndex()));
}
sumRowsOfBlock = 0;
formatState.setJobId(jobId);
formatState.setFileIndex(blockIndex - 1);
LOG.info("jobId = {}, blockIndex = {}", jobId, blockIndex);
super.getFormatState();
return formatState;
}
return null;
}
@Override
public void closeInternal() throws IOException {
readyCheckpoint = false;
// 最后触发一次 block文件重命名,为 .data 目录下的文件移动到数据目录做准备
if (isTaskEndsNormally()) {
flushData();
// restore == false 需要主动执行
if (!restoreConfig.isRestore()) {
moveTemporaryDataBlockFileToDirectory();
}
}
numWriteCounter.add(sumRowsOfBlock);
}
@Override
protected void afterCloseInternal() {
try {
if (!isTaskEndsNormally()) {
return;
}
if (!restoreConfig.isStream()) {
createFinishedTag();
if (taskNumber == 0) {
waitForAllTasksToFinish();
// 正常被close,触发 .data 目录下的文件移动到数据目录
moveAllTemporaryDataFileToDirectory();
LOG.info("The task ran successfully,clear temporary data files");
closeSource();
clearTemporaryDataFiles();
}
} else {
closeSource();
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
}
protected boolean isTaskEndsNormally() throws IOException {
String state = getTaskState();
LOG.info("State of current task is:[{}]", state);
if (!RUNNING_STATE.equals(state)) {
if (!restoreConfig.isRestore()) {
LOG.info("The task does not end normally, clear the temporary data file");
clearTemporaryDataFiles();
}
closeSource();
return false;
}
return true;
}
@Override
public void tryCleanupOnError() throws Exception {
if (!restoreConfig.isRestore()) {
LOG.info("Clean temporary data in method tryCleanupOnError");
clearTemporaryDataFiles();
}
}
@Override
protected boolean needWaitAfterCloseInternal() {
return true;
}
public String getPath() {
return path;
}
public void flushData() throws IOException {
if (rowsOfCurrentBlock != 0) {
flushDataInternal();
if (restoreConfig.isRestore()) {
moveTemporaryDataBlockFileToDirectory();
sumRowsOfBlock += rowsOfCurrentBlock;
LOG.info("flush file:{} rows:{} sumRowsOfBlock:{}", currentBlockFileName, rowsOfCurrentBlock, sumRowsOfBlock);
}
rowsOfCurrentBlock = 0;
}
}
public long getLastWriteTime() {
return lastWriteTime;
}
/**
* 清除脏数据文件
*/
protected abstract void cleanDirtyData();
/**
* 写数据前由第一个通道完成指定操作之后调用此方法创建结束标制通知其它通道开始写数据
*/
protected abstract void createActionFinishedTag();
/**
* 等待第一个通道完成写数据前的操作
*/
protected abstract void waitForActionFinishedBeforeWrite();
/**
* flush数据到存储介质
*
* @throws IOException 输出异常
*/
protected abstract void flushDataInternal() throws IOException;
/**
* 单条数据写入文件
*
* @param row 要写入的数据
* @throws WriteRecordException 脏数据异常
*/
protected abstract void writeSingleRecordToFile(Row row) throws WriteRecordException;
/**
* 每个通道写完数据后关闭资源前创建结束标制
*
* @throws IOException 创建异常
*/
protected abstract void createFinishedTag() throws IOException;
/**
* 移动临时数据文件
*/
protected abstract void moveTemporaryDataBlockFileToDirectory();
/**
* 等待所有通道操作完成
*
* @throws IOException 超时异常
*/
protected abstract void waitForAllTasksToFinish() throws IOException;
/**
* 覆盖数据操作
*
* @throws IOException 删除数据异常
*/
protected abstract void coverageData() throws IOException;
/**
* 移动所有的临时数据文件
*
* @throws IOException 重命名文件异常
*/
protected abstract void moveTemporaryDataFileToDirectory() throws IOException;
/**
* 正常被close,触发 .data 目录下的文件移动到数据目录
*
* @throws IOException 重命名文件异常
*/
protected abstract void moveAllTemporaryDataFileToDirectory() throws IOException;
/**
* 检查写入路径是否存在,是否为目录
*/
protected abstract void checkOutputDir();
/**
* 打开资源
*
* @throws IOException 打开连接异常
*/
protected abstract void openSource() throws IOException;
/**
* 关闭资源
*
* @throws IOException 关闭连接异常
*/
protected abstract void closeSource() throws IOException;
/**
* 清除临时数据文件
*
* @throws IOException 删除数据异常
*/
protected abstract void clearTemporaryDataFiles() throws IOException;
/**
* 获取文件压缩比
* @return 压缩比 < 1
*/
public abstract float getDeviation();
/**
* 获取文件后缀
*
* @return .gz
*/
protected abstract String getExtension();
}
19
Source : WriteRecordException.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
/**
* The Exception describing errors when writing a record
*
* Company: www.dtstack.com
* @author [email protected]
*/
public clreplaced WriteRecordException extends Exception {
private int colIndex = -1;
private Row row;
public int getColIndex() {
return colIndex;
}
public Row getRow() {
return row;
}
public WriteRecordException(String message, Throwable cause, int colIndex, Row row) {
super(message, cause);
this.colIndex = colIndex;
this.row = row;
}
public WriteRecordException(String message, Throwable cause) {
this(message, cause, -1, null);
}
@Override
public String toString() {
return super.toString() + "\n" + getCause().toString();
}
}
19
Source : PreRowKeyModeDealerDealer.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
private String dealFail(Object arg2, Row input, ResultFuture<Row> resultFuture) {
LOG.error("record:" + input);
LOG.error("get side record exception:" + arg2);
resultFuture.complete(Collections.EMPTY_LIST);
return "";
}
19
Source : BaseAsyncTableFunction.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
/**
* 查询前置
*
* @param input
* @param resultFuture
* @throws InvocationTargetException
* @throws IllegalAccessException
*/
protected void preInvoke(Row input, ResultFuture<Row> resultFuture) throws InvocationTargetException, IllegalAccessException {
}
19
Source : BaseAsyncReqRow.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
protected ScheduledFuture<?> registerTimer(Row input, ResultFuture<Row> resultFuture) {
long timeoutTimestamp = sideInfo.getSideTableInfo().getAsyncTimeout() + getProcessingTimeService().getCurrentProcessingTime();
return getProcessingTimeService().registerTimer(timeoutTimestamp, timestamp -> timeout(input, resultFuture));
}
19
Source : BaseAsyncReqRow.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
protected void preInvoke(Row input, ResultFuture<Row> resultFuture) throws InvocationTargetException, IllegalAccessException {
registerTimerAndAddToHandler(input, resultFuture);
}
19
Source : BaseAllReqRow.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
protected void sendOutputRow(Row value, Object sideInput, Collector<Row> out) {
if (sideInput == null && sideInfo.getJoinType() != JoinType.LEFT) {
return;
}
Row row = fillData(value, sideInput);
out.collect(row);
}
19
Source : CassandraOutputFormat.java
with Apache License 2.0
from DTStack
with Apache License 2.0
from DTStack
private String buildSql(Row row) {
StringBuffer fields = new StringBuffer();
StringBuffer values = new StringBuffer();
for (int index = 0; index < row.getArity(); index++) {
if (row.getField(index) == null) {
} else {
fields.append(fieldNames[index] + ",");
if (row.getField(index) instanceof String || row.getField(index) instanceof Time || row.getField(index) instanceof Date || row.getField(index) instanceof Timestamp) {
values.append("'" + row.getField(index) + "'" + ",");
} else {
values.append(row.getField(index) + ",");
}
}
}
fields.deleteCharAt(fields.length() - 1);
values.deleteCharAt(values.length() - 1);
String cql = "INSERT INTO " + database + "." + tableName + " (" + fields.toString() + ") " + " VALUES (" + values.toString() + ")";
return cql;
}
19
Source : MysqlSideFunction.java
with Apache License 2.0
from binglind
with Apache License 2.0
from binglind
private JsonArray createParams(Row input, List<Integer> conditionIndexs) {
JsonArray params = new JsonArray();
for (Integer index : conditionIndexs) {
Object param = input.getField(index);
if (param == null) {
LOG.warn("join condition is null ,index:{}", index);
return null;
}
params.add(param);
}
return params;
}
19
Source : MysqlSideFunction.java
with Apache License 2.0
from binglind
with Apache License 2.0
from binglind
public void asyncInvoke(Row input, ResultFuture<Row> resultFuture) throws Exception {
asyncInvoke(input, new Future<>(resultFuture));
}
19
Source : Elasticsearch6SinkFunction.java
with Apache License 2.0
from binglind
with Apache License 2.0
from binglind
/**
* 获取自定义索引名
*
* @param row
* @return
*/
private String getIndex(Row row) {
if (fieldIndex == null) {
return this.index;
}
return (String) row.getField(fieldIndex);
}
19
Source : GridSearchTVSplitTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
public clreplaced GridSearchTVSplitTest extends AlinkTestBase {
private Row[] testArray;
private MemSourceBatchOp memSourceBatchOp;
private String[] colNames;
@Before
public void setUp() throws Exception {
testArray = new Row[] { Row.of(1, 2, 0), Row.of(1, 2, 0), Row.of(0, 3, 1), Row.of(0, 2, 0), Row.of(1, 3, 1), Row.of(4, 3, 1), Row.of(4, 4, 1), Row.of(5, 3, 0), Row.of(5, 4, 0), Row.of(5, 2, 1) };
colNames = new String[] { "col0", "col1", "label" };
memSourceBatchOp = new MemSourceBatchOp(Arrays.asList(testArray), colNames);
}
@Test
public void findBest() {
GbdtClreplacedifier gbdtClreplacedifier = new GbdtClreplacedifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClreplacedifier, GbdtClreplacedifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClreplacedifier).setParamGrid(grid).setTuningEvaluator(new BinaryClreplacedificationTuningEvaluator().setTuningBinaryClreplacedMetric(TuningBinaryClreplacedMetric.ACCURACY).setLabelCol(colNames[2]).setPositiveLabelValueString("1").setPredictionDetailCol("pred_detail"));
GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
@Test
public void findBestMulti() {
GbdtClreplacedifier gbdtClreplacedifier = new GbdtClreplacedifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClreplacedifier, GbdtClreplacedifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClreplacedifier).setParamGrid(grid).setTuningEvaluator(new MultiClreplacedClreplacedificationTuningEvaluator().setTuningMultiClreplacedMetric(TuningMultiClreplacedMetric.ACCURACY).setLabelCol(colNames[2]).setPredictionDetailCol("pred_detail"));
GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
@Test
public void findBestReg() {
GbdtRegressor gbdtClreplacedifier = new GbdtRegressor().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred");
ParamGrid grid = new ParamGrid().addGrid(gbdtClreplacedifier, GbdtClreplacedifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(gbdtClreplacedifier).setParamGrid(grid).setTuningEvaluator(new RegressionTuningEvaluator().setTuningRegressionMetric(TuningRegressionMetric.RMSE).setLabelCol(colNames[2]).setPredictionCol("pred"));
GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
@Test
public void findBestCluster() throws Exception {
ColumnsToVector columnsToVector = new ColumnsToVector().setSelectedCols(colNames[0], colNames[1]).setVectorCol("vector");
KMeans kMeans = new KMeans().setVectorCol("vector").setPredictionCol("pred");
ParamGrid grid = new ParamGrid().addGrid(kMeans, "distanceType", new HasKMeansDistanceType.DistanceType[] { EUCLIDEAN, COSINE });
Pipeline pipeline = new Pipeline().add(columnsToVector).add(kMeans);
GridSearchTVSplit gridSearchTVSplit = new GridSearchTVSplit().setEstimator(pipeline).setParamGrid(grid).setTrainRatio(0.5).setTuningEvaluator(new ClusterTuningEvaluator().setTuningClusterMetric(TuningClusterMetric.RI).setPredictionCol("pred").setVectorCol("vector").setLabelCol("label"));
GridSearchTVSplitModel model = gridSearchTVSplit.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
}
19
Source : GridSearchCVTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
public clreplaced GridSearchCVTest extends AlinkTestBase {
private Row[] testArray;
private MemSourceBatchOp memSourceBatchOp;
private String[] colNames;
@Before
public void setUp() throws Exception {
testArray = new Row[] { Row.of(1, 2, 0), Row.of(1, 2, 0), Row.of(0, 3, 1), Row.of(0, 2, 0), Row.of(1, 3, 1), Row.of(4, 3, 1), Row.of(4, 4, 1), Row.of(5, 3, 0), Row.of(5, 4, 0), Row.of(5, 2, 1) };
colNames = new String[] { "col0", "col1", "label" };
memSourceBatchOp = new MemSourceBatchOp(Arrays.asList(testArray), colNames);
}
@Test
public void testSplit() throws Exception {
List<Row> rows = Arrays.asList(Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1), Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1), Row.of(4.0, "D", 3, 3, 1), Row.of(1.0, "A", 0, 0, 0), Row.of(2.0, "B", 1, 1, 0), Row.of(3.0, "C", 2, 2, 1));
String[] colNames = new String[] { "f0", "f1", "f2", "f3", "label" };
MemSourceBatchOp data = new MemSourceBatchOp(rows, colNames);
String[] featureColNames = new String[] { colNames[0], colNames[1], colNames[2], colNames[3] };
String[] categoricalColNames = new String[] { colNames[1] };
String labelColName = colNames[4];
RandomForestClreplacedifier rf = new RandomForestClreplacedifier().setFeatureCols(featureColNames).setCategoricalCols(categoricalColNames).setLabelCol(labelColName).setPredictionCol("pred_result").setPredictionDetailCol("pred_detail").setSubsamplingRatio(1.0);
Pipeline pipeline = new Pipeline(rf);
ParamGrid paramGrid = new ParamGrid().addGrid(rf, "SUBSAMPLING_RATIO", new Double[] { 1.0 }).addGrid(rf, "NUM_TREES", new Integer[] { 3 });
BinaryClreplacedificationTuningEvaluator tuning_evaluator = new BinaryClreplacedificationTuningEvaluator().setLabelCol(labelColName).setPredictionDetailCol("pred_detail").setTuningBinaryClreplacedMetric("Accuracy");
GridSearchTVSplit cv = new GridSearchTVSplit().setEstimator(pipeline).setParamGrid(paramGrid).setTuningEvaluator(tuning_evaluator).setTrainRatio(0.8);
ModelBase cvModel = cv.fit(data);
cvModel.transform(data).print();
}
@Test
public void findBest() {
GbdtClreplacedifier gbdtClreplacedifier = new GbdtClreplacedifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClreplacedifier, GbdtClreplacedifier.NUM_TREES, new Integer[] { 1, 2 }).addGrid(gbdtClreplacedifier, GbdtClreplacedifier.MAX_DEPTH, new Integer[] { 3, -1 });
GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(gbdtClreplacedifier).setParamGrid(grid).setNumFolds(2).enableLazyPrintTrainInfo().setTuningEvaluator(new BinaryClreplacedificationTuningEvaluator().setTuningBinaryClreplacedMetric(TuningBinaryClreplacedMetric.ACCURACY).setLabelCol(colNames[2]).setPositiveLabelValueString("1").setPredictionDetailCol("pred_detail"));
GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
@Test
public void findBestMulti() {
GbdtClreplacedifier gbdtClreplacedifier = new GbdtClreplacedifier().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred").setPredictionDetailCol("pred_detail");
ParamGrid grid = new ParamGrid().addGrid(gbdtClreplacedifier, GbdtClreplacedifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(gbdtClreplacedifier).setParamGrid(grid).setNumFolds(2).setTuningEvaluator(new MultiClreplacedClreplacedificationTuningEvaluator().setTuningMultiClreplacedMetric(TuningMultiClreplacedMetric.ACCURACY).setLabelCol(colNames[2]).setPredictionDetailCol("pred_detail"));
GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
@Test
public void findBestReg() {
GbdtRegressor gbdtClreplacedifier = new GbdtRegressor().setFeatureCols(colNames[0], colNames[1]).setLabelCol(colNames[2]).setMinSamplesPerLeaf(1).setPredictionCol("pred");
ParamGrid grid = new ParamGrid().addGrid(gbdtClreplacedifier, GbdtClreplacedifier.NUM_TREES, new Integer[] { 1, 2 });
GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(gbdtClreplacedifier).setParamGrid(grid).setNumFolds(2).setTuningEvaluator(new RegressionTuningEvaluator().setTuningRegressionMetric(TuningRegressionMetric.RMSE).setLabelCol(colNames[2]).setPredictionCol("pred"));
GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
@Test
public void findBestCluster() {
ColumnsToVector columnsToVector = new ColumnsToVector().setSelectedCols(colNames[0], colNames[1]).setVectorCol("vector");
KMeans kMeans = new KMeans().setVectorCol("vector").setPredictionCol("pred");
ParamGrid grid = new ParamGrid().addGrid(kMeans, KMeans.DISTANCE_TYPE, new HasKMeansDistanceType.DistanceType[] { EUCLIDEAN, COSINE });
Pipeline pipeline = new Pipeline().add(columnsToVector).add(kMeans);
GridSearchCV gridSearchCV = new GridSearchCV().setEstimator(pipeline).setParamGrid(grid).setNumFolds(2).setTuningEvaluator(new ClusterTuningEvaluator().setTuningClusterMetric(TuningClusterMetric.RI).setPredictionCol("pred").setVectorCol("vector").setLabelCol("label"));
GridSearchCVModel model = gridSearchCV.fit(memSourceBatchOp);
replacedert.replacedertEquals(testArray.length, model.transform(memSourceBatchOp).collect().size());
}
}
19
Source : StringNearestNeighborTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
public clreplaced StringNearestNeighborTest extends AlinkTestBase {
@Test
public void testString() {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(StringNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
BatchOperator query = new MemSourceBatchOp(Arrays.asList(StringNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
BatchOperator neareastNeighbor = new StringNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict).transform(query);
List<Row> res = neareastNeighbor.collect();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.75, 0.667, 0.333 });
score.put(2, new Double[] { 0.667, 0.667, 0.5 });
score.put(3, new Double[] { 0.333, 0.333, 0.25 });
score.put(4, new Double[] { 0.75, 0.333, 0.333 });
score.put(5, new Double[] { 0.333, 0.25, 0.25 });
score.put(6, new Double[] { 0.333, 0.333, 0.333 });
for (Row row : res) {
Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals(actual[i], expect[i], 0.01);
}
}
}
@Test
public void testStringApprox() {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(StringNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
BatchOperator query = new MemSourceBatchOp(Arrays.asList(StringNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
BatchOperator neareastNeighbor = new StringApproxNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict).transform(query);
List<Row> res = neareastNeighbor.collect();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.984375, 0.953125, 0.9375 });
score.put(2, new Double[] { 0.984375, 0.953125, 0.9375 });
score.put(3, new Double[] { 0.921875, 0.875, 0.875 });
score.put(4, new Double[] { 0.9375, 0.890625, 0.8125 });
score.put(5, new Double[] { 0.890625, 0.84375, 0.8125 });
score.put(6, new Double[] { 0.9375, 0.890625, 0.8125 });
for (Row row : res) {
Double[] actual = extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals(actual[i], expect[i], 0.01);
}
}
}
@Test
public void testText() {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(TextApproxNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
BatchOperator query = new MemSourceBatchOp(Arrays.asList(TextApproxNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
BatchOperator neareastNeighbor = new TextNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict).transform(query);
List<Row> res = neareastNeighbor.collect();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.75, 0.667, 0.333 });
score.put(2, new Double[] { 0.667, 0.667, 0.5 });
score.put(3, new Double[] { 0.333, 0.333, 0.25 });
score.put(4, new Double[] { 0.75, 0.333, 0.333 });
score.put(5, new Double[] { 0.333, 0.25, 0.25 });
score.put(6, new Double[] { 0.333, 0.333, 0.333 });
for (Row row : res) {
Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals(actual[i], expect[i], 0.01);
}
}
}
@Test
public void testTextApprox() {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(TextApproxNearestNeighborBatchOpTest.dictRows), new String[] { "id", "str" });
BatchOperator query = new MemSourceBatchOp(Arrays.asList(TextApproxNearestNeighborBatchOpTest.queryRows), new String[] { "id", "str" });
BatchOperator neareastNeighbor = new TextApproxNearestNeighbor().setIdCol("id").setSelectedCol("str").setTopN(3).setOutputCol("output").fit(dict).transform(query);
List<Row> res = neareastNeighbor.collect();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.984375, 0.953125, 0.9375 });
score.put(2, new Double[] { 0.984375, 0.953125, 0.9375 });
score.put(3, new Double[] { 0.921875, 0.875, 0.875 });
score.put(4, new Double[] { 0.9375, 0.890625, 0.8125 });
score.put(5, new Double[] { 0.890625, 0.84375, 0.8125 });
score.put(6, new Double[] { 0.9375, 0.890625, 0.8125 });
for (Row row : res) {
Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals(actual[i], expect[i], 0.01);
}
}
}
public static Row[] dictRows = new Row[] { Row.of("dict1", "0 0 0"), Row.of("dict2", "0.1 0.1 0.1"), Row.of("dict3", "0.2 0.2 0.2"), Row.of("dict4", "9 9 9"), Row.of("dict5", "9.1 9.1 9.1"), Row.of("dict6", "9.2 9.2 9.2") };
public static Row[] queryRows = new Row[] { Row.of(1, "0 0 0"), Row.of(2, "0.1 0.1 0.1"), Row.of(3, "0.2 0.2 0.2"), Row.of(4, "9 9 9"), Row.of(5, "9.1 9.1 9.1"), Row.of(6, "9.2 9.2 9.2") };
@Test
public void testVector() {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(dictRows), new String[] { "id", "vec" });
BatchOperator query = new MemSourceBatchOp(Arrays.asList(queryRows), new String[] { "id", "vec" });
BatchOperator neareastNeighbor = new VectorNearestNeighbor().setIdCol("id").setSelectedCol("vec").setTopN(3).setOutputCol("output").fit(dict).transform(query);
List<Row> res = neareastNeighbor.collect();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.0, 0.17320508075688776, 0.3464101615137755 });
score.put(2, new Double[] { 0.0, 0.17320508075688773, 0.17320508075688776 });
score.put(3, new Double[] { 0.0, 0.17320508075688776, 0.3464101615137755 });
score.put(4, new Double[] { 0.0, 0.17320508075680896, 0.346410161513782 });
score.put(5, new Double[] { 0.0, 0.17320508075680896, 0.17320508075680896 });
score.put(6, new Double[] { 0.0, 0.17320508075680896, 0.346410161513782 });
for (Row row : res) {
Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals(actual[i], expect[i], 0.01);
}
}
}
@Test
public void testVectorApprox() {
BatchOperator dict = new MemSourceBatchOp(Arrays.asList(dictRows), new String[] { "id", "vec" });
BatchOperator query = new MemSourceBatchOp(Arrays.asList(queryRows), new String[] { "id", "vec" });
BatchOperator neareastNeighbor = new VectorApproxNearestNeighbor().setIdCol("id").setSelectedCol("vec").setTopN(3).setOutputCol("output").fit(dict).transform(query);
List<Row> res = neareastNeighbor.collect();
Map<Object, Double[]> score = new HashMap<>();
score.put(1, new Double[] { 0.0, 0.17320508075688776, 0.3464101615137755 });
score.put(2, new Double[] { 0.0, 0.17320508075688773, 0.17320508075688776 });
score.put(3, new Double[] { 0.0, 0.17320508075688776, 0.3464101615137755 });
score.put(4, new Double[] { 0.0, 0.17320508075680896, 0.346410161513782 });
score.put(5, new Double[] { 0.0, 0.17320508075680896, 0.17320508075680896 });
score.put(6, new Double[] { 0.0, 0.17320508075680896, 0.346410161513782 });
for (Row row : res) {
Double[] actual = StringNearestNeighborBatchOpTest.extractScore((String) row.getField(2));
Double[] expect = score.get(row.getField(0));
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals(actual[i], expect[i], 0.01);
}
}
}
}
19
Source : RidgeRegressionTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
public clreplaced RidgeRegressionTest extends AlinkTestBase {
Row[] vecrows = new Row[] { Row.of("$3$0:1.0 1:7.0 2:9.0", "1.0 7.0 9.0", 1.0, 7.0, 9.0, 16.8), Row.of("$3$0:1.0 1:3.0 2:3.0", "1.0 3.0 3.0", 1.0, 3.0, 3.0, 6.7), Row.of("$3$0:1.0 1:2.0 2:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 6.9), Row.of("$3$0:1.0 1:3.0 2:4.0", "1.0 3.0 4.0", 1.0, 3.0, 4.0, 8.0) };
String[] veccolNames = new String[] { "svec", "vec", "f0", "f1", "f2", "label" };
@Test
public void regressionPipelineTest() throws Exception {
BatchOperator vecdata = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
StreamOperator svecdata = new MemSourceStreamOp(Arrays.asList(vecrows), veccolNames);
String[] xVars = new String[] { "f0", "f1", "f2" };
String yVar = "label";
String vec = "vec";
String svec = "svec";
RidgeRegression ridge = new RidgeRegression().setLabelCol(yVar).setFeatureCols(xVars).setLambda(0.01).setMaxIter(10).setPredictionCol("linpred");
RidgeRegression vridge = new RidgeRegression().setLabelCol(yVar).setVectorCol(vec).setLambda(0.01).setMaxIter(10).setOptimMethod("newton").setPredictionCol("vlinpred");
RidgeRegression svridge = new RidgeRegression().setLabelCol(yVar).setVectorCol(svec).setLambda(0.01).setMaxIter(10).setPredictionCol("svlinpred");
Pipeline pl = new Pipeline().add(ridge).add(vridge).add(svridge);
PipelineModel model = pl.fit(vecdata);
BatchOperator result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
result.lazyCollect(new Consumer<List<Row>>() {
@Override
public void accept(List<Row> d) {
for (Row row : d) {
if ((double) row.getField(0) == 16.8000) {
replacedert.replacedertEquals((double) row.getField(1), 16.77322547668301, 0.01);
replacedert.replacedertEquals((double) row.getField(2), 16.77322547668301, 0.01);
replacedert.replacedertEquals((double) row.getField(3), 16.384437074591887, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
replacedert.replacedertEquals((double) row.getField(1), 6.932628087721653, 0.01);
replacedert.replacedertEquals((double) row.getField(2), 6.932628087721653, 0.01);
replacedert.replacedertEquals((double) row.getField(3), 7.425378715755974, 0.01);
}
}
}
});
// below is stream test code
// model.transform(svecdata).print();
// StreamOperator.execute();
}
}
19
Source : LinearRegressionTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
public clreplaced LinearRegressionTest extends AlinkTestBase {
Row[] vecrows = new Row[] { Row.of("$3$0:1.0 1:7.0 2:9.0", "1.0 7.0 9.0", 1.0, 7.0, 9.0, 16.8), Row.of("$3$0:1.0 1:3.0 2:3.0", "1.0 3.0 3.0", 1.0, 3.0, 3.0, 6.7), Row.of("$3$0:1.0 1:2.0 2:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 6.9), Row.of("$3$0:1.0 1:3.0 2:4.0", "1.0 3.0 4.0", 1.0, 3.0, 4.0, 8.0) };
String[] veccolNames = new String[] { "svec", "vec", "f0", "f1", "f2", "label" };
@Test
public void regressionPipelineTest() throws Exception {
BatchOperator vecdata = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
StreamOperator svecdata = new MemSourceStreamOp(Arrays.asList(vecrows), veccolNames);
String[] xVars = new String[] { "f0", "f1", "f2" };
String yVar = "label";
String vec = "vec";
String svec = "svec";
LinearRegression linear = new LinearRegression().setLabelCol(yVar).setFeatureCols(xVars).setMaxIter(20).setOptimMethod("newton").setPredictionCol("linpred");
LinearRegression vlinear = new LinearRegression().setLabelCol(yVar).setVectorCol(vec).setMaxIter(20).setPredictionCol("vlinpred");
LinearRegression svlinear = new LinearRegression().setLabelCol(yVar).setVectorCol(svec).setMaxIter(20).setPredictionCol("svlinpred");
svlinear.enableLazyPrintModelInfo();
svlinear.enableLazyPrintTrainInfo();
Pipeline pl = new Pipeline().add(linear).add(vlinear).add(svlinear);
PipelineModel model = pl.fit(vecdata);
BatchOperator result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
List<Row> data = result.collect();
for (Row row : data) {
if ((double) row.getField(0) == 16.8000) {
replacedert.replacedertEquals((double) row.getField(1), 16.814789059973744, 0.01);
replacedert.replacedertEquals((double) row.getField(2), 16.814789059973744, 0.01);
replacedert.replacedertEquals((double) row.getField(3), 16.814788687904162, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
replacedert.replacedertEquals((double) row.getField(1), 6.773942836224718, 0.01);
replacedert.replacedertEquals((double) row.getField(2), 6.773942836224718, 0.01);
replacedert.replacedertEquals((double) row.getField(3), 6.773943529327923, 0.01);
}
}
// below is stream test code
// model.transform(svecdata).print();
// StreamOperator.execute();
}
}
19
Source : LassoRegressionTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
public clreplaced LreplacedoRegressionTest extends AlinkTestBase {
Row[] vecrows = new Row[] { Row.of("$3$0:1.0 1:7.0 2:9.0", "1.0 7.0 9.0", 1.0, 7.0, 9.0, 16.8), Row.of("$3$0:1.0 1:3.0 2:3.0", "1.0 3.0 3.0", 1.0, 3.0, 3.0, 6.7), Row.of("$3$0:1.0 1:2.0 2:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 6.9), Row.of("$3$0:1.0 1:3.0 2:4.0", "1.0 3.0 4.0", 1.0, 3.0, 4.0, 8.0) };
String[] veccolNames = new String[] { "svec", "vec", "f0", "f1", "f2", "label" };
@Test
public void regressionPipelineTest() throws Exception {
BatchOperator vecdata = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
// StreamOperator svecdata = new MemSourceStreamOp(Arrays.asList(vecrows), veccolNames);
String[] xVars = new String[] { "f0", "f1", "f2" };
String yVar = "label";
String vec = "vec";
String svec = "svec";
LreplacedoRegression lreplacedo = new LreplacedoRegression().setLabelCol(yVar).setFeatureCols(xVars).setLambda(0.01).setMaxIter(20).setOptimMethod("owlqn").setPredictionCol("linpred");
LreplacedoRegression vlreplacedo = new LreplacedoRegression().setLabelCol(yVar).setVectorCol(vec).setMaxIter(20).setLambda(0.01).setOptimMethod("newton").setPredictionCol("vlinpred").enableLazyPrintModelInfo();
LreplacedoRegression svlreplacedo = new LreplacedoRegression().setLabelCol(yVar).setVectorCol(svec).setMaxIter(20).setLambda(0.01).setPredictionCol("svlinpred");
Pipeline pl = new Pipeline().add(lreplacedo).add(vlreplacedo).add(svlreplacedo);
PipelineModel model = pl.fit(vecdata);
BatchOperator result = model.transform(vecdata).select(new String[] { "label", "linpred", "vlinpred", "svlinpred" });
List<Row> data = result.collect();
for (Row row : data) {
if ((double) row.getField(0) == 16.8000) {
replacedert.replacedertEquals((double) row.getField(1), 16.784611802507232, 0.01);
replacedert.replacedertEquals((double) row.getField(2), 16.784611802507232, 0.01);
replacedert.replacedertEquals((double) row.getField(3), 16.78209421260283, 0.01);
} else if ((double) row.getField(0) == 6.7000) {
replacedert.replacedertEquals((double) row.getField(1), 6.7713287283076, 0.01);
replacedert.replacedertEquals((double) row.getField(2), 6.7713287283076, 0.01);
replacedert.replacedertEquals((double) row.getField(3), 6.826846826823054, 0.01);
}
}
// below is stream test code
// model.transform(svecdata).print();
// StreamOperator.execute();
}
}
19
Source : IsotonicRegressionTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
/**
* Test for IsotonicRegression.
*/
public clreplaced IsotonicRegressionTest extends AlinkTestBase {
private Row[] rows = new Row[] { Row.of(0, 0.35, 1), Row.of(1, 0.6, 1), Row.of(2, 0.55, 1), Row.of(3, 0.5, 1), Row.of(4, 0.18, 0), Row.of(5, 0.1, 1), Row.of(6, 0.8, 1), Row.of(7, 0.45, 0), Row.of(8, 0.4, 1), Row.of(9, 0.7, 0), Row.of(10, 0.02, 1), Row.of(11, 0.3, 0), Row.of(12, 0.27, 1), Row.of(13, 0.2, 0), Row.of(14, 0.9, 1) };
@Test
public void testIsotonicReg() throws Exception {
Table data = MLEnvironmentFactory.getDefault().createBatchTable(rows, new String[] { "id", "feature", "label" });
Table dataStream = MLEnvironmentFactory.getDefault().createStreamTable(rows, new String[] { "id", "feature", "label" });
IsotonicRegression op = new IsotonicRegression().setFeatureCol("feature").setLabelCol("label").setPredictionCol("result");
PipelineModel model = new Pipeline().add(op).fit(data);
BatchOperator<?> res = model.transform(new TableSourceBatchOp(data));
List<Row> list = res.select(new String[] { "id", "result" }).collect();
double[] actual = new double[] { 0.66, 0.75, 0.75, 0.75, 0.5, 0.5, 0.75, 0.66, 0.66, 0.75, 0.5, 0.5, 0.5, 0.5, 0.75 };
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals((Double) list.get(i).getField(1), actual[(int) list.get(i).getField(0)], 0.01);
}
// StreamOperator<?> resStream = model.transform(new TableSourceStreamOp(dataStream));
// resStream.print();
// MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().execute();
}
}
19
Source : AFTRegTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
/**
* Test for AFTRegression.
*/
public clreplaced AFTRegTest extends AlinkTestBase {
private static Row[] rows = new Row[] { Row.of(0, 1.218, 1.0, "1.560,-0.605"), Row.of(1, 2.949, 0.0, "0.346,2.158"), Row.of(2, 3.627, 0.0, "1.380,0.231"), Row.of(3, 0.273, 1.0, "0.520,1.151"), Row.of(4, 4.199, 0.0, "0.795,-0.226") };
private static Row[] rowsSparse = new Row[] { Row.of(1.218, 1.0, "$10$3:1.560,7:-0.605"), Row.of(2.949, 0.0, "$10$3:0.346,7:2.158"), Row.of(3.627, 0.0, "$10$3:1.380,7:0.231"), Row.of(0.273, 1.0, "$10$3:0.520,7:1.151"), Row.of(4.199, 0.0, "$10$3:0.795,7:-0.226") };
private static Row[] rowsFeatures = new Row[] { Row.of(1.218, 1.0, 1.560, -0.605), Row.of(2.949, 0.0, 0.346, 2.158), Row.of(3.627, 0.0, 1.380, 0.231), Row.of(0.273, 1.0, 0.520, 1.151), Row.of(4.199, 0.0, 0.795, -0.226) };
@Test
public void testPipeline() throws Exception {
MemSourceBatchOp data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "id", "label", "censor", "features" });
AftSurvivalRegression reg = new AftSurvivalRegression().setVectorCol("features").setLabelCol("label").setCensorCol("censor").setPredictionCol("result").enableLazyPrintModelInfo().enableLazyPrintTrainInfo();
PipelineModel model = new Pipeline().add(reg).fit(data);
BatchOperator<?> res = model.transform(data);
List<Row> list = res.select(new String[] { "id", "result" }).collect();
double[] actual = new double[] { 5.70, 18.10, 7.36, 13.62, 9.03 };
for (int i = 0; i < actual.length; i++) {
replacedert.replacedertEquals((Double) list.get(i).getField(1), actual[(int) list.get(i).getField(0)], 0.1);
}
}
@Test
public void testOp() throws Exception {
MemSourceBatchOp data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "id", "label", "censor", "features" });
AftSurvivalRegTrainBatchOp trainBatchOp = new AftSurvivalRegTrainBatchOp().setVectorCol("features").setLabelCol("label").setCensorCol("censor").linkFrom(data);
AftSurvivalRegPredictBatchOp pred = new AftSurvivalRegPredictBatchOp().setPredictionCol("result").setPredictionDetailCol("detail");
pred.linkFrom(trainBatchOp, data).lazyCollect();
BatchOperator.execute();
}
}
19
Source : DocHashCountVectorizerTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
// import com.alibaba.alink.common.utils.RowTypeDataStream;
/**
* Test for DocHashIDFVectorizer.
*/
public clreplaced DocHashCountVectorizerTest extends AlinkTestBase {
private static Row[] rows = new Row[] { Row.of(0, "a b c d a a", 1), Row.of(1, "c c b a e", 1) };
@Test
public void testIdf() throws Exception {
Table data = MLEnvironmentFactory.getDefault().createBatchTable(rows, new String[] { "id", "sentence", "label" });
Table dataStream = MLEnvironmentFactory.getDefault().createStreamTable(rows, new String[] { "id", "sentence", "label" });
DocHashCountVectorizer op = new DocHashCountVectorizer().setSelectedCol("sentence").setNumFeatures(10).setOutputCol("res");
DocHashCountVectorizerModel model = op.fit(data);
Table res = model.transform(data);
replacedert.replacedertArrayEquals(MLEnvironmentFactory.getDefault().getBatchTableEnvironment().toDataSet(res.select("res"), new RowTypeInfo(VectorTypes.SPARSE_VECTOR)).collect().stream().map(row -> (SparseVector) row.getField(0)).toArray(SparseVector[]::new), new SparseVector[] { new SparseVector(10, new int[] { 3, 4, 5, 7 }, new double[] { 1.0, 3.0, 1.0, 1.0 }), new SparseVector(10, new int[] { 4, 5, 6, 7 }, new double[] { 1.0, 2.0, 1.0, 1.0 }) });
res = model.transform(dataStream);
DataStreamConversionUtil.fromTable(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, res).print();
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().execute();
}
@Test
public void testException() throws Exception {
BatchOperator data = new MemSourceBatchOp(rows, new String[] { "id", "sentence", "label" });
DocHashCountVectorizerTrainBatchOp op = new DocHashCountVectorizerTrainBatchOp().setSelectedCol("sentence").setMinDF(2.).setFeatureType("TF").linkFrom(data);
op.collect();
}
@Test
public void testInitializer() {
DocHashCountVectorizerModel model = new DocHashCountVectorizerModel();
replacedert.replacedertEquals(model.getParams().size(), 0);
DocHashCountVectorizer op = new DocHashCountVectorizer(new Params());
replacedert.replacedertEquals(op.getParams().size(), 0);
BatchOperator b = new DocHashCountVectorizerTrainBatchOp();
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new DocHashCountVectorizerTrainBatchOp(new Params());
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new DocHashCountVectorizerPredictBatchOp();
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new DocHashCountVectorizerPredictBatchOp(new Params());
replacedert.replacedertEquals(b.getParams().size(), 0);
StreamOperator s = new DocHashCountVectorizerPredictStreamOp(b);
replacedert.replacedertEquals(s.getParams().size(), 0);
s = new DocHashCountVectorizerPredictStreamOp(b, new Params());
replacedert.replacedertEquals(s.getParams().size(), 0);
}
}
19
Source : DocCountVectorizerTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
/**
* Test for DocCountVectorizer.
*/
public clreplaced DocCountVectorizerTest extends AlinkTestBase {
@Rule
public ExpectedException thrown = ExpectedException.none();
private Row[] rows = new Row[] { Row.of(0, "That is an English book", 1), Row.of(1, "Have a good day", 1) };
@Test
public void testDefault() throws Exception {
Table data = MLEnvironmentFactory.getDefault().createBatchTable(rows, new String[] { "id", "sentence", "label" });
Table dataStream = MLEnvironmentFactory.getDefault().createStreamTable(rows, new String[] { "id", "sentence", "label" });
DocCountVectorizer op = new DocCountVectorizer().setSelectedCol("sentence").setOutputCol("features").setFeatureType("TF");
PipelineModel model = new Pipeline().add(op).fit(data);
Table res = model.transform(data);
List<SparseVector> list = MLEnvironmentFactory.getDefault().getBatchTableEnvironment().toDataSet(res.select("features"), new RowTypeInfo(VectorTypes.SPARSE_VECTOR)).collect().stream().map(row -> (SparseVector) row.getField(0)).collect(Collectors.toList());
replacedert.replacedertEquals(list.size(), 2);
replacedert.replacedertEquals(list.get(0).getValues().length, 5);
replacedert.replacedertEquals(list.get(1).getValues().length, 4);
for (int i = 0; i < list.get(0).getValues().length; i++) {
replacedert.replacedertEquals(list.get(0).getValues()[i], 0.2, 0.1);
}
for (int i = 0; i < list.get(1).getValues().length; i++) {
replacedert.replacedertEquals(list.get(1).getValues()[i], 0.25, 0.1);
}
res = model.transform(dataStream);
DataStreamConversionUtil.fromTable(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, res).print();
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().execute();
}
@Test
public void testException() throws Exception {
BatchOperator data = new MemSourceBatchOp(rows, new String[] { "id", "sentence", "label" });
thrown.expect(RuntimeException.clreplaced);
DocCountVectorizerTrainBatchOp op = new DocCountVectorizerTrainBatchOp().setSelectedCol("sentence").setMinDF(2.).setMaxDF(1.).setFeatureType("TF").linkFrom(data);
op.collect();
}
@Test
public void testInitializer() {
DocCountVectorizerModel model = new DocCountVectorizerModel();
replacedert.replacedertEquals(model.getParams().size(), 0);
DocCountVectorizer op = new DocCountVectorizer(new Params());
replacedert.replacedertEquals(op.getParams().size(), 0);
BatchOperator b = new DocCountVectorizerTrainBatchOp();
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new DocCountVectorizerTrainBatchOp(new Params());
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new DocCountVectorizerPredictBatchOp();
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new DocCountVectorizerPredictBatchOp(new Params());
replacedert.replacedertEquals(b.getParams().size(), 0);
StreamOperator s = new DocCountVectorizerPredictStreamOp(b);
replacedert.replacedertEquals(s.getParams().size(), 0);
s = new DocCountVectorizerPredictStreamOp(b, new Params());
replacedert.replacedertEquals(s.getParams().size(), 0);
}
}
19
Source : OneHotTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
AlgoOperator getData(boolean isBatch) {
Row[] array = new Row[] { Row.of("0", "doc0", "天", 4L), Row.of("1", "doc0", "地", 5L), Row.of("2", "doc0", "人", 1L), Row.of("3", "doc1", null, 3L), Row.of("4", null, "人", 2L), Row.of("5", "doc1", "合", 4L), Row.of("6", "doc1", "一", 4L), Row.of("7", "doc2", "清", 3L), Row.of("8", "doc2", "一", 2L), Row.of("9", "doc2", "色", 2L) };
if (isBatch) {
return new MemSourceBatchOp(Arrays.asList(array), schema);
} else {
return new MemSourceStreamOp(Arrays.asList(array), schema);
}
}
19
Source : BucketizerTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
/**
* Test for Bucketizer.
*/
public clreplaced BucketizerTest extends AlinkTestBase {
private Row[] rows = new Row[] { Row.of(-999.9, -999.9), Row.of(-0.5, -0.2), Row.of(-0.3, -0.1), Row.of(0.0, 0.0), Row.of(0.2, 0.4), Row.of(999.9, 999.9) };
private static double[][] cutsArray = new double[][] { { -0.5, 0.0, 0.5 }, { -0.3, 0.0, 0.3, 0.4 } };
@Test
public void testBucketizer() throws Exception {
Table data = MLEnvironmentFactory.getDefault().createBatchTable(rows, new String[] { "features1", "features2" });
Table dataStream = MLEnvironmentFactory.getDefault().createStreamTable(rows, new String[] { "features1", "features2" });
Bucketizer op = new Bucketizer().setSelectedCols(new String[] { "features1", "features2" }).setOutputCols(new String[] { "bucket1", "bucket2" }).setCutsArray(cutsArray);
Table res = op.transform(data);
List<Long> list = MLEnvironmentFactory.getDefault().getBatchTableEnvironment().toDataSet(res.select("bucket1"), Long.clreplaced).collect();
replacedert.replacedertArrayEquals(list.toArray(new Long[0]), new Long[] { 0L, 0L, 1L, 1L, 2L, 3L });
res = op.transform(dataStream);
DataStreamConversionUtil.fromTable(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, res).print();
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().execute();
}
@Test
public void testInitializer() {
Bucketizer op = new Bucketizer(new Params());
replacedert.replacedertEquals(op.getParams().size(), 0);
BatchOperator b = new BucketizerBatchOp();
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new BucketizerBatchOp(new Params());
replacedert.replacedertEquals(b.getParams().size(), 0);
StreamOperator s = new BucketizerStreamOp();
replacedert.replacedertEquals(s.getParams().size(), 0);
s = new BucketizerStreamOp(new Params());
replacedert.replacedertEquals(s.getParams().size(), 0);
}
}
19
Source : BinarizerTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
/**
* Test for Binarizer.
*/
public clreplaced BinarizerTest extends AlinkTestBase {
Row[] rows = new Row[] { Row.of(1.218, 16.0, "1.560 -0.605"), Row.of(2.949, 4.0, "0.346 2.158"), Row.of(3.627, 2.0, "1.380 0.231"), Row.of(0.273, 15.0, "0.520 1.151"), Row.of(4.199, 7.0, "0.795 -0.226") };
Table data = MLEnvironmentFactory.getDefault().createBatchTable(rows, new String[] { "label", "censor", "features" });
Table dataStream = MLEnvironmentFactory.getDefault().createStreamTable(rows, new String[] { "label", "censor", "features" });
@Test
public void test() throws Exception {
Binarizer op = new Binarizer().setSelectedCol("censor").setThreshold(8.0);
Table res = op.transform(data);
List<Double> list = MLEnvironmentFactory.getDefault().getBatchTableEnvironment().toDataSet(res.select("censor"), Double.clreplaced).collect();
replacedert.replacedertEquals(list.toArray(new Double[0]), new Double[] { 1.0, 0.0, 0.0, 1.0, 0.0 });
res = op.transform(dataStream);
DataStreamConversionUtil.fromTable(MLEnvironmentFactory.DEFAULT_ML_ENVIRONMENT_ID, res).print();
MLEnvironmentFactory.getDefault().getStreamExecutionEnvironment().execute();
}
@Test
public void testInitializer() {
Binarizer op = new Binarizer(new Params());
replacedert.replacedertEquals(op.getParams().size(), 0);
BatchOperator b = new BinarizerBatchOp();
replacedert.replacedertEquals(b.getParams().size(), 0);
b = new BinarizerBatchOp(new Params());
replacedert.replacedertEquals(b.getParams().size(), 0);
StreamOperator s = new BinarizerStreamOp();
replacedert.replacedertEquals(s.getParams().size(), 0);
s = new BinarizerStreamOp(new Params());
replacedert.replacedertEquals(s.getParams().size(), 0);
}
}
19
Source : StringIndexerTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
/**
* Test cases for {@link StringIndexer}.
*/
public clreplaced StringIndexerTest extends AlinkTestBase {
private static Row[] rows = new Row[] { Row.of("football"), Row.of("football"), Row.of("football"), Row.of("basketball"), Row.of("basketball"), Row.of("tennis") };
private static void checkResult(List<Row> prediction, String[] actualOrderedTokens) {
Map<String, Long> actual = new HashMap<>();
for (int i = 0; i < actualOrderedTokens.length; i++) {
actual.put(actualOrderedTokens[i], (long) i);
}
prediction.forEach(row -> {
String token = (String) row.getField(0);
Long id = (Long) row.getField(1);
replacedert.replacedertEquals(id, actual.get(token));
});
}
@Test
public void testRandom() throws Exception {
BatchOperator data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "f0" });
StringIndexer stringIndexer = new StringIndexer().setSelectedCol("f0").setOutputCol("f0_indexed").setStringOrderType("random");
replacedert.replacedertEquals(stringIndexer.fit(data).getModelData().collect().size(), 3);
}
@Test
public void testFrequencyAsc() throws Exception {
BatchOperator data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "f0" });
StringIndexer stringIndexer = new StringIndexer().setSelectedCol("f0").setOutputCol("f0_indexed").setStringOrderType("frequency_asc");
List<Row> prediction = stringIndexer.fit(data).transform(data).collect();
checkResult(prediction, new String[] { "tennis", "basketball", "football" });
}
@Test
public void testAlphabetDesc() throws Exception {
BatchOperator data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "f0" });
StringIndexer stringIndexer = new StringIndexer().setSelectedCol("f0").setOutputCol("f0_indexed").setStringOrderType("alphabet_desc");
List<Row> prediction = stringIndexer.fit(data).transform(data).collect();
checkResult(prediction, new String[] { "tennis", "football", "basketball" });
}
}
19
Source : IndexToStringTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
/**
* Test cases for {@link IndexToString}.
*/
public clreplaced IndexToStringTest extends AlinkTestBase {
private static Row[] rows = new Row[] { Row.of("football"), Row.of("football"), Row.of("football"), Row.of("basketball"), Row.of("basketball"), Row.of("tennis") };
@Test
public void testIndexToString() throws Exception {
BatchOperator data = new MemSourceBatchOp(Arrays.asList(rows), new String[] { "f0" });
StringIndexer stringIndexer = new StringIndexer().setModelName("string_indexer_model").setSelectedCol("f0").setOutputCol("f0_indexed").setStringOrderType("frequency_asc");
BatchOperator indexed = stringIndexer.fit(data).transform(data);
IndexToString indexToString = new IndexToString().setModelName("string_indexer_model").setSelectedCol("f0_indexed").setOutputCol("f0_indxed_unindexed");
List<Row> unindexed = indexToString.transform(indexed).collect();
unindexed.forEach(row -> {
replacedert.replacedertEquals(row.getField(0), row.getField(2));
});
}
}
19
Source : SvmTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
AlgoOperator getData(boolean isBatch) {
Row[] array = new Row[] { Row.of(new Object[] { "$31$0:1.0 1:1.0 2:1.0 30:1.0", "1.0 1.0 1.0 1.0", 1.0, 1.0, 1.0, 1.0, 1 }), Row.of(new Object[] { "$31$0:1.0 1:1.0 2:0.0 30:1.0", "1.0 1.0 0.0 1.0", 1.0, 1.0, 0.0, 1.0, 1 }), Row.of(new Object[] { "$31$0:1.0 1:0.0 2:1.0 30:1.0", "1.0 0.0 1.0 1.0", 1.0, 0.0, 1.0, 1.0, 1 }), Row.of(new Object[] { "$31$0:1.0 1:0.0 2:1.0 30:1.0", "1.0 0.0 1.0 1.0", 1.0, 0.0, 1.0, 1.0, 1 }), Row.of(new Object[] { "$31$0:0.0 1:1.0 2:1.0 30:0.0", "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0 }), Row.of(new Object[] { "$31$0:0.0 1:1.0 2:1.0 30:0.0", "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0 }), Row.of(new Object[] { "$31$0:0.0 1:1.0 2:1.0 30:0.0", "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0 }), Row.of(new Object[] { "$31$0:0.0 1:1.0 2:1.0 30:0.0", "0.0 1.0 1.0 0.0", 0.0, 1.0, 1.0, 0.0, 0 }) };
if (isBatch) {
return new MemSourceBatchOp(Arrays.asList(array), new String[] { "svec", "vec", "f0", "f1", "f2", "f3", "labels" });
} else {
return new MemSourceStreamOp(Arrays.asList(array), new String[] { "svec", "vec", "f0", "f1", "f2", "f3", "labels" });
}
}
19
Source : SoftmaxTest.java
with Apache License 2.0
from alibaba
with Apache License 2.0
from alibaba
public clreplaced SoftmaxTest extends AlinkTestBase {
String labelColName = "label";
Row[] vecrows = new Row[] { Row.of("0:1.0 2:7.0 15:9.0", "1.0 7.0 9.0", 1.0, 7.0, 9.0, 2), Row.of("0:1.0 2:3.0 12:3.0", "1.0 3.0 3.0", 1.0, 3.0, 3.0, 3), Row.of("0:1.0 2:2.0 10:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 1), Row.of("0:1.0 2:2.0 7:4.0", "1.0 2.0 4.0", 1.0, 2.0, 4.0, 1) };
String[] veccolNames = new String[] { "svec", "vec", "f0", "f1", "f2", "label" };
Row[] vecmrows = new Row[] { Row.of("0:1.0 2:7.0 15:9.0", "1.0 1.0 9.0", 1.0, 7.0, 9.0, 2), Row.of("0:2.0 2:3.0 12:3.0", "1.0 2.0 3.0", 1.0, 3.0, 5.0, 3), Row.of("0:3.0 2:2.0 10:4.0", "1.0 3.0 4.0", 1.0, 2.0, 6.0, 1), Row.of("0:4.0 2:3.0 12:3.0", "1.0 4.0 3.0", 1.0, 3.0, 7.0, 4), Row.of("0:5.0 2:2.0 10:4.0", "1.0 5.0 4.0", 1.0, 2.0, 40.0, 5), Row.of("0:6.0 2:3.0 12:3.0", "1.0 6.0 3.0", 1.0, 3.0, 9.0, 6), Row.of("0:7.0 2:2.0 10:4.0", "1.0 7.0 4.0", 1.0, 2.0, 0.0, 7), Row.of("0:8.0 2:3.0 12:3.0", "1.0 8.0 3.0", 1.0, 3.0, 888.0, 8), Row.of("0:9.0 2:2.0 10:4.0", "1.0 9.0 4.0", 1.0, 2.0, 77.0, 9), Row.of("0:10.0 2:2.0 7:4.0", "1.0 12.0 4.0", 1.0, 2.0, 766.0, 1) };
Softmax softmax;
Softmax vsoftmax;
Softmax vssoftmax;
Softmax svsoftmax;
@Before
public void setUp() {
softmax = new Softmax(new Params()).setFeatureCols(new String[] { "f0", "f1", "f2" }).setStandardization(true).setWithIntercept(true).setEpsilon(1.0e-20).setLabelCol(labelColName).enableLazyPrintModelInfo().setPredictionCol("predLr").setMaxIter(10);
vsoftmax = new Softmax().setVectorCol("vec").setStandardization(true).setWithIntercept(true).setEpsilon(1.0e-20).setLabelCol(labelColName).setPredictionCol("vpredLr").enableLazyPrintModelInfo().setMaxIter(10);
vssoftmax = new Softmax().setVectorCol("svec").setStandardization(true).setWithIntercept(true).setEpsilon(1.0e-20).setLabelCol(labelColName).setPredictionCol("vsspredLr").enableLazyPrintModelInfo().setOptimMethod("newton").setMaxIter(10);
svsoftmax = new Softmax().setVectorCol("svec").setStandardization(true).setWithIntercept(true).setEpsilon(1.0e-20).setLabelCol(labelColName).setPredictionCol("svpredLr").setPredictionDetailCol("svpredDetail").enableLazyPrintModelInfo().setMaxIter(10);
}
@Test
public void pipelineTest() throws Exception {
BatchOperator vecdata = new MemSourceBatchOp(Arrays.asList(vecrows), veccolNames);
StreamOperator svecdata = new MemSourceStreamOp(Arrays.asList(vecrows), veccolNames);
Pipeline pl = new Pipeline().add(softmax).add(vsoftmax).add(svsoftmax).add(vssoftmax);
PipelineModel model = pl.fit(vecdata);
BatchOperator result = model.transform(vecdata).select(new String[] { "label", "predLr", "vpredLr", "svpredLr" });
List<Row> data = result.collect();
for (Row row : data) {
for (int i = 1; i < 3; ++i) {
replacedert.replacedertEquals(row.getField(0), row.getField(i));
}
}
// below is stream test code
model.transform(svecdata).print();
StreamOperator.execute();
}
@Test
public void pipelineTest1() throws Exception {
BatchOperator vecmdata = new MemSourceBatchOp(Arrays.asList(vecmrows), veccolNames);
Pipeline pl = new Pipeline().add(softmax).add(vsoftmax).add(svsoftmax).add(vssoftmax);
PipelineModel modelm = pl.fit(vecmdata);
modelm.transform(vecmdata).select(new String[] { "label", "predLr", "vpredLr", "svpredLr" }).print();
}
}
See More Examples