"""
Validare seturi de date — verifică consistența între date statistice,
geometrii și catalog.
"""
import os
from pathlib import Path

import pandas as pd
import geopandas as gpd
from rich.console import Console
from rich.table import Table

from datahub_ingest.catalog import load_catalog

console = Console()


def validate_dataset(
    dataset_id: str,
    level: str | None = None,
    catalog_path: str = "../data/catalog.json",
    stats_dir: str = "../data/statistics",
    geom_dir: str = "../data/geometries",
):
    """
    Validează un set de date:
    1. Verifică existența în catalog
    2. Verifică existența fișierelor Parquet
    3. Verifică join keys — match cu geometriile
    4. Verifică tipuri de date
    5. Raportează completitudinea
    """
    catalog = load_catalog(catalog_path)
    dataset = next((d for d in catalog["datasets"] if d["id"] == dataset_id), None)

    if not dataset:
        console.print(f"[red]✕ Setul de date '{dataset_id}' nu există în catalog![/red]")
        return

    console.print(f"\n[bold cyan]═══ Validare: {dataset['name']} ═══[/bold cyan]")

    levels_to_check = [level] if level else dataset.get("geometry_levels", [])
    join_key = dataset["join_key"]
    all_ok = True

    for lev in levels_to_check:
        console.print(f"\n[bold]Nivel: {lev}[/bold]")

        # Check stats file
        stats_path = os.path.join(stats_dir, f"{dataset_id}_{lev}.parquet")
        if not os.path.exists(stats_path):
            console.print(f"  [red]✕ Fișier statistic lipsă:[/red] {stats_path}")
            all_ok = False
            continue

        stats_df = pd.read_parquet(stats_path)
        console.print(f"  → Date statistice: {len(stats_df)} rânduri, {len(stats_df.columns)} coloane")

        # Check join key in stats
        if join_key not in stats_df.columns:
            console.print(f"  [red]✕ Coloana de join '{join_key}' lipsă din date statistice![/red]")
            all_ok = False
            continue

        # Check geometry file
        geom_path = os.path.join(geom_dir, f"{lev}.parquet")
        if not os.path.exists(geom_path):
            console.print(f"  [yellow]⚠ Geometrie lipsă:[/yellow] {geom_path}")
            console.print("    → Rulați: datahub-ingest geometry prepare-all")
            all_ok = False
            continue

        geom_df = gpd.read_parquet(geom_path)
        console.print(f"  → Geometrii: {len(geom_df)} entități")

        # Check join key in geometry
        if join_key not in geom_df.columns:
            console.print(f"  [red]✕ Coloana de join '{join_key}' lipsă din geometrii![/red]")
            all_ok = False
            continue

        # Compare join keys
        stats_keys = set(stats_df[join_key].astype(str))
        geom_keys = set(geom_df[join_key].astype(str))

        matched = stats_keys & geom_keys
        stats_only = stats_keys - geom_keys
        geom_only = geom_keys - stats_keys

        console.print(f"  [green]✓ Join match:[/green] {len(matched)}/{len(geom_keys)} geometrii acoperite")

        if stats_only:
            console.print(
                f"  [yellow]⚠ {len(stats_only)} chei în statistici fără geometrie:[/yellow] "
                f"{', '.join(sorted(stats_only)[:5])}"
                f"{'...' if len(stats_only) > 5 else ''}"
            )

        if geom_only:
            console.print(
                f"  [yellow]⚠ {len(geom_only)} geometrii fără date statistice:[/yellow] "
                f"{', '.join(sorted(geom_only)[:5])}"
                f"{'...' if len(geom_only) > 5 else ''}"
            )

        # Check for nulls in data columns
        data_cols = [c for c in stats_df.columns if c != join_key]
        null_report = []
        for col in data_cols:
            null_count = stats_df[col].isna().sum()
            if null_count > 0:
                null_report.append((col, null_count, len(stats_df)))

        if null_report:
            console.print(f"  [yellow]⚠ Valori lipsă:[/yellow]")
            for col, nulls, total in null_report:
                pct = (nulls / total) * 100
                console.print(f"    {col}: {nulls}/{total} ({pct:.1f}%)")

        # Check column types match catalog
        catalog_cols = {c["id"]: c["type"] for c in dataset.get("columns", [])}
        for col in data_cols:
            if col in catalog_cols:
                expected = catalog_cols[col]
                actual = stats_df[col].dtype
                if expected == "integer" and not pd.api.types.is_integer_dtype(actual):
                    console.print(f"  [yellow]⚠ Tip incorect pentru '{col}': așteptat {expected}, găsit {actual}[/yellow]")
                elif expected == "float" and not pd.api.types.is_float_dtype(actual):
                    console.print(f"  [yellow]⚠ Tip incorect pentru '{col}': așteptat {expected}, găsit {actual}[/yellow]")

    if all_ok:
        console.print(f"\n[bold green]✓ Validare completă — totul este OK![/bold green]")
    else:
        console.print(f"\n[bold yellow]⚠ Validarea a detectat probleme.[/bold yellow]")
