/*
 * Decompiled with CFR 0.152.
 */
package org.apache.iotdb.db.queryengine.plan.relational.planner.ir;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.DeterminismEvaluator;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionRewriter;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.ExpressionTreeRewriter;
import org.apache.iotdb.db.queryengine.plan.relational.planner.ir.IrUtils;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.LogicalExpression;

public final class ExtractCommonPredicatesExpressionRewriter {
    public static Expression extractCommonPredicates(Expression expression) {
        return ExpressionTreeRewriter.rewriteWith(new Visitor(), expression, NodeContext.ROOT_NODE);
    }

    private ExtractCommonPredicatesExpressionRewriter() {
    }

    private static class Visitor
    extends ExpressionRewriter<NodeContext> {
        private Visitor() {
        }

        @Override
        protected Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            if (context.isRootNode()) {
                return treeRewriter.rewrite(node, NodeContext.NOT_ROOT_NODE);
            }
            return null;
        }

        @Override
        public Expression rewriteLogicalExpression(LogicalExpression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter) {
            Expression expression = IrUtils.combinePredicates(node.getOperator(), (Collection)IrUtils.extractPredicates(node.getOperator(), node).stream().map(subExpression -> treeRewriter.rewrite(subExpression, NodeContext.NOT_ROOT_NODE)).collect(ImmutableList.toImmutableList()));
            if (!(expression instanceof LogicalExpression)) {
                return expression;
            }
            Expression simplified = this.extractCommonPredicates((LogicalExpression)expression);
            if (context.isRootNode() && simplified instanceof LogicalExpression && ((LogicalExpression)simplified).getOperator() == LogicalExpression.Operator.OR) {
                return this.distributeIfPossible((LogicalExpression)simplified);
            }
            return simplified;
        }

        private Expression extractCommonPredicates(LogicalExpression node) {
            List<List<Expression>> subPredicates = Visitor.getSubPredicates(node);
            ImmutableSet commonPredicates = ImmutableSet.copyOf((Collection)subPredicates.stream().map(this::filterDeterministicPredicates).reduce(Sets::intersection).orElse(Collections.emptySet()));
            List uncorrelatedSubPredicates = (List)subPredicates.stream().map(arg_0 -> Visitor.lambda$extractCommonPredicates$1((Set)commonPredicates, arg_0)).collect(ImmutableList.toImmutableList());
            LogicalExpression.Operator flippedOperator = node.getOperator().flip();
            List uncorrelatedPredicates = (List)uncorrelatedSubPredicates.stream().map(predicate -> IrUtils.combinePredicates(flippedOperator, predicate)).collect(ImmutableList.toImmutableList());
            Expression combinedUncorrelatedPredicates = IrUtils.combinePredicates(node.getOperator(), uncorrelatedPredicates);
            return IrUtils.combinePredicates(flippedOperator, (Collection<Expression>)ImmutableList.builder().addAll((Iterable)commonPredicates).add((Object)combinedUncorrelatedPredicates).build());
        }

        private static List<List<Expression>> getSubPredicates(LogicalExpression expression) {
            return (List)IrUtils.extractPredicates(expression.getOperator(), expression).stream().map(predicate -> predicate instanceof LogicalExpression ? IrUtils.extractPredicates((LogicalExpression)predicate) : ImmutableList.of((Object)predicate)).collect(ImmutableList.toImmutableList());
        }

        private Expression distributeIfPossible(LogicalExpression expression) {
            int newBaseExpressions;
            if (!DeterminismEvaluator.isDeterministic(expression)) {
                return expression;
            }
            List subPredicates = Visitor.getSubPredicates(expression).stream().map(ImmutableSet::copyOf).collect(Collectors.toList());
            int originalBaseExpressions = subPredicates.stream().mapToInt(Set::size).sum();
            try {
                newBaseExpressions = Math.multiplyExact(subPredicates.stream().mapToInt(Set::size).reduce(Math::multiplyExact).getAsInt(), subPredicates.size());
            }
            catch (ArithmeticException e) {
                return expression;
            }
            if (newBaseExpressions > originalBaseExpressions * 2) {
                return expression;
            }
            Set crossProduct = Sets.cartesianProduct(subPredicates);
            return IrUtils.combinePredicates(expression.getOperator().flip(), (Collection)crossProduct.stream().map(expressions -> IrUtils.combinePredicates(expression.getOperator(), expressions)).collect(ImmutableList.toImmutableList()));
        }

        private Set<Expression> filterDeterministicPredicates(List<Expression> predicates) {
            return predicates.stream().filter(DeterminismEvaluator::isDeterministic).collect(Collectors.toSet());
        }

        private static <T> List<T> removeAll(Collection<T> collection, Collection<T> elementsToRemove) {
            return (List)collection.stream().filter(element -> !elementsToRemove.contains(element)).collect(ImmutableList.toImmutableList());
        }

        private static /* synthetic */ List lambda$extractCommonPredicates$1(Set commonPredicates, List predicateList) {
            return Visitor.removeAll(predicateList, commonPredicates);
        }
    }

    private static enum NodeContext {
        ROOT_NODE,
        NOT_ROOT_NODE;


        boolean isRootNode() {
            return this == ROOT_NODE;
        }
    }
}

