"""Configuration schema definitions using Pydantic."""
from typing import List, Optional, Dict, Any
from pathlib import Path
from pydantic import BaseModel, Field, field_validator, model_validator
from enum import Enum
[docs]
class Granularity(str, Enum):
"""Time granularity options."""
DAILY = "daily"
WEEKLY = "weekly"
MONTHLY = "monthly"
[docs]
class Theme(str, Enum):
"""Visualization theme options."""
PROFESSIONAL = "professional"
MINIMAL = "minimal"
DARK = "dark"
[docs]
class ChartType(str, Enum):
"""Available chart types."""
LINE_SMOOTHED = "line_smoothed"
AREA_STACKED = "area_stacked"
BAR_AVERAGE = "bar_average"
HEATMAP_MONTHLY = "heatmap_monthly"
TREND_COMPARISON = "trend_comparison"
[docs]
class ProjectConfig(BaseModel):
"""Project metadata configuration."""
name: str = Field(..., description="Project name")
description: str = Field("", description="Project description")
output_dir: Path = Field(Path("./outputs"), description="Output directory")
[docs]
class DataSourceConfig(BaseModel):
"""Data source configuration."""
provider: str = Field("serpapi", description="Data provider")
api_key_env: str = Field("SERPAPI_KEY", description="Environment variable for API key")
[docs]
class QueryConfig(BaseModel):
"""Individual query configuration."""
query: str = Field(..., description="Search query")
label: str = Field(..., description="Display label")
color: Optional[str] = Field(None, description="Custom chart color (hex)")
[docs]
class ParametersConfig(BaseModel):
"""Search parameters configuration."""
geo: str = Field("GB", description="Geographic region code")
date_range: str = Field("today 12-m", description="Date range")
granularity: Granularity = Field(Granularity.MONTHLY, description="Time granularity")
smoothing_period: int = Field(3, ge=1, description="Smoothing window size")
@field_validator('smoothing_period')
@classmethod
def validate_smoothing_period(cls, v):
if v < 1:
raise ValueError("smoothing_period must be at least 1")
return v
[docs]
class MetricsConfig(BaseModel):
"""Metrics to calculate."""
metrics: List[str] = Field(
default=[
"share_of_search",
"trend_direction",
"volatility",
"momentum",
"seasonality"
],
description="List of metrics to calculate"
)
[docs]
class InsightsConfig(BaseModel):
"""AI insights configuration."""
enabled: bool = Field(True, description="Enable AI insights")
auto_detect_anomalies: bool = Field(True, description="Automatically detect anomalies")
confidence_threshold: float = Field(0.8, ge=0.0, le=1.0, description="Confidence threshold")
@field_validator('confidence_threshold')
@classmethod
def validate_threshold(cls, v):
if not 0.0 <= v <= 1.0:
raise ValueError("confidence_threshold must be between 0.0 and 1.0")
return v
[docs]
class AnalysisConfig(BaseModel):
"""Analysis configuration."""
metrics: MetricsConfig = Field(default_factory=MetricsConfig)
insights: InsightsConfig = Field(default_factory=InsightsConfig)
[docs]
class ChartConfig(BaseModel):
"""Individual chart configuration."""
type: ChartType = Field(..., description="Chart type")
title: str = Field(..., description="Chart title")
[docs]
class VisualizationConfig(BaseModel):
"""Visualization configuration."""
theme: Theme = Field(Theme.PROFESSIONAL, description="Visual theme")
dpi: int = Field(300, ge=72, description="Chart resolution")
charts: List[ChartConfig] = Field(..., description="Charts to generate")
@field_validator('dpi')
@classmethod
def validate_dpi(cls, v):
if v < 72:
raise ValueError("dpi must be at least 72")
return v
[docs]
class BrandingConfig(BaseModel):
"""Branding configuration."""
logo_path: Optional[Path] = Field(None, description="Path to company logo")
[docs]
class ReportingConfig(BaseModel):
"""Reporting configuration."""
formats: List[ReportFormat] = Field(
default=[ReportFormat.HTML, ReportFormat.PDF, ReportFormat.EXCEL],
description="Output formats"
)
include_sections: List[str] = Field(
default=[
"executive_summary",
"methodology",
"detailed_metrics",
"visualizations",
"data_tables",
"insights",
"appendix"
],
description="Report sections to include"
)
branding: BrandingConfig = Field(default_factory=BrandingConfig)
[docs]
class OptionsConfig(BaseModel):
"""Advanced options configuration."""
cache_enabled: bool = Field(True, description="Enable response caching")
cache_ttl: int = Field(3600, ge=0, description="Cache time-to-live (seconds)")
retry_attempts: int = Field(3, ge=0, description="API retry attempts")
retry_delay: int = Field(5, ge=1, description="Delay between retries (seconds)")
log_level: str = Field("INFO", description="Logging level")
parallel_requests: bool = Field(False, description="Enable parallel API requests")
@field_validator('log_level')
@classmethod
def validate_log_level(cls, v):
valid_levels = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
if v.upper() not in valid_levels:
raise ValueError(f"log_level must be one of {valid_levels}")
return v.upper()
[docs]
class Config(BaseModel):
"""Complete configuration schema."""
project: ProjectConfig
data_source: DataSourceConfig = Field(default_factory=DataSourceConfig)
queries: List[QueryConfig] = Field(..., min_length=1, max_length=5)
parameters: ParametersConfig = Field(default_factory=ParametersConfig)
analysis: AnalysisConfig = Field(default_factory=AnalysisConfig)
visualization: VisualizationConfig
reporting: ReportingConfig = Field(default_factory=ReportingConfig)
options: OptionsConfig = Field(default_factory=OptionsConfig)
@field_validator('queries')
@classmethod
def validate_queries(cls, v):
if not 1 <= len(v) <= 5:
raise ValueError("Must provide between 1 and 5 queries")
return v
class Config:
use_enum_values = True
validate_assignment = True