/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.correlation;

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.TreeSet;
import org.opensearch.ad.correlation.Anomaly;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.timeseries.model.Feature;
import org.opensearch.timeseries.model.IntervalTimeConfiguration;
import org.opensearch.timeseries.model.TimeConfiguration;

public final class AnomalyCorrelation {
    private static final Comparator<DilatedAnomaly> BY_START_THEN_END_THEN_IDX = Comparator.comparing(d -> d.start).thenComparing(d -> d.end).thenComparingInt(d -> d.idx);
    private static final Comparator<DilatedAnomaly> BY_END_THEN_START_THEN_IDX = Comparator.comparing(d -> d.end).thenComparing(d -> d.start).thenComparingInt(d -> d.idx);
    private static final Duration COARSE_INTERVAL_THRESHOLD = Duration.ofMinutes(30L);
    private static final Duration DELTA_TOL = Duration.ofMinutes(5L);
    private static final double ALPHA = 0.3;
    private static final Duration KAPPA = Duration.ofMinutes(30L);
    private static final Duration MIN_OVERLAP = Duration.ofMinutes(3L);
    private static final TemporalMode MODE = TemporalMode.HYBRID;
    private static final double LAM = 0.6;
    private static final double TAU_CONTAIN = 0.8;
    private static final double RHO_MAX = 0.25;
    private static final double CONTAINMENT_RELAX = 0.45;
    private static final Comparator<Anomaly> CLUSTER_MEMBER_ORDER = Comparator.comparing(Anomaly::getModelId);

    private AnomalyCorrelation() {
    }

    private static Map<String, Duration> detectorIntervalsById(List<AnomalyDetector> detectors) {
        Objects.requireNonNull(detectors, "detectors");
        if (detectors.isEmpty()) {
            throw new IllegalArgumentException("detectors must not be empty");
        }
        HashMap<String, Duration> intervals = new HashMap<String, Duration>(detectors.size());
        for (AnomalyDetector detector : detectors) {
            TimeConfiguration intervalConfig;
            String detectorId;
            if (detector == null || (detectorId = detector.getId()) == null || !((intervalConfig = detector.getInterval()) instanceof IntervalTimeConfiguration)) continue;
            intervals.put(detectorId, ((IntervalTimeConfiguration)intervalConfig).toDuration());
        }
        if (intervals.isEmpty()) {
            throw new IllegalArgumentException("detectors must include interval configurations");
        }
        return intervals;
    }

    private static Duration detectorIntervalForAnomaly(Anomaly anomaly, Map<String, Duration> detectorIntervals) {
        if (detectorIntervals == null || detectorIntervals.isEmpty()) {
            return null;
        }
        return detectorIntervals.get(anomaly.getConfigId());
    }

    private static Map<String, Boolean> detectorUsesCommunityAggregationById(List<AnomalyDetector> detectors) {
        Objects.requireNonNull(detectors, "detectors");
        if (detectors.isEmpty()) {
            throw new IllegalArgumentException("detectors must not be empty");
        }
        HashMap<String, Boolean> usesCommunityAggregation = new HashMap<String, Boolean>(detectors.size());
        for (AnomalyDetector detector : detectors) {
            String detectorId;
            if (detector == null || (detectorId = detector.getId()) == null) continue;
            usesCommunityAggregation.put(detectorId, AnomalyCorrelation.hasCommunityAggregation(detector.getFeatureAttributes()));
        }
        return usesCommunityAggregation;
    }

    private static boolean hasCommunityAggregation(List<Feature> features) {
        if (features == null || features.isEmpty()) {
            return false;
        }
        for (Feature feature : features) {
            if (feature == null) {
                return true;
            }
            if (Boolean.FALSE.equals(feature.getEnabled())) continue;
            AggregationBuilder aggregation = feature.getAggregation();
            if (aggregation == null) {
                return true;
            }
            AggregationBuilder metricAggregation = AnomalyCorrelation.unwrapSingleMetricAggregation(aggregation);
            if (metricAggregation != null && AnomalyCorrelation.isMaxMinPercentileAggregation(metricAggregation)) continue;
            return true;
        }
        return false;
    }

    private static AggregationBuilder unwrapSingleMetricAggregation(AggregationBuilder aggregation) {
        AggregationBuilder current = aggregation;
        while (current != null) {
            Collection subAggregations = current.getSubAggregations();
            if (subAggregations == null || subAggregations.isEmpty()) {
                return current;
            }
            if (subAggregations.size() != 1) {
                return null;
            }
            current = (AggregationBuilder)subAggregations.iterator().next();
        }
        return null;
    }

    private static boolean isMaxMinPercentileAggregation(AggregationBuilder aggregation) {
        String type = aggregation.getType();
        return "max".equals(type) || "min".equals(type) || "percentiles".equals(type);
    }

    private static Duration backwardDilation(Anomaly a, Duration deltaTol, Map<String, Duration> detectorIntervals, Map<String, Boolean> detectorUsesCommunityAggregation) {
        if (detectorUsesCommunityAggregation == null || detectorUsesCommunityAggregation.isEmpty() || !Boolean.TRUE.equals(detectorUsesCommunityAggregation.get(a.getConfigId()))) {
            return deltaTol;
        }
        Duration interval = AnomalyCorrelation.detectorIntervalForAnomaly(a, detectorIntervals);
        if (interval == null || interval.isZero() || interval.isNegative()) {
            return deltaTol;
        }
        return interval.compareTo(COARSE_INTERVAL_THRESHOLD) >= 0 ? interval : deltaTol;
    }

    private static Interval dilate(Instant s, Instant e, Duration deltaStart, Duration deltaEnd) {
        return new Interval(s.minus(deltaStart), e.plus(deltaEnd));
    }

    private static Duration overlapLength(Interval a, Interval b) {
        Instant e;
        Instant s = a.start.isAfter(b.start) ? a.start : b.start;
        Instant instant = e = a.end.isBefore(b.end) ? a.end : b.end;
        if (!e.isAfter(s)) {
            return Duration.ZERO;
        }
        return Duration.between(s, e);
    }

    private static double temporalIou(Anomaly a, Anomaly b, Duration delta, Map<String, Duration> detectorIntervals, Map<String, Boolean> detectorUsesCommunityAggregation) {
        long lenBNanos;
        Interval bd;
        Duration aStart = AnomalyCorrelation.backwardDilation(a, delta, detectorIntervals, detectorUsesCommunityAggregation);
        Duration bStart = AnomalyCorrelation.backwardDilation(b, delta, detectorIntervals, detectorUsesCommunityAggregation);
        Interval ad = AnomalyCorrelation.dilate(a.getDataStartTime(), a.getDataEndTime(), aStart, delta);
        long overlapNanos = AnomalyCorrelation.overlapLength(ad, bd = AnomalyCorrelation.dilate(b.getDataStartTime(), b.getDataEndTime(), bStart, delta)).toNanos();
        if (overlapNanos <= 0L) {
            return 0.0;
        }
        long lenANanos = Duration.between(ad.start, ad.end).toNanos();
        long unionNanos = lenANanos + (lenBNanos = Duration.between(bd.start, bd.end).toNanos()) - overlapNanos;
        if (unionNanos <= 0L) {
            return 0.0;
        }
        return (double)overlapNanos / (double)unionNanos;
    }

    private static double overlapCoefficient(Anomaly a, Anomaly b, Duration delta, Map<String, Duration> detectorIntervals, Map<String, Boolean> detectorUsesCommunityAggregation) {
        long lenBNanos;
        Interval bd;
        Duration aStart = AnomalyCorrelation.backwardDilation(a, delta, detectorIntervals, detectorUsesCommunityAggregation);
        Duration bStart = AnomalyCorrelation.backwardDilation(b, delta, detectorIntervals, detectorUsesCommunityAggregation);
        Interval ad = AnomalyCorrelation.dilate(a.getDataStartTime(), a.getDataEndTime(), aStart, delta);
        long overlapNanos = AnomalyCorrelation.overlapLength(ad, bd = AnomalyCorrelation.dilate(b.getDataStartTime(), b.getDataEndTime(), bStart, delta)).toNanos();
        if (overlapNanos <= 0L) {
            return 0.0;
        }
        long lenANanos = Duration.between(ad.start, ad.end).toNanos();
        long denom = Math.min(lenANanos, lenBNanos = Duration.between(bd.start, bd.end).toNanos());
        if (denom <= 0L) {
            return 0.0;
        }
        return (double)overlapNanos / (double)denom;
    }

    public static double durationPenalty(Anomaly a, Anomaly b, Duration kappa) {
        if (kappa == null || kappa.isZero() || kappa.isNegative()) {
            return 1.0;
        }
        long durANanos = a.getDuration().toNanos();
        long durBNanos = b.getDuration().toNanos();
        long diffNanos = Math.abs(durANanos - durBNanos);
        long kappaNanos = kappa.toNanos();
        if (kappaNanos <= 0L) {
            return 1.0;
        }
        return Math.exp(-((double)diffNanos / (double)kappaNanos));
    }

    private static double similarity(Anomaly a, Anomaly b, Duration delta, Duration kappa, TemporalMode temporalMode, double lam, double tauContain, double rhoMax, double containmentRelax, Map<String, Duration> detectorIntervals, Map<String, Boolean> detectorUsesCommunityAggregation) {
        double iou = AnomalyCorrelation.temporalIou(a, b, delta, detectorIntervals, detectorUsesCommunityAggregation);
        double ovl = AnomalyCorrelation.overlapCoefficient(a, b, delta, detectorIntervals, detectorUsesCommunityAggregation);
        Duration aStart = AnomalyCorrelation.backwardDilation(a, delta, detectorIntervals, detectorUsesCommunityAggregation);
        Duration bStart = AnomalyCorrelation.backwardDilation(b, delta, detectorIntervals, detectorUsesCommunityAggregation);
        Interval ad = AnomalyCorrelation.dilate(a.getDataStartTime(), a.getDataEndTime(), aStart, delta);
        Interval bd = AnomalyCorrelation.dilate(b.getDataStartTime(), b.getDataEndTime(), bStart, delta);
        double lenA = Duration.between(ad.start, ad.end).toNanos();
        double lenB = Duration.between(bd.start, bd.end).toNanos();
        boolean durationVeryDifferent = lenA > 0.0 && lenB > 0.0 && Math.min(lenA, lenB) / Math.max(lenA, lenB) <= rhoMax;
        boolean covered = ovl >= tauContain;
        boolean strongContainment = covered && durationVeryDifferent;
        double t = switch (temporalMode.ordinal()) {
            case 0 -> iou;
            case 1 -> ovl;
            case 2 -> strongContainment ? (1.0 - lam) * iou + lam * ovl : iou;
            default -> throw new IllegalArgumentException("temporalMode must be IOU|OVL|HYBRID");
        };
        if (t <= 0.0) {
            return 0.0;
        }
        Duration kappaEff = kappa;
        if (strongContainment) {
            Duration ia = AnomalyCorrelation.detectorIntervalForAnomaly(a, detectorIntervals);
            Duration ib = AnomalyCorrelation.detectorIntervalForAnomaly(b, detectorIntervals);
            if (!(ia == null || ib == null || ia.isZero() || ia.isNegative() || ib.isZero() || ib.isNegative())) {
                Duration maxInt;
                Duration duration = maxInt = ia.compareTo(ib) >= 0 ? ia : ib;
                if (maxInt.compareTo(kappaEff) > 0) {
                    kappaEff = maxInt;
                }
            }
        }
        double basePen = AnomalyCorrelation.durationPenalty(a, b, kappaEff);
        double pen = strongContainment ? (containmentRelax == 0.0 ? 1.0 : Math.pow(basePen, containmentRelax)) : basePen;
        return t * pen;
    }

    public static List<List<Integer>> buildThresholdGraph(List<Anomaly> anomalies, List<AnomalyDetector> detectors, Duration delta, Duration kappa, double minSimilarity, Duration minOverlap, TemporalMode temporalMode, double lam, double tauContain, double rhoMax, double containmentRelax) {
        Map<String, Duration> detectorIntervals = AnomalyCorrelation.detectorIntervalsById(detectors);
        Map<String, Boolean> detectorUsesCommunityAggregation = AnomalyCorrelation.detectorUsesCommunityAggregationById(detectors);
        return AnomalyCorrelation.buildThresholdGraph(anomalies, delta, kappa, minSimilarity, minOverlap, temporalMode, lam, tauContain, rhoMax, containmentRelax, detectorIntervals, detectorUsesCommunityAggregation);
    }

    private static List<List<Integer>> buildThresholdGraph(List<Anomaly> anomalies, Duration delta, Duration kappa, double minSimilarity, Duration minOverlap, TemporalMode temporalMode, double lam, double tauContain, double rhoMax, double containmentRelax, Map<String, Duration> detectorIntervals, Map<String, Boolean> detectorUsesCommunityAggregation) {
        Objects.requireNonNull(anomalies, "anomalies");
        Objects.requireNonNull(delta, "delta");
        Objects.requireNonNull(minOverlap, "minOverlap");
        int n = anomalies.size();
        ArrayList<List<Integer>> adj = new ArrayList<List<Integer>>(n);
        for (int i = 0; i < n; ++i) {
            adj.add(new ArrayList());
        }
        if (n <= 1) {
            return adj;
        }
        ArrayList<DilatedAnomaly> nodes = new ArrayList<DilatedAnomaly>(n);
        for (int i = 0; i < n; ++i) {
            Anomaly a = anomalies.get(i);
            nodes.add(new DilatedAnomaly(i, a, delta, detectorIntervals, detectorUsesCommunityAggregation));
        }
        nodes.sort(BY_START_THEN_END_THEN_IDX);
        TreeSet<DilatedAnomaly> active = new TreeSet<DilatedAnomaly>(BY_END_THEN_START_THEN_IDX);
        for (DilatedAnomaly dilatedAnomaly : nodes) {
            Instant requiredEnd = dilatedAnomaly.start.plus(minOverlap);
            while (!active.isEmpty() && active.first().end.isBefore(requiredEnd)) {
                active.pollFirst();
            }
            if (dilatedAnomaly.end.isBefore(requiredEnd)) continue;
            for (DilatedAnomaly prev : active) {
                double s = AnomalyCorrelation.similarity(prev.anomaly, dilatedAnomaly.anomaly, delta, kappa, temporalMode, lam, tauContain, rhoMax, containmentRelax, detectorIntervals, detectorUsesCommunityAggregation);
                if (!(s >= minSimilarity)) continue;
                int i = prev.idx;
                int j = dilatedAnomaly.idx;
                ((List)adj.get(i)).add(j);
                ((List)adj.get(j)).add(i);
            }
            active.add(dilatedAnomaly);
        }
        for (List list : adj) {
            Collections.sort(list);
        }
        return adj;
    }

    public static List<List<Integer>> connectedComponents(List<List<Integer>> adj) {
        int n = adj.size();
        boolean[] seen = new boolean[n];
        ArrayList<List<Integer>> comps = new ArrayList<List<Integer>>();
        for (int v = 0; v < n; ++v) {
            if (seen[v]) continue;
            ArrayList<Integer> comp = new ArrayList<Integer>();
            ArrayDeque<Integer> stack = new ArrayDeque<Integer>();
            stack.push(v);
            seen[v] = true;
            while (!stack.isEmpty()) {
                int x = (Integer)stack.pop();
                comp.add(x);
                for (int y : adj.get(x)) {
                    if (seen[y]) continue;
                    seen[y] = true;
                    stack.push(y);
                }
            }
            Collections.sort(comp);
            comps.add(comp);
        }
        return comps;
    }

    public static List<List<Anomaly>> cluster(List<Anomaly> anomalies, List<AnomalyDetector> detectors, Duration delta, Duration kappa, double minSimilarity, Duration minOverlap, TemporalMode temporalMode, double lam, double tauContain, double rhoMax, double containmentRelax, boolean includeSingletons) {
        List<Anomaly> dedupedAnomalies = AnomalyCorrelation.dedupe(anomalies);
        Map<String, Duration> detectorIntervals = AnomalyCorrelation.detectorIntervalsById(detectors);
        Map<String, Boolean> detectorUsesCommunityAggregation = AnomalyCorrelation.detectorUsesCommunityAggregationById(detectors);
        List<List<Integer>> adj = AnomalyCorrelation.buildThresholdGraph(dedupedAnomalies, delta, kappa, minSimilarity, minOverlap, temporalMode, lam, tauContain, rhoMax, containmentRelax, detectorIntervals, detectorUsesCommunityAggregation);
        List<List<Integer>> comps = AnomalyCorrelation.connectedComponents(adj);
        ArrayList<List<Anomaly>> out = new ArrayList<List<Anomaly>>(comps.size());
        for (List<Integer> comp : comps) {
            if (!includeSingletons && comp.size() == 1) continue;
            ArrayList<Anomaly> members = new ArrayList<Anomaly>(comp.size());
            for (int idx : comp) {
                members.add(dedupedAnomalies.get(idx));
            }
            members.sort(CLUSTER_MEMBER_ORDER);
            out.add(members);
        }
        return out;
    }

    public static List<List<Anomaly>> cluster(List<Anomaly> anomalies, List<AnomalyDetector> detectors, Duration delta, Duration kappa, double minSimilarity, Duration minOverlap, TemporalMode temporalMode, double lam, double tauContain, double rhoMax, double containmentRelax) {
        return AnomalyCorrelation.cluster(anomalies, detectors, delta, kappa, minSimilarity, minOverlap, temporalMode, lam, tauContain, rhoMax, containmentRelax, true);
    }

    private static List<Anomaly> dedupe(List<Anomaly> anomalies) {
        Objects.requireNonNull(anomalies, "anomalies");
        return new ArrayList<Anomaly>(new LinkedHashSet<Anomaly>(anomalies));
    }

    private static EventWindow eventWindowForCluster(List<Anomaly> cluster) {
        Objects.requireNonNull(cluster, "cluster");
        if (cluster.isEmpty()) {
            throw new IllegalArgumentException("cluster must not be empty");
        }
        Instant minStart = null;
        Instant maxEnd = null;
        for (Anomaly anomaly : cluster) {
            Objects.requireNonNull(anomaly, "anomaly");
            Instant start = anomaly.getDataStartTime();
            Instant end = anomaly.getDataEndTime();
            if (minStart == null || start.isBefore(minStart)) {
                minStart = start;
            }
            if (maxEnd != null && !end.isAfter(maxEnd)) continue;
            maxEnd = end;
        }
        return new EventWindow(minStart, maxEnd);
    }

    public static List<EventWindow> clusterEventWindows(List<List<Anomaly>> clusters) {
        Objects.requireNonNull(clusters, "clusters");
        ArrayList<EventWindow> windows = new ArrayList<EventWindow>(clusters.size());
        for (List<Anomaly> cluster : clusters) {
            windows.add(AnomalyCorrelation.eventWindowForCluster(cluster));
        }
        return windows;
    }

    public static List<Cluster> clusterWithEventWindows(List<Anomaly> anomalies, List<AnomalyDetector> detectors) {
        return AnomalyCorrelation.clusterWithEventWindows(anomalies, detectors, true);
    }

    public static List<Cluster> clusterWithEventWindows(List<Anomaly> anomalies, List<AnomalyDetector> detectors, boolean includeSingletons) {
        List<List<Anomaly>> clusters = AnomalyCorrelation.cluster(anomalies, detectors, includeSingletons);
        ArrayList<Cluster> out = new ArrayList<Cluster>(clusters.size());
        for (List<Anomaly> cluster : clusters) {
            out.add(new Cluster(AnomalyCorrelation.eventWindowForCluster(cluster), cluster));
        }
        return out;
    }

    public static List<List<Anomaly>> cluster(List<Anomaly> anomalies, List<AnomalyDetector> detectors) {
        return AnomalyCorrelation.cluster(anomalies, detectors, true);
    }

    public static List<List<Anomaly>> cluster(List<Anomaly> anomalies, List<AnomalyDetector> detectors, boolean includeSingletons) {
        return AnomalyCorrelation.cluster(anomalies, detectors, DELTA_TOL, KAPPA, 0.3, MIN_OVERLAP, MODE, 0.6, 0.8, 0.25, 0.45, includeSingletons);
    }

    private static final class Interval {
        final Instant start;
        final Instant end;

        Interval(Instant start, Instant end) {
            this.start = Objects.requireNonNull(start, "start");
            this.end = Objects.requireNonNull(end, "end");
        }
    }

    public static enum TemporalMode {
        IOU,
        OVL,
        HYBRID;

    }

    private static final class DilatedAnomaly {
        final int idx;
        final Anomaly anomaly;
        final Instant start;
        final Instant end;

        DilatedAnomaly(int idx, Anomaly anomaly, Duration delta, Map<String, Duration> detectorIntervals, Map<String, Boolean> detectorUsesCommunityAggregation) {
            this.idx = idx;
            this.anomaly = anomaly;
            Duration deltaStart = AnomalyCorrelation.backwardDilation(anomaly, delta, detectorIntervals, detectorUsesCommunityAggregation);
            this.start = anomaly.getDataStartTime().minus(deltaStart);
            this.end = anomaly.getDataEndTime().plus(delta);
        }
    }

    public static final class EventWindow {
        private final Instant start;
        private final Instant end;

        public EventWindow(Instant start, Instant end) {
            this.start = Objects.requireNonNull(start, "start");
            this.end = Objects.requireNonNull(end, "end");
            if (end.isBefore(start)) {
                throw new IllegalArgumentException("end must be on or after start");
            }
        }

        public Instant getStart() {
            return this.start;
        }

        public Instant getEnd() {
            return this.end;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            EventWindow that = (EventWindow)o;
            return this.start.equals(that.start) && this.end.equals(that.end);
        }

        public int hashCode() {
            return Objects.hash(this.start, this.end);
        }

        public String toString() {
            return "EventWindow{start=" + String.valueOf(this.start) + ", end=" + String.valueOf(this.end) + "}";
        }
    }

    public static final class Cluster {
        private final EventWindow eventWindow;
        private final List<Anomaly> anomalies;

        public Cluster(EventWindow eventWindow, List<Anomaly> anomalies) {
            this.eventWindow = Objects.requireNonNull(eventWindow, "eventWindow");
            this.anomalies = Objects.requireNonNull(anomalies, "anomalies");
            if (anomalies.isEmpty()) {
                throw new IllegalArgumentException("anomalies must not be empty");
            }
        }

        public EventWindow getEventWindow() {
            return this.eventWindow;
        }

        public List<Anomaly> getAnomalies() {
            return this.anomalies;
        }

        public boolean equals(Object o) {
            if (this == o) {
                return true;
            }
            if (o == null || this.getClass() != o.getClass()) {
                return false;
            }
            Cluster cluster = (Cluster)o;
            return this.eventWindow.equals(cluster.eventWindow) && this.anomalies.equals(cluster.anomalies);
        }

        public int hashCode() {
            return Objects.hash(this.eventWindow, this.anomalies);
        }

        public String toString() {
            return "Cluster{eventWindow=" + String.valueOf(this.eventWindow) + ", anomalies=" + String.valueOf(this.anomalies) + "}";
        }
    }
}

