#!/usr/bin/env python3
"""
Descărcare indicatori morbiditate spitalizată DRG la nivel de județ.
Sursă: http://www.drg.ro

Utilizare:
    # Descărcare perioadă completă
    python3 scripts/download_drg.py --start 01-2018 --end 04-2025

    # Update lunar (pentru cron)
    python3 scripts/download_drg.py --update

    # Forțare re-descărcare
    python3 scripts/download_drg.py --start 01-2024 --end 12-2024 --force
"""

import argparse
import calendar
import json
import logging
import os
import shutil
import ssl
import sys
import tempfile
import time
from datetime import datetime
from pathlib import Path
from urllib.request import urlopen, Request
from urllib.error import URLError, HTTPError

import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

SCRIPT_DIR = Path(__file__).resolve().parent
PROJECT_DIR = SCRIPT_DIR.parent
DATA_DIR = PROJECT_DIR / "data" / "health" / "drg"
RAW_DIR = DATA_DIR / "raw"
STATE_FILE = DATA_DIR / "_state.json"

BASE_URL = "http://drg.ro/inc/{year}/{mm}_{year}/DRG/02_Judet_pacient/IM_DRG___{county}___{start}_{end}.xls"

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
log = logging.getLogger("drg")

# ---------------------------------------------------------------------------
# County mapping: DRG name → natcode (numeric ID used in TEMPO/geometries)
# ---------------------------------------------------------------------------

COUNTY_MAP = {
    "ALBA": "3064",
    "ARAD": "3065",
    "ARGES": "3066",
    "BACAU": "3067",
    "BIHOR": "3068",
    "BISTRITA-NASAUD": "3069",
    "BOTOSANI": "3070",
    "BRAILA": "3072",
    "BRASOV": "3071",
    "BUCURESTI": "3104",
    "BUZAU": "3073",
    "CALARASI": "3105",
    "CARAS-SEVERIN": "3074",
    "CLUJ": "3075",
    "CONSTANTA": "3076",
    "COVASNA": "3077",
    "DIMBOVITA": "3078",
    "DOLJ": "3079",
    "GALATI": "3080",
    "GIURGIU": "3106",
    "GORJ": "3081",
    "HARGHITA": "3082",
    "HUNEDOARA": "3083",
    "IALOMITA": "3084",
    "IASI": "3085",
    "ILFOV": "3086",
    "MARAMURES": "3087",
    "MEHEDINTI": "3088",
    "MURES": "3089",
    "NEAMT": "3090",
    "OLT": "3091",
    "PRAHOVA": "3092",
    "SALAJ": "3094",
    "SATU MARE": "3093",
    "SIBIU": "3095",
    "SUCEAVA": "3096",
    "TELEORMAN": "3097",
    "TIMIS": "3098",
    "TULCEA": "3099",
    "VASLUI": "3100",
    "VILCEA": "3101",
    "VRANCEA": "3102",
}

# Output column names (logical names we use internally)
OUTPUT_FIELDS = [
    "cod", "denumire", "tip", "vr",
    "cazuri_total", "cazuri_pct", "cazuri_acuti", "cazuri_cronici",
    "zile_total", "zile_acuti", "zile_cronici",
    "dms_acuti", "dms_cronici",
]


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def format_date(year: int, month: int, day: int) -> str:
    """Format date for URL — no padding ≤2020, zero-padded ≥2021."""
    if year <= 2020:
        return f"{day}.{month}.{year}"
    else:
        return f"{day:02d}.{month:02d}.{year}"


def build_url(year: int, month: int, county: str) -> str:
    """Build the download URL for a given year/month/county."""
    last_day = calendar.monthrange(year, month)[1]
    start = format_date(year, month, 1)
    end = format_date(year, month, last_day)
    county_url = county.replace(" ", "%20")
    return BASE_URL.format(
        year=year,
        mm=f"{month:02d}",
        county=county_url,
        start=start,
        end=end,
    )


def download_file(url: str, dest: Path, retries: int = 3) -> bool:
    """Download a file with retries. Returns True on success."""
    # Disable SSL verification (drg.ro has certificate issues)
    ctx = ssl.create_default_context()
    ctx.check_hostname = False
    ctx.verify_mode = ssl.CERT_NONE

    for attempt in range(1, retries + 1):
        try:
            req = Request(url, headers={"User-Agent": "Mozilla/5.0"})
            with urlopen(req, timeout=30, context=ctx) as resp:
                data = resp.read()
            if len(data) < 500:
                log.warning("File too small (%d bytes), may be invalid: %s", len(data), url)
                return False
            dest.parent.mkdir(parents=True, exist_ok=True)
            # Atomic write for SSHFS safety
            tmp_fd, tmp_path = tempfile.mkstemp(
                dir=str(dest.parent), suffix=".tmp"
            )
            try:
                os.write(tmp_fd, data)
                os.close(tmp_fd)
                shutil.move(tmp_path, str(dest))
            except Exception:
                try:
                    os.close(tmp_fd)
                except OSError:
                    pass
                if os.path.exists(tmp_path):
                    os.unlink(tmp_path)
                raise
            return True
        except (URLError, HTTPError, OSError) as e:
            log.warning("Attempt %d/%d failed for %s: %s", attempt, retries, url, e)
            if attempt < retries:
                time.sleep(2 * attempt)
    return False


def _detect_columns(df_raw: pd.DataFrame) -> dict:
    """
    Auto-detect column positions by finding the 'Nr.crt' header cell.
    The XLS layout shifted between years but the relative order is stable:
      nr_crt, [gap], cod, denumire, tip, vr, cazuri_total, cazuri_pct,
      cazuri_acuti, cazuri_cronici, zile_total, [gap?], zile_acuti,
      zile_cronici, dms_acuti, dms_cronici

    We locate Nr.crt, then find Cod (which contains DRG codes like E3061),
    and build the map from there.
    """
    # Step 1: Find the Nr.crt cell
    nr_crt_col = None
    nr_crt_row = None
    for r in range(min(15, len(df_raw))):
        for c in range(min(10, df_raw.shape[1])):
            v = df_raw.iloc[r, c]
            if isinstance(v, str) and "Nr.crt" in v:
                nr_crt_col = c
                nr_crt_row = r
                break
        if nr_crt_col is not None:
            break

    if nr_crt_col is None:
        log.warning("Cannot find 'Nr.crt' header")
        return {}

    # Step 2: Find the Cod column — should be near Nr.crt and contain
    # the 'Cod' label in the row after nr_crt_row
    cod_col = None
    for r in range(nr_crt_row, min(nr_crt_row + 3, len(df_raw))):
        for c in range(nr_crt_col + 1, min(nr_crt_col + 6, df_raw.shape[1])):
            v = df_raw.iloc[r, c]
            if isinstance(v, str) and v.strip() == "Cod":
                cod_col = c
                break
        if cod_col is not None:
            break

    if cod_col is None:
        cod_col = nr_crt_col + 2  # fallback

    # Step 3: Validate by checking the first data row has a DRG code
    # Data starts right after the header block
    data_start = None
    for r in range(nr_crt_row + 1, min(nr_crt_row + 10, len(df_raw))):
        v = df_raw.iloc[r, nr_crt_col]
        if isinstance(v, (int, float)) and pd.notna(v) and v == int(v) and int(v) > 0:
            data_start = r
            break

    if data_start is None:
        log.warning("Cannot find first data row")
        return {}

    # Verify cod_col has a DRG code
    test_cod = df_raw.iloc[data_start, cod_col]
    if not (isinstance(test_cod, str) and len(test_cod) > 2):
        # Try adjacent columns
        for offset in [-1, 1, -2, 2]:
            tc = cod_col + offset
            if 0 <= tc < df_raw.shape[1]:
                v = df_raw.iloc[data_start, tc]
                if isinstance(v, str) and len(v) > 2 and v[0].isalpha():
                    cod_col = tc
                    break

    # Step 4: Build column map. After Cod the order is always:
    # denumire, tip, vr, cazuri_total, cazuri_pct, cazuri_acuti, cazuri_cronici,
    # zile_total, [maybe gap], zile_acuti, zile_cronici, dms_acuti, dms_cronici
    #
    # DMS is always the last 2 columns.
    col_map = {
        "nr_crt": nr_crt_col,
        "cod": cod_col,
        "denumire": cod_col + 1,
        "tip": cod_col + 2,
        "vr": cod_col + 3,
        "cazuri_total": cod_col + 4,
        "cazuri_pct": cod_col + 5,
        "cazuri_acuti": cod_col + 6,
        "cazuri_cronici": cod_col + 7,
        "zile_total": cod_col + 8,
    }

    # Check if there's a gap column between zile_total and zile_acuti
    # by testing if col (cod+9) has data in the first data row
    zile_next = cod_col + 9
    if zile_next < df_raw.shape[1]:
        v = df_raw.iloc[data_start, zile_next]
        if pd.isna(v):
            # Gap exists — skip one column
            col_map["zile_acuti"] = cod_col + 10
            col_map["zile_cronici"] = cod_col + 11
        else:
            col_map["zile_acuti"] = cod_col + 9
            col_map["zile_cronici"] = cod_col + 10

    # DMS is always the last two meaningful columns
    col_map["dms_acuti"] = df_raw.shape[1] - 2
    col_map["dms_cronici"] = df_raw.shape[1] - 1

    log.debug("Detected columns from data row %d: %s", data_start, col_map)
    return col_map


def parse_xls(path: Path) -> pd.DataFrame:
    """Parse a DRG XLS file and return clean DataFrame with adaptive layout."""
    df_raw = pd.read_excel(path, header=None, engine="xlrd")

    # Auto-detect column layout
    col_map = _detect_columns(df_raw)
    if not col_map:
        log.warning("Cannot detect column layout in %s", path)
        return pd.DataFrame()

    nr_crt_col = col_map["nr_crt"]

    # Find data rows: nr_crt column must contain positive integers
    mask = df_raw.iloc[:, nr_crt_col].apply(
        lambda v: isinstance(v, (int, float)) and pd.notna(v) and float(v) == int(v) and int(v) > 0
    )
    df_data = df_raw.loc[mask].copy()

    if df_data.empty:
        return pd.DataFrame()

    # Extract columns by detected positions
    result = pd.DataFrame()
    for field in OUTPUT_FIELDS:
        idx = col_map.get(field)
        if idx is not None and idx < df_data.shape[1]:
            result[field] = df_data.iloc[:, idx].values
        else:
            result[field] = pd.NA

    # Clean types
    result["cod"] = result["cod"].astype(str).str.strip()
    result["denumire"] = result["denumire"].astype(str).str.strip()
    result["tip"] = result["tip"].astype(str).str.strip()

    numeric_cols = [
        "vr", "cazuri_total", "cazuri_pct", "cazuri_acuti", "cazuri_cronici",
        "zile_total", "zile_acuti", "zile_cronici", "dms_acuti", "dms_cronici",
    ]
    for col in numeric_cols:
        result[col] = pd.to_numeric(result[col], errors="coerce")

    # Drop rows without a valid DRG code (pattern: letter + digits)
    result = result[
        result["cod"].str.match(r"^[A-Z]\d+", na=False)
    ].copy()

    return result


def load_state() -> dict:
    """Load processing state."""
    if STATE_FILE.exists():
        try:
            with open(STATE_FILE, "r") as f:
                return json.load(f)
        except (json.JSONDecodeError, OSError):
            pass
    return {"downloaded": [], "processed": []}


def save_state(state: dict):
    """Save processing state atomically."""
    STATE_FILE.parent.mkdir(parents=True, exist_ok=True)
    tmp = STATE_FILE.with_suffix(".tmp")
    with open(tmp, "w") as f:
        json.dump(state, f, indent=2)
    shutil.move(str(tmp), str(STATE_FILE))


def month_key(year: int, month: int) -> str:
    return f"{year}_{month:02d}"


def generate_month_range(start_month: int, start_year: int,
                         end_month: int, end_year: int) -> list:
    """Generate list of (year, month) tuples."""
    result = []
    y, m = start_year, start_month
    while (y, m) <= (end_year, end_month):
        result.append((y, m))
        m += 1
        if m > 12:
            m = 1
            y += 1
    return result


def write_parquet_safe(df: pd.DataFrame, path: Path):
    """Write parquet atomically (SSHFS safe)."""
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(".parquet.tmp")
    df.to_parquet(str(tmp), engine="pyarrow", index=False)
    shutil.move(str(tmp), str(path))


# ---------------------------------------------------------------------------
# Main logic
# ---------------------------------------------------------------------------


def download_month(year: int, month: int, force: bool = False) -> dict:
    """
    Download XLS files for all counties for a given month.
    Returns dict: county_name → Path to downloaded file.
    """
    mk = month_key(year, month)
    month_dir = RAW_DIR / str(year) / f"{month:02d}"
    downloaded = {}

    for county in COUNTY_MAP:
        fname = f"IM_DRG_{county}_{mk}.xls"
        dest = month_dir / fname

        if dest.exists() and dest.stat().st_size > 500 and not force:
            log.debug("Already exists: %s", dest)
            downloaded[county] = dest
            continue

        url = build_url(year, month, county)
        log.info("Downloading %s %s-%02d ...", county, year, month)

        if download_file(url, dest):
            downloaded[county] = dest
        else:
            log.error("Failed: %s", url)

        # Be polite
        time.sleep(0.3)

    return downloaded


def process_month(year: int, month: int, downloaded: dict) -> dict:
    """
    Parse all XLS files for a month, return dict: drg_code → list of records.
    Each record has natcode + all indicator values.
    """
    mk = month_key(year, month)
    all_records = {}  # drg_code → list of {natcode, ...}

    for county, path in downloaded.items():
        natcode = COUNTY_MAP[county]
        try:
            df = parse_xls(path)
        except Exception as e:
            log.error("Parse error %s: %s", path, e)
            continue

        if df.empty:
            log.warning("No data in %s", path)
            continue

        for _, row in df.iterrows():
            cod = row["cod"]
            if cod not in all_records:
                all_records[cod] = {
                    "denumire": row["denumire"],
                    "tip": row["tip"],
                    "records": [],
                }
            all_records[cod]["records"].append({
                "natcode": natcode,
                "county": county,
                f"vr_{mk}": row["vr"],
                f"cazuri_total_{mk}": row["cazuri_total"],
                f"cazuri_pct_{mk}": row["cazuri_pct"],
                f"cazuri_acuti_{mk}": row["cazuri_acuti"],
                f"cazuri_cronici_{mk}": row["cazuri_cronici"],
                f"zile_total_{mk}": row["zile_total"],
                f"zile_acuti_{mk}": row["zile_acuti"],
                f"zile_cronici_{mk}": row["zile_cronici"],
                f"dms_acuti_{mk}": row["dms_acuti"],
                f"dms_cronici_{mk}": row["dms_cronici"],
            })

    return all_records


def _read_parquet_safe(path: Path) -> pd.DataFrame:
    """Read a parquet file via memory buffer (avoids SSHFS fcntl issues)."""
    import io
    with open(path, "rb") as f:
        buf = io.BytesIO(f.read())
    return pd.read_parquet(buf, engine="pyarrow")


def merge_into_parquet(month_data: dict, month_keys_processed: list):
    """
    Merge month data into existing Parquet files.
    For each DRG code, load existing parquet (if any), merge new columns, save.
    """
    codes_written = 0

    for cod, info in month_data.items():
        parquet_path = DATA_DIR / f"{cod}_county.parquet"
        records = info["records"]

        # Build DataFrame for this month
        new_df = pd.DataFrame(records)
        if new_df.empty:
            continue

        # Group by natcode (in case of duplicates, take first)
        new_df = new_df.groupby("natcode", as_index=False).first()

        if parquet_path.exists():
            try:
                existing = _read_parquet_safe(parquet_path)
                # Merge: keep all existing columns, add new temporal columns
                merge_cols = [c for c in new_df.columns if c not in ("natcode", "county")]
                new_indexed = new_df.set_index("natcode")
                for col in merge_cols:
                    mapped = existing["natcode"].map(new_indexed[col])
                    if col in existing.columns:
                        existing[col] = mapped.combine_first(existing[col])
                    else:
                        existing[col] = mapped
                # Add any new counties not in existing
                new_natcodes = set(new_df["natcode"]) - set(existing["natcode"])
                if new_natcodes:
                    extra = new_df[new_df["natcode"].isin(new_natcodes)].drop(
                        columns=["county"], errors="ignore"
                    )
                    existing = pd.concat([existing, extra], ignore_index=True)
                result = existing
            except Exception as e:
                log.warning("Could not read existing %s, overwriting: %s", parquet_path, e)
                result = new_df.drop(columns=["county"], errors="ignore")
        else:
            result = new_df.drop(columns=["county"], errors="ignore")

        # Ensure natcode is string
        result["natcode"] = result["natcode"].astype(str)

        # Sort columns: natcode first, then sorted temporal cols
        fixed_cols = ["natcode"]
        temporal_cols = sorted([c for c in result.columns if c != "natcode"])
        result = result[fixed_cols + temporal_cols]

        write_parquet_safe(result, parquet_path)
        codes_written += 1

    log.info("Written/updated %d Parquet files", codes_written)
    return codes_written


def generate_catalog():
    """Generate catalog.json from existing Parquet files."""
    catalog = {"datasets": []}
    parquet_files = sorted(DATA_DIR.glob("*_county.parquet"))

    for pf in parquet_files:
        code = pf.stem.replace("_county", "")
        try:
            df = _read_parquet_safe(pf)
        except Exception as e:
            log.warning("Cannot read %s for catalog: %s", pf, e)
            continue

        # Extract temporal values from column names
        temporal_values = sorted(set(
            c.rsplit("_", 2)[-2] + "_" + c.rsplit("_", 1)[-1]
            for c in df.columns
            if c.startswith("cazuri_total_")
        ))
        # Fix: temporal values are already like 2024_01 from col cazuri_total_2024_01
        temporal_values = sorted(set(
            "_".join(c.split("_")[2:])
            for c in df.columns
            if c.startswith("cazuri_total_")
        ))

        # Try to get denumire from a column if it exists, otherwise from data
        denumire = code  # fallback

        entry = {
            "id": code,
            "name": f"DRG {code}",
            "description": "Indicatori morbiditate spitalizată - grupa de diagnostic",
            "source": "DRG Romania - CCEASS",
            "source_url": "http://www.drg.ro",
            "license": "Open Data",
            "license_url": "",
            "acquisition": datetime.now().strftime("%d-%m-%Y"),
            "geometry_levels": ["county"],
            "temporal": {
                "type": "month",
                "values": temporal_values,
            },
            "join_key": "natcode",
            "columns": [
                {"id": "vr", "name": "Valoare relativă", "type": "float"},
                {"id": "cazuri_total", "name": "Total cazuri", "type": "int"},
                {"id": "cazuri_pct", "name": "Procent cazuri", "type": "float", "unit": "%"},
                {"id": "cazuri_acuti", "name": "Cazuri secții acuți", "type": "int"},
                {"id": "cazuri_cronici", "name": "Cazuri secții cronici", "type": "int"},
                {"id": "zile_total", "name": "Total zile spitalizare", "type": "int"},
                {"id": "zile_acuti", "name": "Zile secții acuți", "type": "int"},
                {"id": "zile_cronici", "name": "Zile secții cronici", "type": "int"},
                {"id": "dms_acuti", "name": "DMS secții acuți", "type": "float", "unit": "zile"},
                {"id": "dms_cronici", "name": "DMS secții cronici", "type": "float", "unit": "zile"},
            ],
            "category": "SANATATE",
            "subcategory": "Morbiditate spitalizată DRG",
            "periodicity": "Lunară",
        }
        catalog["datasets"].append(entry)

    catalog_path = DATA_DIR / "catalog.json"
    tmp = catalog_path.with_suffix(".tmp")
    with open(tmp, "w", encoding="utf-8") as f:
        json.dump(catalog, f, indent=2, ensure_ascii=False)
    shutil.move(str(tmp), str(catalog_path))

    log.info("Catalog written with %d datasets", len(catalog["datasets"]))


def build_drg_names_map():
    """
    After first processing, build a mapping of DRG code → denumire
    from raw data, and store alongside catalog.
    """
    names = {}
    parquet_files = sorted(DATA_DIR.glob("*_county.parquet"))
    # We need to get names from the raw XLS data; store them during processing
    names_file = DATA_DIR / "_drg_names.json"
    if names_file.exists():
        with open(names_file) as f:
            return json.load(f)
    return names


def save_drg_names(names: dict):
    """Persist DRG code → name mapping."""
    names_file = DATA_DIR / "_drg_names.json"
    tmp = names_file.with_suffix(".tmp")
    with open(tmp, "w", encoding="utf-8") as f:
        json.dump(names, f, indent=2, ensure_ascii=False)
    shutil.move(str(tmp), str(names_file))


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------


def parse_month_arg(s: str) -> tuple:
    """Parse MM-YYYY argument into (month, year)."""
    parts = s.split("-")
    if len(parts) != 2:
        raise argparse.ArgumentTypeError(f"Format invalid: {s}. Folosește MM-YYYY")
    try:
        m, y = int(parts[0]), int(parts[1])
    except ValueError:
        raise argparse.ArgumentTypeError(f"Format invalid: {s}. Folosește MM-YYYY")
    if m < 1 or m > 12:
        raise argparse.ArgumentTypeError(f"Luna invalidă: {m}")
    if y < 2012:
        raise argparse.ArgumentTypeError(f"Anul minim suportat: 2012")
    return (m, y)


def main():
    parser = argparse.ArgumentParser(
        description="Descărcare date DRG morbiditate spitalizată la nivel de județ"
    )
    parser.add_argument(
        "--start", type=parse_month_arg, metavar="MM-YYYY",
        help="Luna de start (ex: 01-2018)"
    )
    parser.add_argument(
        "--end", type=parse_month_arg, metavar="MM-YYYY",
        help="Luna de final (ex: 04-2025)"
    )
    parser.add_argument(
        "--update", action="store_true",
        help="Descarcă automat ultima lună disponibilă (pentru cron)"
    )
    parser.add_argument(
        "--force", action="store_true",
        help="Forțează re-descărcarea chiar dacă fișierele există"
    )
    parser.add_argument(
        "--download-only", action="store_true",
        help="Doar descarcă XLS-urile, fără procesare Parquet"
    )
    parser.add_argument(
        "--verbose", "-v", action="store_true",
        help="Logging detaliat"
    )

    args = parser.parse_args()

    if args.verbose:
        log.setLevel(logging.DEBUG)

    # Determine month range
    if args.update:
        now = datetime.now()
        # Previous month
        if now.month == 1:
            end_year, end_month = now.year - 1, 12
        else:
            end_year, end_month = now.year, now.month - 1
        months = [(end_year, end_month)]
        log.info("Update mode: processing %d-%02d", end_year, end_month)
    elif args.start and args.end:
        sm, sy = args.start
        em, ey = args.end
        months = generate_month_range(sm, sy, em, ey)
        log.info("Range: %02d-%d to %02d-%d (%d months)",
                 sm, sy, em, ey, len(months))
    else:
        parser.error("Specifică --start și --end, sau --update")
        return

    # Load state for resume
    state = load_state()
    drg_names = build_drg_names_map()

    DATA_DIR.mkdir(parents=True, exist_ok=True)

    total_months = len(months)
    for i, (year, month) in enumerate(months, 1):
        mk = month_key(year, month)
        log.info("=== [%d/%d] Processing %s ===", i, total_months, mk)

        # Check resume
        if mk in state.get("processed", []) and not args.force:
            log.info("Already processed %s, skipping (use --force to reprocess)", mk)
            continue

        # Download
        downloaded = download_month(year, month, force=args.force)
        log.info("Downloaded %d/%d counties for %s",
                 len(downloaded), len(COUNTY_MAP), mk)

        if not downloaded:
            log.error("No files downloaded for %s, skipping", mk)
            continue

        # Mark as downloaded
        if mk not in state.get("downloaded", []):
            state.setdefault("downloaded", []).append(mk)
            save_state(state)

        if args.download_only:
            continue

        # Process
        month_data = process_month(year, month, downloaded)
        log.info("Parsed %d DRG codes for %s", len(month_data), mk)

        # Collect names
        for cod, info in month_data.items():
            if cod not in drg_names:
                drg_names[cod] = info.get("denumire", cod)

        # Merge into Parquet
        merge_into_parquet(month_data, [mk])

        # Mark as processed
        state.setdefault("processed", []).append(mk)
        save_state(state)

    # Save names and generate catalog
    if not args.download_only:
        save_drg_names(drg_names)

        # Update catalog with proper names
        generate_catalog()

        # Patch catalog with real names
        catalog_path = DATA_DIR / "catalog.json"
        if catalog_path.exists() and drg_names:
            with open(catalog_path) as f:
                cat = json.load(f)
            for ds in cat["datasets"]:
                if ds["id"] in drg_names:
                    ds["name"] = drg_names[ds["id"]]
            with open(catalog_path, "w", encoding="utf-8") as f:
                json.dump(cat, f, indent=2, ensure_ascii=False)

    log.info("Done! Processed %d months total.", total_months)


if __name__ == "__main__":
    main()
