/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public class BalanceToAvgFederatedScheme
extends DataPartitionFederatedScheme {
    @Override
    public DataPartitionFederatedScheme.Result partition(MatrixObject features, MatrixObject labels, int seed) {
        List<MatrixObject> pFeatures = BalanceToAvgFederatedScheme.sliceFederatedMatrix(features);
        List<MatrixObject> pLabels = BalanceToAvgFederatedScheme.sliceFederatedMatrix(labels);
        DataPartitionFederatedScheme.BalanceMetrics balanceMetricsBefore = BalanceToAvgFederatedScheme.getBalanceMetrics(pFeatures);
        List<Double> weightingFactors = BalanceToAvgFederatedScheme.getWeightingFactors(pFeatures, balanceMetricsBefore);
        int average_num_rows = (int)balanceMetricsBefore._avgRows;
        for (int i = 0; i < pFeatures.size(); ++i) {
            FederatedData featuresData = pFeatures.get(i).getFedMapping().getFederatedData()[0];
            FederatedData labelsData = pLabels.get(i).getFedMapping().getFederatedData()[0];
            Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, featuresData.getVarID(), new balanceDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, average_num_rows)));
            try {
                FederatedResponse response = udfResponse.get();
                if (!response.isSuccessful()) {
                    throw new DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: balance UDF returned fail");
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException("FederatedDataPartitioner BalanceFederatedScheme: executing balance UDF failed" + e.getMessage());
            }
            DataCharacteristics update = pFeatures.get(i).getDataCharacteristics().setRows(average_num_rows);
            pFeatures.get(i).updateDataCharacteristics(update);
            update = pLabels.get(i).getDataCharacteristics().setRows(average_num_rows);
            pLabels.get(i).updateDataCharacteristics(update);
        }
        return new DataPartitionFederatedScheme.Result(pFeatures, pLabels, pFeatures.size(), BalanceToAvgFederatedScheme.getBalanceMetrics(pFeatures), weightingFactors);
    }

    private static class balanceDataOnFederatedWorker
    extends FederatedUDF {
        private static final long serialVersionUID = 6631958250346625546L;
        private final int _seed;
        private final int _average_num_rows;

        protected balanceDataOnFederatedWorker(long[] inIDs, int seed, int average_num_rows) {
            super(inIDs);
            this._seed = seed;
            this._average_num_rows = average_num_rows;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixObject features = (MatrixObject)data[0];
            MatrixObject labels = (MatrixObject)data[1];
            if (features.getNumRows() > (long)this._average_num_rows) {
                MatrixBlock subsampleMatrixBlock = ParamservUtils.generateSubsampleMatrix(this._average_num_rows, Math.toIntExact(features.getNumRows()), this._seed);
                DataPartitionFederatedScheme.subsampleTo(features, subsampleMatrixBlock);
                DataPartitionFederatedScheme.subsampleTo(labels, subsampleMatrixBlock);
            } else if (features.getNumRows() < (long)this._average_num_rows) {
                int num_rows_needed = this._average_num_rows - Math.toIntExact(features.getNumRows());
                MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), this._seed);
                DataPartitionFederatedScheme.replicateTo(features, replicateMatrixBlock);
                DataPartitionFederatedScheme.replicateTo(labels, replicateMatrixBlock);
            }
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

