/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.federated;

import java.io.Serializable;
import java.net.InetSocketAddress;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Future;
import javax.net.ssl.SSLException;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.utils.Statistics;

public class FederatedStatistics {
    private static Set<Pair<String, Integer>> _fedWorkerAddresses = new HashSet<Pair<String, Integer>>();

    public static void registerFedWorker(String host, int port) {
        _fedWorkerAddresses.add((Pair<String, Integer>)new ImmutablePair((Object)host, (Object)new Integer(port)));
    }

    public static String displayFedWorkers() {
        StringBuilder sb = new StringBuilder();
        sb.append("Federated Worker Addresses:\n");
        for (Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
            sb.append(String.format("  %s:%d", fedAddr.getLeft(), (int)((Integer)fedAddr.getRight())));
            sb.append("\n");
        }
        return sb.toString();
    }

    public static String displayFedStatistics(int numHeavyHitters) {
        StringBuilder sb = new StringBuilder();
        FedStatsCollection fedStats = FederatedStatistics.collectFedStats();
        sb.append("SystemDS Federated Statistics:\n");
        sb.append(FederatedStatistics.displayCacheStats(fedStats.cacheStats));
        sb.append(String.format("Total JIT compile time:\t\t%.3f sec.\n", fedStats.jitCompileTime));
        sb.append(FederatedStatistics.displayGCStats(fedStats.gcStats));
        sb.append(FederatedStatistics.displayHeavyHitters(fedStats.heavyHitters, numHeavyHitters));
        return sb.toString();
    }

    public static String displayCacheStats(FedStatsCollection.CacheStatsCollection csc) {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Cache hits (Mem/Li/WB/FS/HDFS):\t%d/%d/%d/%d/%d.\n", csc.memHits, csc.linHits, csc.fsBuffHits, csc.fsHits, csc.hdfsHits));
        sb.append(String.format("Cache writes (Li/WB/FS/HDFS):\t%d/%d/%d/%d.\n", csc.linWrites, csc.fsBuffWrites, csc.fsWrites, csc.hdfsWrites));
        sb.append(String.format("Cache times (ACQr/m, RLS, EXP):\t%.3f/%.3f/%.3f/%.3f sec.\n", csc.acqRTime, csc.acqMTime, csc.rlsTime, csc.expTime));
        return sb.toString();
    }

    public static String displayGCStats(FedStatsCollection.GCStatsCollection gcsc) {
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Total JVM GC count:\t\t%d.\n", gcsc.gcCount));
        sb.append(String.format("Total JVM GC time:\t\t%.3f sec.\n", gcsc.gcTime));
        return sb.toString();
    }

    public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> heavyHitters) {
        return FederatedStatistics.displayHeavyHitters(heavyHitters, 10);
    }

    public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> heavyHitters, int num) {
        int counter;
        StringBuilder sb = new StringBuilder();
        Map.Entry[] hhArr = heavyHitters.entrySet().toArray(new Map.Entry[0]);
        Arrays.sort(hhArr, new Comparator<Map.Entry<String, Pair<Long, Double>>>(){

            @Override
            public int compare(Map.Entry<String, Pair<Long, Double>> e1, Map.Entry<String, Pair<Long, Double>> e2) {
                return ((Double)e1.getValue().getRight()).compareTo((Double)e2.getValue().getRight());
            }
        });
        sb.append("Heavy hitter instructions:\n");
        String numCol = "#";
        String instCol = "Instruction";
        String timeSCol = "Time(s)";
        String countCol = "Count";
        int numHittersToDisplay = Math.min(num, hhArr.length);
        int maxNumLen = String.valueOf(numHittersToDisplay).length();
        int maxInstLen = "Instruction".length();
        int maxTimeSLen = "Time(s)".length();
        int maxCountLen = "Count".length();
        DecimalFormat sFormat = new DecimalFormat("#,##0.000");
        for (counter = 0; counter < numHittersToDisplay; ++counter) {
            Map.Entry hh = hhArr[hhArr.length - 1 - counter];
            String instruction = (String)hh.getKey();
            maxInstLen = Math.max(maxInstLen, instruction.length());
            String timeString = sFormat.format(((Pair)hh.getValue()).getRight());
            maxTimeSLen = Math.max(maxTimeSLen, timeString.length());
            maxCountLen = Math.max(maxCountLen, String.valueOf(((Pair)hh.getValue()).getLeft()).length());
        }
        maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN);
        sb.append(String.format(" %" + maxNumLen + "s  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "s", "#", "Instruction", "Time(s)", "Count"));
        sb.append("\n");
        for (counter = 0; counter < numHittersToDisplay; ++counter) {
            String instruction = (String)hhArr[hhArr.length - 1 - counter].getKey();
            String[] wrappedInstruction = Statistics.wrap(instruction, maxInstLen);
            String timeSString = sFormat.format(((Pair)hhArr[hhArr.length - 1 - counter].getValue()).getRight());
            long count = (Long)((Pair)hhArr[hhArr.length - 1 - counter].getValue()).getLeft();
            int numLines = wrappedInstruction.length;
            for (int wrapIter = 0; wrapIter < numLines; ++wrapIter) {
                String instStr;
                String string = instStr = wrapIter < wrappedInstruction.length ? wrappedInstruction[wrapIter] : "";
                if (wrapIter == 0) {
                    sb.append(String.format(" %" + maxNumLen + "d  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "d", counter + 1, instStr, timeSString, count));
                } else {
                    sb.append(String.format(" %" + maxNumLen + "s  %-" + maxInstLen + "s  %" + maxTimeSLen + "s  %" + maxCountLen + "s", "", instStr, "", ""));
                }
                sb.append("\n");
            }
        }
        return sb.toString();
    }

    private static FedStatsCollection collectFedStats() {
        Future<FederatedResponse>[] responses = FederatedStatistics.getFederatedResponses();
        FedStatsCollection aggFedStats = new FedStatsCollection();
        for (Future<FederatedResponse> res : responses) {
            try {
                Object[] tmp = res.get().getData();
                if (!(tmp[0] instanceof FedStatsCollection)) continue;
                aggFedStats.aggregate((FedStatsCollection)tmp[0]);
            }
            catch (Exception e) {
                throw new DMLRuntimeException("Exception of type " + e.getClass().toString() + " thrown while getting the federated stats of the federated response: ", e);
            }
        }
        return aggFedStats;
    }

    private static Future<FederatedResponse>[] getFederatedResponses() {
        ArrayList<Future<FederatedResponse>> ret = new ArrayList<Future<FederatedResponse>>();
        for (Pair<String, Integer> fedAddr : _fedWorkerAddresses) {
            InetSocketAddress isa = new InetSocketAddress((String)fedAddr.getLeft(), (int)((Integer)fedAddr.getRight()));
            FederatedRequest frUDF = new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new FedStatsCollectFunction());
            try {
                ret.add(FederatedData.executeFederatedOperation(isa, frUDF));
            }
            catch (SSLException ssle) {
                System.out.println("SSLException while getting the federated stats from " + isa.toString() + ": " + ssle.getMessage());
            }
            catch (DMLRuntimeException ssle) {
            }
            catch (Exception e) {
                System.out.println("Exeption of type " + e.getClass().getName() + " thrown while getting stats from federated worker: " + e.getMessage());
            }
        }
        Future[] retArr = ret.toArray(new Future[0]);
        return retArr;
    }

    protected static class FedStatsCollection
    implements Serializable {
        private static final long serialVersionUID = 1L;
        private CacheStatsCollection cacheStats = new CacheStatsCollection();
        private double jitCompileTime = 0.0;
        private GCStatsCollection gcStats = new GCStatsCollection();
        private HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap();

        protected FedStatsCollection() {
        }

        private void collectStats() {
            this.cacheStats.collectStats();
            this.jitCompileTime = (double)Statistics.getJITCompileTime() / 1000.0;
            this.gcStats.collectStats();
            this.heavyHitters = Statistics.getHeavyHittersHashMap();
        }

        private void aggregate(FedStatsCollection that) {
            this.cacheStats.aggregate(that.cacheStats);
            this.jitCompileTime += that.jitCompileTime;
            this.gcStats.aggregate(that.gcStats);
            that.heavyHitters.forEach((key, value) -> this.heavyHitters.merge((String)key, (Pair<Long, Double>)value, (v1, v2) -> new ImmutablePair((Object)((Long)v1.getLeft() + (Long)v2.getLeft()), (Object)((Double)v1.getRight() + (Double)v2.getRight()))));
        }

        protected static class GCStatsCollection
        implements Serializable {
            private static final long serialVersionUID = 1L;
            private long gcCount = 0L;
            private double gcTime = 0.0;

            protected GCStatsCollection() {
            }

            private void collectStats() {
                this.gcCount = Statistics.getJVMgcCount();
                this.gcTime = (double)Statistics.getJVMgcTime() / 1000.0;
            }

            private void aggregate(GCStatsCollection that) {
                this.gcCount += that.gcCount;
                this.gcTime += that.gcTime;
            }
        }

        protected static class CacheStatsCollection
        implements Serializable {
            private static final long serialVersionUID = 1L;
            private long memHits = 0L;
            private long linHits = 0L;
            private long fsBuffHits = 0L;
            private long fsHits = 0L;
            private long hdfsHits = 0L;
            private long linWrites = 0L;
            private long fsBuffWrites = 0L;
            private long fsWrites = 0L;
            private long hdfsWrites = 0L;
            private double acqRTime = 0.0;
            private double acqMTime = 0.0;
            private double rlsTime = 0.0;
            private double expTime = 0.0;

            protected CacheStatsCollection() {
            }

            private void collectStats() {
                this.memHits = CacheStatistics.getMemHits();
                this.linHits = CacheStatistics.getLinHits();
                this.fsBuffHits = CacheStatistics.getFSBuffHits();
                this.fsHits = CacheStatistics.getFSHits();
                this.hdfsHits = CacheStatistics.getHDFSHits();
                this.linWrites = CacheStatistics.getLinWrites();
                this.fsBuffWrites = CacheStatistics.getFSBuffWrites();
                this.fsWrites = CacheStatistics.getFSWrites();
                this.hdfsWrites = CacheStatistics.getHDFSWrites();
                this.acqRTime = (double)CacheStatistics.getAcquireRTime() / 1.0E9;
                this.acqMTime = (double)CacheStatistics.getAcquireMTime() / 1.0E9;
                this.rlsTime = (double)CacheStatistics.getReleaseTime() / 1.0E9;
                this.expTime = (double)CacheStatistics.getExportTime() / 1.0E9;
            }

            private void aggregate(CacheStatsCollection that) {
                this.memHits += that.memHits;
                this.linHits += that.linHits;
                this.fsBuffHits += that.fsBuffHits;
                this.fsHits += that.fsHits;
                this.hdfsHits += that.hdfsHits;
                this.linWrites += that.linWrites;
                this.fsBuffWrites += that.fsBuffWrites;
                this.fsWrites += that.fsWrites;
                this.hdfsWrites += that.hdfsWrites;
                this.acqRTime += that.acqRTime;
                this.acqMTime += that.acqMTime;
                this.rlsTime += that.rlsTime;
                this.expTime += that.expTime;
            }
        }
    }

    private static class FedStatsCollectFunction
    extends FederatedUDF {
        private static final long serialVersionUID = 1L;

        public FedStatsCollectFunction() {
            super(new long[0]);
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            FedStatsCollection fedStats = new FedStatsCollection();
            fedStats.collectStats();
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, fedStats);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

