Source code for src.config.schema

"""Configuration schema definitions using Pydantic."""

from typing import List, Optional, Dict, Any
from pathlib import Path
from pydantic import BaseModel, Field, field_validator, model_validator
from enum import Enum


[docs] class Granularity(str, Enum): """Time granularity options.""" DAILY = "daily" WEEKLY = "weekly" MONTHLY = "monthly"
[docs] class Theme(str, Enum): """Visualization theme options.""" PROFESSIONAL = "professional" MINIMAL = "minimal" DARK = "dark"
[docs] class ChartType(str, Enum): """Available chart types.""" LINE_SMOOTHED = "line_smoothed" AREA_STACKED = "area_stacked" BAR_AVERAGE = "bar_average" HEATMAP_MONTHLY = "heatmap_monthly" TREND_COMPARISON = "trend_comparison"
[docs] class ReportFormat(str, Enum): """Report output formats.""" HTML = "html" PDF = "pdf" EXCEL = "excel"
[docs] class ProjectConfig(BaseModel): """Project metadata configuration.""" name: str = Field(..., description="Project name") description: str = Field("", description="Project description") output_dir: Path = Field(Path("./outputs"), description="Output directory")
[docs] class DataSourceConfig(BaseModel): """Data source configuration.""" provider: str = Field("serpapi", description="Data provider") api_key_env: str = Field("SERPAPI_KEY", description="Environment variable for API key")
[docs] class QueryConfig(BaseModel): """Individual query configuration.""" query: str = Field(..., description="Search query") label: str = Field(..., description="Display label") color: Optional[str] = Field(None, description="Custom chart color (hex)")
[docs] class ParametersConfig(BaseModel): """Search parameters configuration.""" geo: str = Field("GB", description="Geographic region code") date_range: str = Field("today 12-m", description="Date range") granularity: Granularity = Field(Granularity.MONTHLY, description="Time granularity") smoothing_period: int = Field(3, ge=1, description="Smoothing window size") @field_validator('smoothing_period') @classmethod def validate_smoothing_period(cls, v): if v < 1: raise ValueError("smoothing_period must be at least 1") return v
[docs] class MetricsConfig(BaseModel): """Metrics to calculate.""" metrics: List[str] = Field( default=[ "share_of_search", "trend_direction", "volatility", "momentum", "seasonality" ], description="List of metrics to calculate" )
[docs] class InsightsConfig(BaseModel): """AI insights configuration.""" enabled: bool = Field(True, description="Enable AI insights") auto_detect_anomalies: bool = Field(True, description="Automatically detect anomalies") confidence_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Confidence threshold") @field_validator('confidence_threshold') @classmethod def validate_threshold(cls, v): if not 0.0 <= v <= 1.0: raise ValueError("confidence_threshold must be between 0.0 and 1.0") return v
[docs] class AnalysisConfig(BaseModel): """Analysis configuration.""" metrics: MetricsConfig = Field(default_factory=MetricsConfig) insights: InsightsConfig = Field(default_factory=InsightsConfig)
[docs] class ChartConfig(BaseModel): """Individual chart configuration.""" type: ChartType = Field(..., description="Chart type") title: str = Field(..., description="Chart title")
[docs] class VisualizationConfig(BaseModel): """Visualization configuration.""" theme: Theme = Field(Theme.PROFESSIONAL, description="Visual theme") dpi: int = Field(300, ge=72, description="Chart resolution") charts: List[ChartConfig] = Field(..., description="Charts to generate") @field_validator('dpi') @classmethod def validate_dpi(cls, v): if v < 72: raise ValueError("dpi must be at least 72") return v
[docs] class BrandingConfig(BaseModel): """Branding configuration.""" logo_path: Optional[Path] = Field(None, description="Path to company logo")
[docs] class ReportingConfig(BaseModel): """Reporting configuration.""" formats: List[ReportFormat] = Field( default=[ReportFormat.HTML, ReportFormat.PDF, ReportFormat.EXCEL], description="Output formats" ) include_sections: List[str] = Field( default=[ "executive_summary", "methodology", "detailed_metrics", "visualizations", "data_tables", "insights", "appendix" ], description="Report sections to include" ) branding: BrandingConfig = Field(default_factory=BrandingConfig)
[docs] class OptionsConfig(BaseModel): """Advanced options configuration.""" cache_enabled: bool = Field(True, description="Enable response caching") cache_ttl: int = Field(3600, ge=0, description="Cache time-to-live (seconds)") retry_attempts: int = Field(3, ge=0, description="API retry attempts") retry_delay: int = Field(5, ge=1, description="Delay between retries (seconds)") log_level: str = Field("INFO", description="Logging level") parallel_requests: bool = Field(False, description="Enable parallel API requests") @field_validator('log_level') @classmethod def validate_log_level(cls, v): valid_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] if v.upper() not in valid_levels: raise ValueError(f"log_level must be one of {valid_levels}") return v.upper()
[docs] class Config(BaseModel): """Complete configuration schema.""" project: ProjectConfig data_source: DataSourceConfig = Field(default_factory=DataSourceConfig) queries: List[QueryConfig] = Field(..., min_length=1, max_length=5) parameters: ParametersConfig = Field(default_factory=ParametersConfig) analysis: AnalysisConfig = Field(default_factory=AnalysisConfig) visualization: VisualizationConfig reporting: ReportingConfig = Field(default_factory=ReportingConfig) options: OptionsConfig = Field(default_factory=OptionsConfig) @field_validator('queries') @classmethod def validate_queries(cls, v): if not 1 <= len(v) <= 5: raise ValueError("Must provide between 1 and 5 queries") return v class Config: use_enum_values = True validate_assignment = True