/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.sql.connector.util;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.StringJoiner;
import org.apache.spark.SparkIllegalArgumentException;
import org.apache.spark.SparkUnsupportedOperationException;
import org.apache.spark.sql.connector.expressions.Cast;
import org.apache.spark.sql.connector.expressions.Expression;
import org.apache.spark.sql.connector.expressions.Extract;
import org.apache.spark.sql.connector.expressions.GeneralScalarExpression;
import org.apache.spark.sql.connector.expressions.GetArrayItem;
import org.apache.spark.sql.connector.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.connector.expressions.NullOrdering;
import org.apache.spark.sql.connector.expressions.SortDirection;
import org.apache.spark.sql.connector.expressions.SortOrder;
import org.apache.spark.sql.connector.expressions.UserDefinedScalarFunc;
import org.apache.spark.sql.connector.expressions.aggregate.Avg;
import org.apache.spark.sql.connector.expressions.aggregate.Count;
import org.apache.spark.sql.connector.expressions.aggregate.CountStar;
import org.apache.spark.sql.connector.expressions.aggregate.GeneralAggregateFunc;
import org.apache.spark.sql.connector.expressions.aggregate.Max;
import org.apache.spark.sql.connector.expressions.aggregate.Min;
import org.apache.spark.sql.connector.expressions.aggregate.Sum;
import org.apache.spark.sql.connector.expressions.aggregate.UserDefinedAggregateFunc;
import org.apache.spark.sql.types.DataType;

public class V2ExpressionSQLBuilder {
    protected String escapeSpecialCharsForLikePattern(String str) {
        StringBuilder builder = new StringBuilder();
        block4: for (char c : str.toCharArray()) {
            switch (c) {
                case '_': {
                    builder.append("\\_");
                    continue block4;
                }
                case '%': {
                    builder.append("\\%");
                    continue block4;
                }
                default: {
                    builder.append(c);
                }
            }
        }
        return builder.toString();
    }

    public String build(Expression expr) {
        if (expr instanceof Literal) {
            Literal literal = (Literal)expr;
            return this.visitLiteral(literal);
        }
        if (expr instanceof NamedReference) {
            NamedReference namedReference = (NamedReference)expr;
            return this.visitNamedReference(namedReference);
        }
        if (expr instanceof Cast) {
            Cast cast = (Cast)expr;
            return this.visitCast(this.build(cast.expression()), cast.expressionDataType(), cast.dataType());
        }
        if (expr instanceof Extract) {
            Extract extract2 = (Extract)expr;
            return this.visitExtract(extract2);
        }
        if (expr instanceof SortOrder) {
            SortOrder sortOrder = (SortOrder)expr;
            return this.visitSortOrder(this.build(sortOrder.expression()), sortOrder.direction(), sortOrder.nullOrdering());
        }
        if (expr instanceof GetArrayItem) {
            GetArrayItem getArrayItem = (GetArrayItem)expr;
            return this.visitGetArrayItem(getArrayItem);
        }
        if (expr instanceof GeneralScalarExpression) {
            String name;
            GeneralScalarExpression e = (GeneralScalarExpression)expr;
            return switch (name = e.name()) {
                case "IN" -> {
                    Expression[] expressions = e.children();
                    List<String> children = this.expressionsToStringList(expressions, 1, expressions.length - 1);
                    yield this.visitIn(this.build(expressions[0]), children);
                }
                case "IS_NULL" -> this.visitIsNull(this.build(e.children()[0]));
                case "IS_NOT_NULL" -> this.visitIsNotNull(this.build(e.children()[0]));
                case "STARTS_WITH" -> this.visitStartsWith(this.build(e.children()[0]), this.build(e.children()[1]));
                case "ENDS_WITH" -> this.visitEndsWith(this.build(e.children()[0]), this.build(e.children()[1]));
                case "CONTAINS" -> this.visitContains(this.build(e.children()[0]), this.build(e.children()[1]));
                case "=", "<>", "<=>", "<", "<=", ">", ">=" -> this.visitBinaryComparison(name, e.children()[0], e.children()[1]);
                case "BOOLEAN_EXPRESSION" -> this.build(expr.children()[0]);
                case "+", "*", "/", "%", "&", "|", "^" -> this.visitBinaryArithmetic(name, this.inputToSQL(e.children()[0]), this.inputToSQL(e.children()[1]));
                case "-" -> {
                    if (e.children().length == 1) {
                        yield this.visitUnaryArithmetic(name, this.inputToSQL(e.children()[0]));
                    }
                    yield this.visitBinaryArithmetic(name, this.inputToSQL(e.children()[0]), this.inputToSQL(e.children()[1]));
                }
                case "AND" -> this.visitAnd(name, this.build(e.children()[0]), this.build(e.children()[1]));
                case "OR" -> this.visitOr(name, this.build(e.children()[0]), this.build(e.children()[1]));
                case "NOT" -> this.visitNot(this.build(e.children()[0]));
                case "~" -> this.visitUnaryArithmetic(name, this.inputToSQL(e.children()[0]));
                case "ABS", "COALESCE", "GREATEST", "LEAST", "RAND", "LOG", "LOG10", "LOG2", "LN", "EXP", "POWER", "SQRT", "FLOOR", "CEIL", "ROUND", "SIN", "SINH", "COS", "COSH", "TAN", "TANH", "COT", "ASIN", "ASINH", "ACOS", "ACOSH", "ATAN", "ATANH", "ATAN2", "CBRT", "DEGREES", "RADIANS", "SIGN", "WIDTH_BUCKET", "SUBSTRING", "UPPER", "LOWER", "TRANSLATE", "DATE_ADD", "DATE_DIFF", "TRUNC", "AES_ENCRYPT", "AES_DECRYPT", "SHA1", "SHA2", "MD5", "CRC32", "BIT_LENGTH", "CHAR_LENGTH", "CONCAT", "RPAD", "LPAD" -> this.visitSQLFunction(name, e.children());
                case "CASE_WHEN" -> this.visitCaseWhen(this.expressionsToStringArray(e.children()));
                case "TRIM" -> this.visitTrim("BOTH", this.expressionsToStringArray(e.children()));
                case "LTRIM" -> this.visitTrim("LEADING", this.expressionsToStringArray(e.children()));
                case "RTRIM" -> this.visitTrim("TRAILING", this.expressionsToStringArray(e.children()));
                case "OVERLAY" -> this.visitOverlay(this.expressionsToStringArray(e.children()));
                default -> this.visitUnexpectedExpr(expr);
            };
        }
        if (expr instanceof Min) {
            Min min = (Min)expr;
            return this.visitAggregateFunction("MIN", false, min.children());
        }
        if (expr instanceof Max) {
            Max max = (Max)expr;
            return this.visitAggregateFunction("MAX", false, max.children());
        }
        if (expr instanceof Count) {
            Count count = (Count)expr;
            return this.visitAggregateFunction("COUNT", count.isDistinct(), count.children());
        }
        if (expr instanceof Sum) {
            Sum sum = (Sum)expr;
            return this.visitAggregateFunction("SUM", sum.isDistinct(), sum.children());
        }
        if (expr instanceof CountStar) {
            CountStar countStar = (CountStar)expr;
            return this.visitAggregateFunction("COUNT", false, countStar.children());
        }
        if (expr instanceof Avg) {
            Avg avg = (Avg)expr;
            return this.visitAggregateFunction("AVG", avg.isDistinct(), avg.children());
        }
        if (expr instanceof GeneralAggregateFunc) {
            GeneralAggregateFunc f = (GeneralAggregateFunc)expr;
            if (f.orderingWithinGroups().length == 0) {
                return this.visitAggregateFunction(f.name(), f.isDistinct(), f.children());
            }
            return this.visitInverseDistributionFunction(f.name(), f.isDistinct(), this.expressionsToStringArray(f.children()), this.expressionsToStringArray(f.orderingWithinGroups()));
        }
        if (expr instanceof UserDefinedScalarFunc) {
            UserDefinedScalarFunc f = (UserDefinedScalarFunc)expr;
            return this.visitUserDefinedScalarFunction(f.name(), f.canonicalName(), this.expressionsToStringArray(f.children()));
        }
        if (expr instanceof UserDefinedAggregateFunc) {
            UserDefinedAggregateFunc f = (UserDefinedAggregateFunc)expr;
            return this.visitUserDefinedAggregateFunction(f.name(), f.canonicalName(), f.isDistinct(), this.expressionsToStringArray(f.children()));
        }
        return this.visitUnexpectedExpr(expr);
    }

    protected String visitLiteral(Literal<?> literal) {
        return literal.toString();
    }

    protected String visitNamedReference(NamedReference namedRef) {
        return namedRef.toString();
    }

    protected String visitIn(String v, List<String> list) {
        if (list.isEmpty()) {
            return "CASE WHEN " + v + " IS NULL THEN NULL ELSE FALSE END";
        }
        return this.joinListToString(list, ", ", v + " IN (", ")");
    }

    protected String visitIsNull(String v) {
        return v + " IS NULL";
    }

    protected String visitIsNotNull(String v) {
        return v + " IS NOT NULL";
    }

    protected String visitStartsWith(String l, String r) {
        String value = r.substring(1, r.length() - 1);
        return l + " LIKE '" + this.escapeSpecialCharsForLikePattern(value) + "%' ESCAPE '\\'";
    }

    protected String visitEndsWith(String l, String r) {
        String value = r.substring(1, r.length() - 1);
        return l + " LIKE '%" + this.escapeSpecialCharsForLikePattern(value) + "' ESCAPE '\\'";
    }

    protected String visitContains(String l, String r) {
        String value = r.substring(1, r.length() - 1);
        return l + " LIKE '%" + this.escapeSpecialCharsForLikePattern(value) + "%' ESCAPE '\\'";
    }

    protected String inputToSQL(Expression input) {
        if (input.children().length > 1) {
            return "(" + this.build(input) + ")";
        }
        return this.build(input);
    }

    protected String visitBinaryComparison(String name, Expression le, Expression re) {
        return this.visitBinaryComparison(name, this.inputToSQL(le), this.inputToSQL(re));
    }

    protected String visitBinaryComparison(String name, String l, String r) {
        if (name.equals("<=>")) {
            return "((" + l + " IS NOT NULL AND " + r + " IS NOT NULL AND " + l + " = " + r + ") OR (" + l + " IS NULL AND " + r + " IS NULL))";
        }
        return l + " " + name + " " + r;
    }

    protected String visitBinaryArithmetic(String name, String l, String r) {
        return l + " " + name + " " + r;
    }

    protected String visitCast(String expr, DataType exprDataType, DataType targetDataType) {
        return "CAST(" + expr + " AS " + targetDataType.typeName() + ")";
    }

    protected String visitAnd(String name, String l, String r) {
        return "(" + l + ") " + name + " (" + r + ")";
    }

    protected String visitOr(String name, String l, String r) {
        return "(" + l + ") " + name + " (" + r + ")";
    }

    protected String visitNot(String v) {
        return "NOT (" + v + ")";
    }

    protected String visitUnaryArithmetic(String name, String v) {
        return name + v;
    }

    protected String visitCaseWhen(String[] children2) {
        StringBuilder sb = new StringBuilder("CASE");
        for (int i = 0; i < children2.length; i += 2) {
            String c = children2[i];
            int j = i + 1;
            if (j < children2.length) {
                String v = children2[j];
                sb.append(" WHEN ");
                sb.append(c);
                sb.append(" THEN ");
                sb.append(v);
                continue;
            }
            sb.append(" ELSE ");
            sb.append(c);
        }
        sb.append(" END");
        return sb.toString();
    }

    protected String visitSQLFunction(String funcName, Expression[] inputs) {
        return this.visitSQLFunction(funcName, this.expressionsToStringArray(inputs));
    }

    protected String visitSQLFunction(String funcName, String[] inputs) {
        return this.joinArrayToString(inputs, ", ", funcName + "(", ")");
    }

    protected String visitAggregateFunction(String funcName, boolean isDistinct, Expression[] inputs) {
        if (funcName.equals("COUNT") && inputs.length == 0) {
            return this.visitAggregateFunction(funcName, isDistinct, new String[]{"*"});
        }
        return this.visitAggregateFunction(funcName, isDistinct, this.expressionsToStringArray(inputs));
    }

    protected String visitAggregateFunction(String funcName, boolean isDistinct, String[] inputs) {
        if (isDistinct) {
            return this.joinArrayToString(inputs, ", ", funcName + "(DISTINCT ", ")");
        }
        return this.joinArrayToString(inputs, ", ", funcName + "(", ")");
    }

    protected String visitInverseDistributionFunction(String funcName, boolean isDistinct, String[] inputs, String[] orderingWithinGroups) {
        assert (!isDistinct);
        String withinGroup = this.joinArrayToString(orderingWithinGroups, ", ", "WITHIN GROUP (ORDER BY ", ")");
        String functionCall = this.joinArrayToString(inputs, ", ", funcName + "(", ")");
        return functionCall + " " + withinGroup;
    }

    protected String visitUserDefinedScalarFunction(String funcName, String canonicalName, String[] inputs) {
        throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3141", Map.of("class", this.getClass().getSimpleName(), "funcName", funcName));
    }

    protected String visitUserDefinedAggregateFunction(String funcName, String canonicalName, boolean isDistinct, String[] inputs) {
        throw new SparkUnsupportedOperationException("_LEGACY_ERROR_TEMP_3142", Map.of("class", this.getClass().getSimpleName(), "funcName", funcName));
    }

    protected String visitUnexpectedExpr(Expression expr) throws IllegalArgumentException {
        throw new SparkIllegalArgumentException("_LEGACY_ERROR_TEMP_3207", Map.of("expr", String.valueOf(expr)));
    }

    protected String visitOverlay(String[] inputs) {
        assert (inputs.length == 3 || inputs.length == 4);
        if (inputs.length == 3) {
            return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] + ")";
        }
        return "OVERLAY(" + inputs[0] + " PLACING " + inputs[1] + " FROM " + inputs[2] + " FOR " + inputs[3] + ")";
    }

    protected String visitTrim(String direction, String[] inputs) {
        assert (inputs.length == 1 || inputs.length == 2);
        if (inputs.length == 1) {
            return "TRIM(" + direction + " FROM " + inputs[0] + ")";
        }
        return "TRIM(" + direction + " " + inputs[1] + " FROM " + inputs[0] + ")";
    }

    protected String visitGetArrayItem(GetArrayItem getArrayItem) {
        throw new SparkUnsupportedOperationException("EXPRESSION_TRANSLATION_TO_V2_IS_NOT_SUPPORTED", Map.of("expr", getArrayItem.toString()));
    }

    protected String visitExtract(Extract extract2) {
        return this.visitExtract(extract2.field(), this.build(extract2.source()));
    }

    protected String visitExtract(String field, String source) {
        return "EXTRACT(" + field + " FROM " + source + ")";
    }

    protected String visitSortOrder(String sortKey, SortDirection sortDirection, NullOrdering nullOrdering) {
        return sortKey + " " + String.valueOf((Object)sortDirection) + " " + String.valueOf((Object)nullOrdering);
    }

    private String joinArrayToString(String[] inputs, CharSequence delimiter, CharSequence prefix, CharSequence suffix) {
        StringJoiner joiner = new StringJoiner(delimiter, prefix, suffix);
        for (String input : inputs) {
            joiner.add(input);
        }
        return joiner.toString();
    }

    private String joinListToString(List<String> inputs, CharSequence delimiter, CharSequence prefix, CharSequence suffix) {
        StringJoiner joiner = new StringJoiner(delimiter, prefix, suffix);
        for (String input : inputs) {
            joiner.add(input);
        }
        return joiner.toString();
    }

    protected String[] expressionsToStringArray(Expression[] expressions) {
        String[] result = new String[expressions.length];
        for (int i = 0; i < expressions.length; ++i) {
            result[i] = this.build(expressions[i]);
        }
        return result;
    }

    private List<String> expressionsToStringList(Expression[] expressions, int offset, int length) {
        ArrayList<String> list = new ArrayList<String>(length);
        int till = Math.min(offset + length, expressions.length);
        while (offset < till) {
            list.add(this.build(expressions[offset]));
            ++offset;
        }
        return list;
    }
}

