/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.sql.expression.function.CollectionUDF;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;
import org.apache.calcite.adapter.enumerable.NotNullImplementor;
import org.apache.calcite.adapter.enumerable.NullPolicy;
import org.apache.calcite.adapter.enumerable.RexToLixTranslator;
import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.function.Function2;
import org.apache.calcite.linq4j.tree.Expression;
import org.apache.calcite.linq4j.tree.Expressions;
import org.apache.calcite.linq4j.tree.Types;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexCallBinding;
import org.apache.calcite.rex.RexLambda;
import org.apache.calcite.sql.type.ArraySqlType;
import org.apache.calcite.sql.type.SqlReturnTypeInference;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.expression.function.CollectionUDF.LambdaUtils;
import org.opensearch.sql.expression.function.ImplementorUDF;
import org.opensearch.sql.expression.function.UDFOperandMetadata;

public class TransformFunctionImpl
extends ImplementorUDF {
    public TransformFunctionImpl() {
        super(new TransformImplementor(), NullPolicy.ANY);
    }

    @Override
    public SqlReturnTypeInference getReturnTypeInference() {
        return sqlOperatorBinding -> {
            RelDataTypeFactory typeFactory = sqlOperatorBinding.getTypeFactory();
            RexCallBinding rexCallBinding = (RexCallBinding)sqlOperatorBinding;
            List operands = rexCallBinding.operands();
            RelDataType lambdaReturnType = ((RexLambda)operands.get(1)).getExpression().getType();
            return SqlTypeUtil.createArrayType((RelDataTypeFactory)typeFactory, (RelDataType)typeFactory.createTypeWithNullability(lambdaReturnType, true), (boolean)true);
        };
    }

    @Override
    public UDFOperandMetadata getOperandMetadata() {
        return null;
    }

    public static Object eval(Object ... args) {
        boolean hasCapturedVars;
        List target = (List)args[0];
        ArrayList<Object> results = new ArrayList<Object>();
        SqlTypeName returnType = (SqlTypeName)args[args.length - 1];
        boolean bl = hasCapturedVars = args.length > 3;
        if (args[1] instanceof Function1) {
            Function1 lambdaFunction = (Function1)args[1];
            try {
                for (Object candidate : target) {
                    results.add(LambdaUtils.transferLambdaOutputToTargetType(lambdaFunction.apply(candidate), returnType));
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            return results;
        }
        if (args[1] instanceof Function2) {
            Function2 lambdaFunction = (Function2)args[1];
            try {
                if (hasCapturedVars) {
                    Object capturedVar = args[2];
                    for (Object candidate : target) {
                        results.add(LambdaUtils.transferLambdaOutputToTargetType(lambdaFunction.apply(candidate, capturedVar), returnType));
                    }
                } else {
                    for (int i = 0; i < target.size(); ++i) {
                        results.add(LambdaUtils.transferLambdaOutputToTargetType(lambdaFunction.apply(target.get(i), (Object)i), returnType));
                    }
                }
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
            return results;
        }
        throw new IllegalArgumentException("wrong lambda function input");
    }

    public static class TransformImplementor
    implements NotNullImplementor {
        public Expression implement(RexToLixTranslator translator, RexCall call, List<Expression> translatedOperands) {
            ArraySqlType arrayType = (ArraySqlType)call.getType();
            ArrayList<Expression> withReturnTypeList = new ArrayList<Expression>(translatedOperands);
            withReturnTypeList.add((Expression)Expressions.constant((Object)arrayType.getComponentType().getSqlTypeName()));
            return Expressions.call((Method)Types.lookupMethod(TransformFunctionImpl.class, (String)"eval", (Class[])new Class[]{Object[].class}), withReturnTypeList);
        }
    }
}

