#!/usr/bin/env python3
"""
Prophet-based anomaly detector for VictoriaMetrics.
Reads metrics, trains Prophet model, detects anomalies, writes results back.
"""

import os
import sys
import logging
from datetime import datetime, timedelta
from typing import Optional

import pandas as pd
import requests
from prophet import Prophet

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Suppress Prophet's verbose logging
logging.getLogger('prophet').setLevel(logging.WARNING)
logging.getLogger('cmdstanpy').setLevel(logging.WARNING)


class VMClient:
    """VictoriaMetrics client for reading and writing metrics."""

    def __init__(self, read_url: str, write_url: Optional[str] = None):
        self.read_url = read_url.rstrip('/')
        self.write_url = (write_url or read_url).rstrip('/')

    def query_range(self, query: str, start: datetime, end: datetime, step: str = "5m") -> pd.DataFrame:
        """Query VictoriaMetrics for time series data."""
        url = f"{self.read_url}/api/v1/query_range"
        params = {
            "query": query,
            "start": int(start.timestamp()),
            "end": int(end.timestamp()),
            "step": step
        }

        response = requests.get(url, params=params, timeout=60)
        response.raise_for_status()
        data = response.json()

        if data["status"] != "success":
            raise ValueError(f"Query failed: {data.get('error', 'Unknown error')}")

        results = data["data"]["result"]
        if not results:
            return pd.DataFrame()

        # Parse first result (assuming single time series)
        values = results[0]["values"]
        df = pd.DataFrame(values, columns=["timestamp", "value"])
        df["ds"] = pd.to_datetime(df["timestamp"], unit="s")
        df["y"] = df["value"].astype(float)

        return df[["ds", "y"]]

    def write_metrics(self, lines: list[str]):
        """Write metrics to VictoriaMetrics in Prometheus format."""
        if not lines:
            return

        url = f"{self.write_url}/api/v1/import/prometheus"
        data = "\n".join(lines)

        response = requests.post(url, data=data, timeout=30)
        response.raise_for_status()
        logger.info(f"Wrote {len(lines)} metrics to VictoriaMetrics")


class ProphetAnomalyDetector:
    """Prophet-based anomaly detector."""

    def __init__(self, interval_width: float = 0.95, seasonality_mode: str = "multiplicative"):
        self.interval_width = interval_width
        self.seasonality_mode = seasonality_mode

    def detect(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Detect anomalies using Prophet.

        Returns DataFrame with columns:
        - ds: timestamp
        - y: actual value
        - yhat: predicted value
        - yhat_lower: lower bound
        - yhat_upper: upper bound
        - anomaly: 1 if anomaly, 0 otherwise
        - anomaly_score: how far outside bounds (0 if not anomaly)
        """
        if len(df) < 10:
            logger.warning("Not enough data points for Prophet (need at least 10)")
            return df.assign(yhat=df["y"], yhat_lower=df["y"], yhat_upper=df["y"], anomaly=0, anomaly_score=0)

        # Train Prophet model
        model = Prophet(
            interval_width=self.interval_width,
            seasonality_mode=self.seasonality_mode,
            daily_seasonality=True,
            weekly_seasonality=True,
            yearly_seasonality=False,
        )

        # Fit model
        model.fit(df)

        # Predict on same data (for anomaly detection, not forecasting)
        forecast = model.predict(df)

        # Merge with original data
        result = df.copy()
        result["yhat"] = forecast["yhat"].values
        result["yhat_lower"] = forecast["yhat_lower"].values
        result["yhat_upper"] = forecast["yhat_upper"].values

        # Detect anomalies (outside confidence interval)
        result["anomaly"] = ((result["y"] < result["yhat_lower"]) |
                             (result["y"] > result["yhat_upper"])).astype(int)

        # Calculate anomaly score (distance from bounds, normalized)
        result["anomaly_score"] = 0.0
        mask_low = result["y"] < result["yhat_lower"]
        mask_high = result["y"] > result["yhat_upper"]

        range_size = result["yhat_upper"] - result["yhat_lower"]
        range_size = range_size.replace(0, 1)  # Avoid division by zero

        result.loc[mask_low, "anomaly_score"] = (
            (result.loc[mask_low, "yhat_lower"] - result.loc[mask_low, "y"]) / range_size[mask_low]
        )
        result.loc[mask_high, "anomaly_score"] = (
            (result.loc[mask_high, "y"] - result.loc[mask_high, "yhat_upper"]) / range_size[mask_high]
        )

        return result


def generate_metrics(metric_name: str, labels: dict, result: pd.DataFrame) -> list[str]:
    """Generate Prometheus-format metrics from detection results."""
    lines = []
    label_str = ",".join(f'{k}="{v}"' for k, v in labels.items())

    # Only write recent anomalies (last hour)
    cutoff = datetime.now() - timedelta(hours=1)
    recent = result[result["ds"] >= cutoff]

    for _, row in recent.iterrows():
        ts_ms = int(row["ds"].timestamp() * 1000)

        # Prediction metrics
        lines.append(f'prophet_prediction{{{label_str},metric_name="{metric_name}"}} {row["yhat"]:.6f} {ts_ms}')
        lines.append(f'prophet_lower_bound{{{label_str},metric_name="{metric_name}"}} {row["yhat_lower"]:.6f} {ts_ms}')
        lines.append(f'prophet_upper_bound{{{label_str},metric_name="{metric_name}"}} {row["yhat_upper"]:.6f} {ts_ms}')

        # Anomaly indicator (always write, 0 or 1)
        lines.append(f'prophet_anomaly{{{label_str},metric_name="{metric_name}"}} {row["anomaly"]} {ts_ms}')
        lines.append(f'prophet_anomaly_score{{{label_str},metric_name="{metric_name}"}} {row["anomaly_score"]:.6f} {ts_ms}')

    return lines


def main():
    # Configuration from environment
    vm_read_url = os.getenv("VM_READ_URL", "http://vmsingle-vm-k8s-stack.kube-monitoring:8428")
    vm_write_url = os.getenv("VM_WRITE_URL", vm_read_url)
    lookback_days = int(os.getenv("LOOKBACK_DAYS", "7"))
    step = os.getenv("STEP", "5m")
    interval_width = float(os.getenv("INTERVAL_WIDTH", "0.95"))

    # Queries to analyze (format: name=query)
    queries_raw = os.getenv("QUERIES", "")

    if not queries_raw:
        # Default queries for common K8s metrics
        queries = {
            "cpu_usage": 'sum(rate(container_cpu_usage_seconds_total{container!=""}[5m])) by (namespace)',
            "memory_usage": 'sum(container_memory_working_set_bytes{container!=""}) by (namespace)',
            "http_requests": 'sum(rate(http_requests_total[5m]))',
        }
    else:
        queries = {}
        for item in queries_raw.split(";"):
            if "=" in item:
                name, query = item.split("=", 1)
                queries[name.strip()] = query.strip()

    logger.info(f"Prophet Anomaly Detector starting")
    logger.info(f"VM Read URL: {vm_read_url}")
    logger.info(f"VM Write URL: {vm_write_url}")
    logger.info(f"Lookback: {lookback_days} days, Step: {step}")
    logger.info(f"Queries to analyze: {list(queries.keys())}")

    # Initialize clients
    vm_client = VMClient(vm_read_url, vm_write_url)
    detector = ProphetAnomalyDetector(interval_width=interval_width)

    # Time range
    end = datetime.now()
    start = end - timedelta(days=lookback_days)

    all_metrics = []

    for name, query in queries.items():
        logger.info(f"Processing metric: {name}")

        try:
            # Query data
            df = vm_client.query_range(query, start, end, step)

            if df.empty:
                logger.warning(f"No data for query: {name}")
                continue

            logger.info(f"Got {len(df)} data points for {name}")

            # Detect anomalies
            result = detector.detect(df)

            # Count anomalies
            anomaly_count = result["anomaly"].sum()
            logger.info(f"Detected {anomaly_count} anomalies for {name}")

            # Generate metrics
            metrics = generate_metrics(name, {"source": "prophet"}, result)
            all_metrics.extend(metrics)

        except Exception as e:
            logger.error(f"Error processing {name}: {e}")
            continue

    # Write all metrics
    if all_metrics:
        try:
            vm_client.write_metrics(all_metrics)
        except Exception as e:
            logger.error(f"Error writing metrics: {e}")
            sys.exit(1)

    logger.info("Prophet Anomaly Detector completed successfully")


if __name__ == "__main__":
    main()
