/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.wlm;

import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.TransportInterceptor;
import org.opensearch.transport.TransportRequest;
import org.opensearch.transport.TransportRequestHandler;
import org.opensearch.wlm.WorkloadGroupService;
import org.opensearch.wlm.WorkloadGroupTask;

public class WorkloadManagementTransportInterceptor
implements TransportInterceptor {
    private final ThreadPool threadPool;
    private final WorkloadGroupService workloadGroupService;

    public WorkloadManagementTransportInterceptor(ThreadPool threadPool, WorkloadGroupService workloadGroupService) {
        this.threadPool = threadPool;
        this.workloadGroupService = workloadGroupService;
    }

    @Override
    public <T extends TransportRequest> TransportRequestHandler<T> interceptHandler(String action, String executor, boolean forceExecution, TransportRequestHandler<T> actualHandler) {
        return new RequestHandler<T>(this.threadPool, actualHandler, this.workloadGroupService);
    }

    public static class RequestHandler<T extends TransportRequest>
    implements TransportRequestHandler<T> {
        private final ThreadPool threadPool;
        TransportRequestHandler<T> actualHandler;
        private final WorkloadGroupService workloadGroupService;

        public RequestHandler(ThreadPool threadPool, TransportRequestHandler<T> actualHandler, WorkloadGroupService workloadGroupService) {
            this.threadPool = threadPool;
            this.actualHandler = actualHandler;
            this.workloadGroupService = workloadGroupService;
        }

        @Override
        public void messageReceived(T request, TransportChannel channel, Task task) throws Exception {
            if (this.isSearchWorkloadRequest(task)) {
                ((WorkloadGroupTask)task).setWorkloadGroupId(this.threadPool.getThreadContext());
                String workloadGroupId = ((WorkloadGroupTask)task).getWorkloadGroupId();
                this.workloadGroupService.rejectIfNeeded(workloadGroupId);
            }
            this.actualHandler.messageReceived(request, channel, task);
        }

        boolean isSearchWorkloadRequest(Task task) {
            return task instanceof WorkloadGroupTask;
        }
    }
}

