/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.lib;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockMCSR;
import org.apache.sysds.runtime.functionobjects.SwapIndex;
import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class CLALibReorg {
    protected static final Log LOG = LogFactory.getLog((String)CLALibReorg.class.getName());
    public static boolean warned = false;

    public static MatrixBlock reorg(CompressedMatrixBlock cmb, ReorgOperator op, MatrixBlock ret, int startRow, int startColumn, int length) {
        if (op.fn instanceof SwapIndex && cmb.getNumColumns() == 1) {
            MatrixBlock tmp = cmb.decompress(op.getNumThreads());
            long nz = tmp.setNonZeros(tmp.getNonZeros());
            if (tmp.isInSparseFormat()) {
                return LibMatrixReorg.transpose(tmp);
            }
            tmp = new MatrixBlock(tmp.getNumColumns(), tmp.getNumRows(), tmp.getDenseBlockValues());
            tmp.setNonZeros(nz);
            return tmp;
        }
        if (op.fn instanceof SwapIndex) {
            MatrixBlock tmp = cmb.getCachedDecompressed();
            if (tmp != null) {
                return tmp.reorgOperations(op, ret, startRow, startColumn, length);
            }
            return CLALibReorg.transpose(cmb, ret, op.getNumThreads());
        }
        String message = !warned ? op.getClass().getSimpleName() + " -- " + op.fn.getClass().getSimpleName() : null;
        MatrixBlock tmp = cmb.getUncompressed(message, op.getNumThreads());
        warned = true;
        return tmp.reorgOperations(op, ret, startRow, startColumn, length);
    }

    private static MatrixBlock transpose(CompressedMatrixBlock cmb, MatrixBlock ret, int k) {
        long nnz = cmb.getNonZeros();
        int nRow = cmb.getNumRows();
        int nCol = cmb.getNumColumns();
        boolean sparseOut = MatrixBlock.evalSparseFormatInMemory(nCol, nRow, nnz);
        if (sparseOut) {
            return CLALibReorg.transposeSparse(cmb, ret, k, nRow, nCol, nnz);
        }
        return CLALibReorg.transposeDense(cmb, ret, k, nRow, nCol, nnz);
    }

    private static MatrixBlock transposeSparse(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol, long nnz) {
        if (ret == null) {
            ret = new MatrixBlock(nCol, nRow, true, nnz);
        } else {
            ret.reset(nCol, nRow, true, nnz);
        }
        ret.allocateAndResetSparseBlock(true, SparseBlock.Type.MCSR);
        int nColOut = ret.getNumColumns();
        if (k > 1 && cmb.getColGroups().size() > 1) {
            CLALibReorg.decompressToTransposedSparseParallel((SparseBlockMCSR)ret.getSparseBlock(), cmb.getColGroups(), nColOut, k);
        } else {
            CLALibReorg.decompressToTransposedSparseSingleThread((SparseBlockMCSR)ret.getSparseBlock(), cmb.getColGroups(), nColOut);
        }
        return ret;
    }

    private static MatrixBlock transposeDense(CompressedMatrixBlock cmb, MatrixBlock ret, int k, int nRow, int nCol, long nnz) {
        if (ret == null) {
            ret = new MatrixBlock(nCol, nRow, false, nnz);
        } else {
            ret.reset(nCol, nRow, false, nnz);
        }
        ret.allocateDenseBlock();
        CLALibReorg.decompressToTransposedDense(ret.getDenseBlock(), cmb.getColGroups(), nRow, 0, nRow);
        return ret;
    }

    private static void decompressToTransposedDense(DenseBlock ret, List<AColGroup> groups, int rlen, int rl, int ru) {
        for (int i = 0; i < groups.size(); ++i) {
            AColGroup g = groups.get(i);
            g.decompressToDenseBlockTransposed(ret, rl, ru);
        }
    }

    private static void decompressToTransposedSparseSingleThread(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut) {
        for (int i = 0; i < groups.size(); ++i) {
            AColGroup g = groups.get(i);
            g.decompressToSparseBlockTransposed(ret, nColOut);
        }
    }

    private static void decompressToTransposedSparseParallel(SparseBlockMCSR ret, List<AColGroup> groups, int nColOut, int k) {
        ExecutorService pool = CommonThreadPool.get(k);
        try {
            ArrayList tasks = new ArrayList(groups.size());
            for (int i = 0; i < groups.size(); ++i) {
                AColGroup aColGroup = groups.get(i);
                tasks.add(pool.submit(() -> g.decompressToSparseBlockTransposed(ret, nColOut)));
            }
            for (Future future : tasks) {
                future.get();
            }
        }
        catch (Exception e) {
            throw new DMLCompressionException("Failed to parallel decompress transpose sparse", e);
        }
        finally {
            pool.shutdown();
        }
    }
}

