#!/usr/bin/env python3
"""Simulate event-time tumbling windows with watermarks for Chapter 9.

This is intentionally small and dependency-free. It is not a replacement for
Kafka or Flink; it is a deterministic teaching tool that makes watermark,
lateness, and replay behavior visible from a JSONL event file.
"""

from __future__ import annotations

import argparse
import csv
from collections import defaultdict
from dataclasses import dataclass
from datetime import datetime, timezone
from decimal import Decimal
import json
from pathlib import Path


@dataclass
class MetricState:
    events: int = 0
    users: set[str] | None = None
    page_views: int = 0
    product_views: int = 0
    add_to_carts: int = 0
    purchases: int = 0
    revenue: Decimal = Decimal("0.00")

    def __post_init__(self) -> None:
        if self.users is None:
            self.users = set()


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Simulate event-time clickstream windows.")
    parser.add_argument("--input", required=True, help="Path to JSONL clickstream events.")
    parser.add_argument("--window-seconds", type=int, default=60, help="Tumbling window width in seconds.")
    parser.add_argument("--allowed-lateness-seconds", type=int, default=20, help="Allowed lateness behind max observed event time.")
    parser.add_argument("--output", required=True, help="CSV output path for window metrics.")
    parser.add_argument("--late-output", default="", help="Optional CSV output path for late events.")
    return parser.parse_args()


def parse_ts(value: str) -> datetime:
    return datetime.fromisoformat(value.replace("Z", "+00:00"))


def format_ts(ts: datetime) -> str:
    return ts.astimezone(timezone.utc).isoformat().replace("+00:00", "Z")


def floor_window(ts: datetime, window_seconds: int) -> datetime:
    epoch = int(ts.timestamp())
    floored = epoch - (epoch % window_seconds)
    return datetime.fromtimestamp(floored, tz=timezone.utc)


def read_events(path: Path) -> list[dict[str, str]]:
    with path.open("r", encoding="utf-8") as f:
        return [json.loads(line) for line in f if line.strip()]


def simulate(events: list[dict[str, str]], window_seconds: int, allowed_lateness_seconds: int) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
    states: dict[datetime, MetricState] = defaultdict(MetricState)
    late_events: list[dict[str, str]] = []
    max_event_time: datetime | None = None
    for event in sorted(events, key=lambda e: (e["arrival_time"], e["event_id"])):
        event_time = parse_ts(event["event_time"])
        max_event_time = event_time if max_event_time is None else max(max_event_time, event_time)
        watermark_epoch = max_event_time.timestamp() - allowed_lateness_seconds
        if event_time.timestamp() < watermark_epoch:
            late = dict(event)
            late["watermark_time"] = format_ts(datetime.fromtimestamp(watermark_epoch, tz=timezone.utc))
            late_events.append(late)
            continue
        window_start = floor_window(event_time, window_seconds)
        state = states[window_start]
        state.events += 1
        state.users.add(event["user_id"])
        event_type = event["event_type"]
        if event_type == "page_view":
            state.page_views += 1
        elif event_type == "product_view":
            state.product_views += 1
        elif event_type == "add_to_cart":
            state.add_to_carts += 1
        elif event_type == "purchase":
            state.purchases += 1
            state.revenue += Decimal(event["amount"])
        else:
            raise ValueError(f"Unknown event_type: {event_type}")
    rows: list[dict[str, str]] = []
    for window_start in sorted(states):
        state = states[window_start]
        rows.append({
            "window_start": format_ts(window_start),
            "window_end": format_ts(datetime.fromtimestamp(window_start.timestamp() + window_seconds, tz=timezone.utc)),
            "event_count": str(state.events),
            "unique_users": str(len(state.users or set())),
            "page_views": str(state.page_views),
            "product_views": str(state.product_views),
            "add_to_carts": str(state.add_to_carts),
            "purchases": str(state.purchases),
            "revenue": f"{state.revenue:.2f}",
        })
    return rows, late_events


def write_csv(path: Path, rows: list[dict[str, str]], fieldnames: list[str]) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with path.open("w", encoding="utf-8", newline="") as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(rows)


def main() -> None:
    args = parse_args()
    events = read_events(Path(args.input))
    metrics, late_events = simulate(events, args.window_seconds, args.allowed_lateness_seconds)
    metric_fields = ["window_start", "window_end", "event_count", "unique_users", "page_views", "product_views", "add_to_carts", "purchases", "revenue"]
    late_fields = ["event_id", "arrival_time", "event_time", "user_id", "session_id", "event_type", "page", "product_id", "amount", "watermark_time"]
    write_csv(Path(args.output), metrics, metric_fields)
    if args.late_output:
        write_csv(Path(args.late_output), late_events, late_fields)
    print(f"Processed events: {len(events)}")
    print(f"Accepted events: {sum(int(row['event_count']) for row in metrics)}")
    print(f"Late events: {len(late_events)}")
    print(f"Window rows: {len(metrics)}")
    print(f"Wrote metrics: {args.output}")
    if args.late_output:
        print(f"Wrote late events: {args.late_output}")


if __name__ == "__main__":
    main()
