logo
RustyBT Documentation
Custom Data Adapter
Initializing search
    jerryinyang/rustybt
    • Home
    • Getting Started
    • User Guides
    • Migration Guides
    • Examples & Tutorials
    • API Reference
    • About
    jerryinyang/rustybt
    • Home
      • Installation
      • Quick Start
      • Configuration
      • Decimal Precision
        • Backtest Output Organization
        • Strategy Code Capture
        • Cash Validation
        • CSV Data Import
        • Data Ingestion
        • Databento Data Import
        • Data Validation
        • Creating Data Adapters
        • Migrating to Unified Data
        • Caching System
        • Caching Guide
        • Broker Setup
        • Testnet Setup
        • Live vs Backtest Data
        • WebSocket Streaming
        • Type Hinting
        • Exception Handling
        • Execution Methods
        • Pipeline API
        • Advanced Pipeline Techniques
        • Multi-Strategy Portfolio
        • Deployment Guide
        • Production Checklist
        • Audit Logging
        • Troubleshooting
      • Cash Validation Migration
      • Overview
        • Overview
          • Crypto Backtesting with CCXT Data Adapter
          • Equity Backtesting with YFinance Data Adapter
          • Getting Started with RustyBT
          • Data Ingestion with RustyBT
          • Strategy Development with RustyBT
          • Performance Analysis
          • Strategy Optimization
          • Walk-Forward Optimization
          • Risk Analytics
          • Portfolio Construction (Single-Strategy Multi-Asset)
          • 09. Multi-Strategy Portfolio
          • Live Paper Trading
          • Complete Workflow: Data → Backtest → Analysis → Optimization
          • CCXT Data Ingestion
          • YFinance Data Ingestion
          • Custom Data Adapter
          • Backtest with Cache
          • Full Validation (Backtest & Paper)
          • Cache Warming
          • Generate Backtest Report
          • Live Trading (Simple)
          • Live Trading (Advanced)
          • Paper Trading (Simple)
          • Paper Trading Validation
          • Shadow Trading (Simple)
          • Shadow Trading Dashboard
          • Portfolio Allocator Tutorial
          • Allocation Algorithms
          • Attribution Analysis
          • Slippage Models
          • Borrow Costs
          • Overnight Financing
          • High-Frequency Custom Triggers
          • Latency Simulation
          • Pipeline API
          • WebSocket Streaming
          • Custom Broker Adapter
          • Grid Search MA Crossover
          • Random Search vs Grid
          • Bayesian Optimization (5 Params)
          • Parallel Optimization
          • Walk-Forward Analysis
      • Overview & Interactive Docs
        • Overview
        • Asset Finder
          • Overview
          • Selection Guide
          • Base Adapter
          • CCXT
          • YFinance
          • CSV
          • Polygon
          • Alpaca
          • AlphaVantage
          • Overview
          • Architecture
          • Catalog API
          • Bundle System
          • Metadata Tracking
          • Migration Guide
          • Overview
          • Data Portal
          • Polars Data Portal
          • Bar Reader
          • Daily Bars
          • Overview
          • Providers
          • Storage
          • Converters
          • FX & Caching
          • Caching
          • Optimization
          • Troubleshooting
          • Overview
          • Computation API
        • Overview
          • Overview
          • Types Reference
          • Blotter
          • Blotter System
          • Decimal Blotter
          • Execution Pipeline
          • Latency Models
          • Partial Fills
          • Order Status Tracking
          • Slippage
          • Slippage Models
          • Commissions
          • Commission Models
          • Borrow Costs & Financing
          • Order Lifecycle
          • Examples
        • Overview
          • Allocation Algorithms
          • Multi-Strategy Allocation
          • Portfolio Allocator
          • Allocators
          • Risk Management
          • Risk Metrics
          • Position Limits
          • Performance Tracking
          • Metrics
          • Order Aggregation
          • Analytics Suite
        • Overview
          • Parameter Spaces
          • Objective Functions
          • Grid Search
          • Random Search
          • Bayesian
          • Genetic
          • Overview
          • Monte Carlo
          • Noise Infusion
          • Sensitivity Analysis
        • Overview
        • Artifact Manager
        • Code Capture
        • Overview
        • Reports
        • Visualization
          • Overview
          • Overview
          • Metrics
          • VaR & CVaR
          • Drawdown
          • Overview
        • Overview
        • Production Deployment
          • Circuit Breakers
        • Overview
        • Datasource API
        • Optimization API
        • Analytics API
      • License
      • Contributing
      • Changelog
    In [ ]:
    Copied!
    #
    # Copyright 2025 RustyBT Contributors
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.
    """
    Example: Creating a Custom Data Adapter
    
    This example demonstrates how to create a custom data adapter for RustyBT
    that fetches data from a proprietary or custom data source.
    
    We'll implement a simple REST API adapter that could be adapted for any
    custom data provider.
    
    Key Concepts Demonstrated:
    - Inheriting from DataSource base class
    - Implementing required abstract methods
    - Error handling and retry logic
    - Data validation
    - Type conversion (to Decimal)
    - Registering adapter with DataSourceRegistry
    
    Usage:
        python examples/custom_data_adapter.py
    """
    
    # # Copyright 2025 RustyBT Contributors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ Example: Creating a Custom Data Adapter This example demonstrates how to create a custom data adapter for RustyBT that fetches data from a proprietary or custom data source. We'll implement a simple REST API adapter that could be adapted for any custom data provider. Key Concepts Demonstrated: - Inheriting from DataSource base class - Implementing required abstract methods - Error handling and retry logic - Data validation - Type conversion (to Decimal) - Registering adapter with DataSourceRegistry Usage: python examples/custom_data_adapter.py """
    In [ ]:
    Copied!
    import asyncio
    import random
    from decimal import Decimal, getcontext
    from pathlib import Path
    
    import asyncio import random from decimal import Decimal, getcontext from pathlib import Path
    In [ ]:
    Copied!
    import httpx
    import pandas as pd
    import polars as pl
    import structlog
    
    import httpx import pandas as pd import polars as pl import structlog
    In [ ]:
    Copied!
    from rustybt.data.sources.base import (
        DataSource,
        DataSourceMetadata,
        DataValidationError,
        NetworkError,
        RateLimitError,
        with_retry,
    )
    
    from rustybt.data.sources.base import ( DataSource, DataSourceMetadata, DataValidationError, NetworkError, RateLimitError, with_retry, )
    In [ ]:
    Copied!
    # Set decimal precision
    getcontext().prec = 28
    
    # Set decimal precision getcontext().prec = 28
    In [ ]:
    Copied!
    logger = structlog.get_logger()
    
    logger = structlog.get_logger()
    In [ ]:
    Copied!
    class CustomAPIDataSource(DataSource):
        """Custom data source adapter for a hypothetical REST API.
    
        This adapter demonstrates the complete implementation of a DataSource
        for fetching OHLCV data from a custom API endpoint.
    
        Adapt this template for your own data sources by:
        1. Changing the API endpoint and authentication
        2. Modifying the data parsing logic
        3. Adjusting error handling for your API
        4. Updating metadata and capabilities
    
        Example:
            >>> source = CustomAPIDataSource(
            ...     api_url="https://api.example.com/v1",
            ...     api_key="your_api_key"
            ... )
            >>> data = await source.fetch(
            ...     symbols=["AAPL", "MSFT"],
            ...     start=pd.Timestamp("2023-01-01"),
            ...     end=pd.Timestamp("2023-12-31"),
            ...     frequency="1d"
            ... )
        """
    
        def __init__(
            self, api_url: str, api_key: str, rate_limit_per_second: int = 5, timeout: int = 30
        ):
            """Initialize custom data source.
    
            Args:
                api_url: Base URL for API (e.g., "https://api.example.com/v1")
                api_key: API authentication key
                rate_limit_per_second: Maximum requests per second
                timeout: Request timeout in seconds
            """
            self.api_url = api_url.rstrip("/")
            self.api_key = api_key
            self.timeout = timeout
    
            # Initialize HTTP client
            self.client = httpx.AsyncClient(timeout=timeout)
    
            logger.info(
                "custom_adapter_initialized", api_url=self.api_url, rate_limit=rate_limit_per_second
            )
    
        @with_retry(max_retries=3, initial_delay=1.0, backoff_factor=2.0)
        async def fetch(
            self, symbols: list[str], start: pd.Timestamp, end: pd.Timestamp, frequency: str
        ) -> pl.DataFrame:
            """Fetch OHLCV data from custom API.
    
            Args:
                symbols: List of ticker symbols
                start: Start timestamp (inclusive)
                end: End timestamp (inclusive)
                frequency: Data frequency ("1d", "1h", "1m", etc.)
    
            Returns:
                Polars DataFrame with columns:
                - symbol: str
                - date: date (for daily) or timestamp: datetime (for intraday)
                - open: Decimal
                - high: Decimal
                - low: Decimal
                - close: Decimal
                - volume: Decimal
    
            Raises:
                NetworkError: If API request fails
                RateLimitError: If rate limit exceeded
                DataValidationError: If data format is invalid
            """
            logger.info(
                "fetching_data",
                symbols=symbols,
                start=start.isoformat(),
                end=end.isoformat(),
                frequency=frequency,
            )
    
            all_data = []
    
            for symbol in symbols:
                # Fetch data for each symbol
                try:
                    symbol_data = await self._fetch_symbol(symbol, start, end, frequency)
                    all_data.append(symbol_data)
    
                except httpx.HTTPStatusError as e:
                    if e.response.status_code == 429:
                        # Rate limit exceeded
                        raise RateLimitError(
                            f"Rate limit exceeded for {symbol}", adapter="custom_api", reset_after=60.0
                        )
                    elif e.response.status_code >= 500:
                        # Server error - retry
                        raise NetworkError(f"Server error for {symbol}: {e}", adapter="custom_api")
                    else:
                        # Client error - don't retry
                        logger.error("api_error", symbol=symbol, status=e.response.status_code)
                        raise
    
                except httpx.TimeoutException as e:
                    raise NetworkError(f"Timeout fetching {symbol}: {e}", adapter="custom_api")
    
            # Combine all symbols into single DataFrame
            if not all_data:
                return pl.DataFrame()
    
            df = pl.concat(all_data)
    
            # Validate data
            self._validate_ohlcv_data(df)
    
            logger.info("fetch_complete", rows=len(df), symbols=len(symbols))
    
            return df
    
        async def _fetch_symbol(
            self, symbol: str, start: pd.Timestamp, end: pd.Timestamp, frequency: str
        ) -> pl.DataFrame:
            """Fetch data for a single symbol.
    
            This is where you implement your API-specific logic.
    
            Args:
                symbol: Ticker symbol
                start: Start timestamp
                end: End timestamp
                frequency: Data frequency
    
            Returns:
                Polars DataFrame with OHLCV data for the symbol
            """
            # Build API request
            url = f"{self.api_url}/ohlcv/{symbol}"
            params = {
                "start": start.isoformat(),
                "end": end.isoformat(),
                "frequency": frequency,
                "format": "json",
            }
            headers = {"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"}
    
            # Make API request
            response = await self.client.get(url, params=params, headers=headers)
            response.raise_for_status()
    
            # Parse response
            data = response.json()
    
            # Convert to DataFrame
            # NOTE: Adapt this to your API's response format
            if not data or "data" not in data:
                return pl.DataFrame()
    
            rows = []
            for record in data["data"]:
                rows.append(
                    {
                        "symbol": symbol,
                        "date": (
                            pd.Timestamp(record["timestamp"]).date()
                            if frequency.endswith("d")
                            else pd.Timestamp(record["timestamp"])
                        ),
                        "open": Decimal(str(record["open"])),
                        "high": Decimal(str(record["high"])),
                        "low": Decimal(str(record["low"])),
                        "close": Decimal(str(record["close"])),
                        "volume": Decimal(str(record["volume"])),
                    }
                )
    
            if not rows:
                return pl.DataFrame()
    
            # Create DataFrame with proper types
            if frequency.endswith("d"):
                # Daily data - use date
                df = pl.DataFrame(rows).with_columns(
                    [
                        pl.col("date").cast(pl.Date),
                        pl.col("open").cast(pl.Decimal(18, 8)),
                        pl.col("high").cast(pl.Decimal(18, 8)),
                        pl.col("low").cast(pl.Decimal(18, 8)),
                        pl.col("close").cast(pl.Decimal(18, 8)),
                        pl.col("volume").cast(pl.Decimal(18, 8)),
                    ]
                )
            else:
                # Intraday data - use timestamp
                df = (
                    pl.DataFrame(rows)
                    .rename({"date": "timestamp"})
                    .with_columns(
                        [
                            pl.col("timestamp").cast(pl.Datetime),
                            pl.col("open").cast(pl.Decimal(18, 8)),
                            pl.col("high").cast(pl.Decimal(18, 8)),
                            pl.col("low").cast(pl.Decimal(18, 8)),
                            pl.col("close").cast(pl.Decimal(18, 8)),
                            pl.col("volume").cast(pl.Decimal(18, 8)),
                        ]
                    )
                )
    
            return df
    
        def _validate_ohlcv_data(self, df: pl.DataFrame) -> None:
            """Validate OHLCV data integrity.
    
            Args:
                df: DataFrame to validate
    
            Raises:
                DataValidationError: If validation fails
            """
            if df.is_empty():
                return
    
            # Check required columns
            required_cols = {"symbol", "open", "high", "low", "close", "volume"}
            actual_cols = set(df.columns)
    
            if not required_cols.issubset(actual_cols):
                missing = required_cols - actual_cols
                raise DataValidationError(f"Missing required columns: {missing}", adapter="custom_api")
    
            # Validate OHLCV relationships (high >= low, etc.)
            invalid_rows = df.filter(
                (pl.col("high") < pl.col("low"))
                | (pl.col("close") > pl.col("high"))
                | (pl.col("close") < pl.col("low"))
                | (pl.col("open") > pl.col("high"))
                | (pl.col("open") < pl.col("low"))
            )
    
            if len(invalid_rows) > 0:
                raise DataValidationError(
                    f"Found {len(invalid_rows)} rows with invalid OHLCV relationships",
                    adapter="custom_api",
                )
    
        def ingest_to_bundle(
            self,
            bundle_name: str,
            symbols: list[str],
            start: pd.Timestamp,
            end: pd.Timestamp,
            frequency: str,
            **kwargs,
        ) -> Path:
            """Ingest data and create bundle.
    
            This uses the default bundle creation logic from adapter_bundles.
    
            Args:
                bundle_name: Name for the bundle
                symbols: Symbols to ingest
                start: Start date
                end: End date
                frequency: Data frequency
                **kwargs: Additional options
    
            Returns:
                Path to created bundle directory
            """
            from rustybt.data.bundles.adapter_bundles import ingest_from_datasource
    
            return ingest_from_datasource(
                datasource=self,
                bundle_name=bundle_name,
                symbols=symbols,
                start=start,
                end=end,
                frequency=frequency,
                **kwargs,
            )
    
        def get_metadata(self) -> DataSourceMetadata:
            """Get data source metadata.
    
            Returns:
                Metadata describing this data source
            """
            return DataSourceMetadata(
                source_type="custom_api",
                source_url=self.api_url,
                api_version="v1",
                supports_live=False,  # Set True if supports real-time streaming
                supported_frequencies=["1d", "1h", "5m", "1m"],
                rate_limit=300,  # Requests per minute
                requires_auth=True,
            )
    
        def supports_live(self) -> bool:
            """Check if source supports live streaming.
    
            Returns:
                True if live streaming supported, False otherwise
            """
            return False  # Change to True if you implement WebSocket streaming
    
        async def close(self) -> None:
            """Close HTTP client and cleanup resources."""
            await self.client.aclose()
            logger.info("custom_adapter_closed")
    
    class CustomAPIDataSource(DataSource): """Custom data source adapter for a hypothetical REST API. This adapter demonstrates the complete implementation of a DataSource for fetching OHLCV data from a custom API endpoint. Adapt this template for your own data sources by: 1. Changing the API endpoint and authentication 2. Modifying the data parsing logic 3. Adjusting error handling for your API 4. Updating metadata and capabilities Example: >>> source = CustomAPIDataSource( ... api_url="https://api.example.com/v1", ... api_key="your_api_key" ... ) >>> data = await source.fetch( ... symbols=["AAPL", "MSFT"], ... start=pd.Timestamp("2023-01-01"), ... end=pd.Timestamp("2023-12-31"), ... frequency="1d" ... ) """ def __init__( self, api_url: str, api_key: str, rate_limit_per_second: int = 5, timeout: int = 30 ): """Initialize custom data source. Args: api_url: Base URL for API (e.g., "https://api.example.com/v1") api_key: API authentication key rate_limit_per_second: Maximum requests per second timeout: Request timeout in seconds """ self.api_url = api_url.rstrip("/") self.api_key = api_key self.timeout = timeout # Initialize HTTP client self.client = httpx.AsyncClient(timeout=timeout) logger.info( "custom_adapter_initialized", api_url=self.api_url, rate_limit=rate_limit_per_second ) @with_retry(max_retries=3, initial_delay=1.0, backoff_factor=2.0) async def fetch( self, symbols: list[str], start: pd.Timestamp, end: pd.Timestamp, frequency: str ) -> pl.DataFrame: """Fetch OHLCV data from custom API. Args: symbols: List of ticker symbols start: Start timestamp (inclusive) end: End timestamp (inclusive) frequency: Data frequency ("1d", "1h", "1m", etc.) Returns: Polars DataFrame with columns: - symbol: str - date: date (for daily) or timestamp: datetime (for intraday) - open: Decimal - high: Decimal - low: Decimal - close: Decimal - volume: Decimal Raises: NetworkError: If API request fails RateLimitError: If rate limit exceeded DataValidationError: If data format is invalid """ logger.info( "fetching_data", symbols=symbols, start=start.isoformat(), end=end.isoformat(), frequency=frequency, ) all_data = [] for symbol in symbols: # Fetch data for each symbol try: symbol_data = await self._fetch_symbol(symbol, start, end, frequency) all_data.append(symbol_data) except httpx.HTTPStatusError as e: if e.response.status_code == 429: # Rate limit exceeded raise RateLimitError( f"Rate limit exceeded for {symbol}", adapter="custom_api", reset_after=60.0 ) elif e.response.status_code >= 500: # Server error - retry raise NetworkError(f"Server error for {symbol}: {e}", adapter="custom_api") else: # Client error - don't retry logger.error("api_error", symbol=symbol, status=e.response.status_code) raise except httpx.TimeoutException as e: raise NetworkError(f"Timeout fetching {symbol}: {e}", adapter="custom_api") # Combine all symbols into single DataFrame if not all_data: return pl.DataFrame() df = pl.concat(all_data) # Validate data self._validate_ohlcv_data(df) logger.info("fetch_complete", rows=len(df), symbols=len(symbols)) return df async def _fetch_symbol( self, symbol: str, start: pd.Timestamp, end: pd.Timestamp, frequency: str ) -> pl.DataFrame: """Fetch data for a single symbol. This is where you implement your API-specific logic. Args: symbol: Ticker symbol start: Start timestamp end: End timestamp frequency: Data frequency Returns: Polars DataFrame with OHLCV data for the symbol """ # Build API request url = f"{self.api_url}/ohlcv/{symbol}" params = { "start": start.isoformat(), "end": end.isoformat(), "frequency": frequency, "format": "json", } headers = {"Authorization": f"Bearer {self.api_key}", "Accept": "application/json"} # Make API request response = await self.client.get(url, params=params, headers=headers) response.raise_for_status() # Parse response data = response.json() # Convert to DataFrame # NOTE: Adapt this to your API's response format if not data or "data" not in data: return pl.DataFrame() rows = [] for record in data["data"]: rows.append( { "symbol": symbol, "date": ( pd.Timestamp(record["timestamp"]).date() if frequency.endswith("d") else pd.Timestamp(record["timestamp"]) ), "open": Decimal(str(record["open"])), "high": Decimal(str(record["high"])), "low": Decimal(str(record["low"])), "close": Decimal(str(record["close"])), "volume": Decimal(str(record["volume"])), } ) if not rows: return pl.DataFrame() # Create DataFrame with proper types if frequency.endswith("d"): # Daily data - use date df = pl.DataFrame(rows).with_columns( [ pl.col("date").cast(pl.Date), pl.col("open").cast(pl.Decimal(18, 8)), pl.col("high").cast(pl.Decimal(18, 8)), pl.col("low").cast(pl.Decimal(18, 8)), pl.col("close").cast(pl.Decimal(18, 8)), pl.col("volume").cast(pl.Decimal(18, 8)), ] ) else: # Intraday data - use timestamp df = ( pl.DataFrame(rows) .rename({"date": "timestamp"}) .with_columns( [ pl.col("timestamp").cast(pl.Datetime), pl.col("open").cast(pl.Decimal(18, 8)), pl.col("high").cast(pl.Decimal(18, 8)), pl.col("low").cast(pl.Decimal(18, 8)), pl.col("close").cast(pl.Decimal(18, 8)), pl.col("volume").cast(pl.Decimal(18, 8)), ] ) ) return df def _validate_ohlcv_data(self, df: pl.DataFrame) -> None: """Validate OHLCV data integrity. Args: df: DataFrame to validate Raises: DataValidationError: If validation fails """ if df.is_empty(): return # Check required columns required_cols = {"symbol", "open", "high", "low", "close", "volume"} actual_cols = set(df.columns) if not required_cols.issubset(actual_cols): missing = required_cols - actual_cols raise DataValidationError(f"Missing required columns: {missing}", adapter="custom_api") # Validate OHLCV relationships (high >= low, etc.) invalid_rows = df.filter( (pl.col("high") < pl.col("low")) | (pl.col("close") > pl.col("high")) | (pl.col("close") < pl.col("low")) | (pl.col("open") > pl.col("high")) | (pl.col("open") < pl.col("low")) ) if len(invalid_rows) > 0: raise DataValidationError( f"Found {len(invalid_rows)} rows with invalid OHLCV relationships", adapter="custom_api", ) def ingest_to_bundle( self, bundle_name: str, symbols: list[str], start: pd.Timestamp, end: pd.Timestamp, frequency: str, **kwargs, ) -> Path: """Ingest data and create bundle. This uses the default bundle creation logic from adapter_bundles. Args: bundle_name: Name for the bundle symbols: Symbols to ingest start: Start date end: End date frequency: Data frequency **kwargs: Additional options Returns: Path to created bundle directory """ from rustybt.data.bundles.adapter_bundles import ingest_from_datasource return ingest_from_datasource( datasource=self, bundle_name=bundle_name, symbols=symbols, start=start, end=end, frequency=frequency, **kwargs, ) def get_metadata(self) -> DataSourceMetadata: """Get data source metadata. Returns: Metadata describing this data source """ return DataSourceMetadata( source_type="custom_api", source_url=self.api_url, api_version="v1", supports_live=False, # Set True if supports real-time streaming supported_frequencies=["1d", "1h", "5m", "1m"], rate_limit=300, # Requests per minute requires_auth=True, ) def supports_live(self) -> bool: """Check if source supports live streaming. Returns: True if live streaming supported, False otherwise """ return False # Change to True if you implement WebSocket streaming async def close(self) -> None: """Close HTTP client and cleanup resources.""" await self.client.aclose() logger.info("custom_adapter_closed")

    ============================================================================ Example Usage¶

    In [ ]:
    Copied!
    class MockAPIDataSource(CustomAPIDataSource):
        """Mock version that generates fake data for demonstration.
    
        This removes the need for a real API endpoint to test the example.
        """
    
        async def _fetch_symbol(
            self, symbol: str, start: pd.Timestamp, end: pd.Timestamp, frequency: str
        ) -> pl.DataFrame:
            """Generate mock data instead of fetching from API."""
            # Generate date range
            if frequency == "1d":
                dates = pd.date_range(start=start, end=end, freq="D")
            elif frequency == "1h":
                dates = pd.date_range(start=start, end=end, freq="H")
            else:
                dates = pd.date_range(start=start, end=end, freq="5min")
    
            # Generate fake OHLCV data
            base_price = Decimal("100")
            rows = []
    
            for date in dates:
                # Random walk
                open_price = base_price + Decimal(str(random.uniform(-2, 2)))  # noqa: S311
                high_price = open_price + Decimal(str(random.uniform(0, 3)))  # noqa: S311
                low_price = open_price - Decimal(str(random.uniform(0, 3)))  # noqa: S311
                close_price = open_price + Decimal(str(random.uniform(-2, 2)))  # noqa: S311
                close_price = min(max(close_price, low_price), high_price)
                volume = Decimal(str(random.randint(100000, 1000000)))  # noqa: S311
    
                rows.append(
                    {
                        "symbol": symbol,
                        "date": date.date() if frequency == "1d" else date,
                        "open": open_price,
                        "high": high_price,
                        "low": low_price,
                        "close": close_price,
                        "volume": volume,
                    }
                )
    
                base_price = close_price
    
            # Create DataFrame
            if frequency == "1d":
                df = pl.DataFrame(rows).with_columns(
                    [
                        pl.col("date").cast(pl.Date),
                        pl.col("open").cast(pl.Decimal(18, 8)),
                        pl.col("high").cast(pl.Decimal(18, 8)),
                        pl.col("low").cast(pl.Decimal(18, 8)),
                        pl.col("close").cast(pl.Decimal(18, 8)),
                        pl.col("volume").cast(pl.Decimal(18, 8)),
                    ]
                )
            else:
                df = (
                    pl.DataFrame(rows)
                    .rename({"date": "timestamp"})
                    .with_columns(
                        [
                            pl.col("timestamp").cast(pl.Datetime),
                            pl.col("open").cast(pl.Decimal(18, 8)),
                            pl.col("high").cast(pl.Decimal(18, 8)),
                            pl.col("low").cast(pl.Decimal(18, 8)),
                            pl.col("close").cast(pl.Decimal(18, 8)),
                            pl.col("volume").cast(pl.Decimal(18, 8)),
                        ]
                    )
                )
    
            return df
    
    class MockAPIDataSource(CustomAPIDataSource): """Mock version that generates fake data for demonstration. This removes the need for a real API endpoint to test the example. """ async def _fetch_symbol( self, symbol: str, start: pd.Timestamp, end: pd.Timestamp, frequency: str ) -> pl.DataFrame: """Generate mock data instead of fetching from API.""" # Generate date range if frequency == "1d": dates = pd.date_range(start=start, end=end, freq="D") elif frequency == "1h": dates = pd.date_range(start=start, end=end, freq="H") else: dates = pd.date_range(start=start, end=end, freq="5min") # Generate fake OHLCV data base_price = Decimal("100") rows = [] for date in dates: # Random walk open_price = base_price + Decimal(str(random.uniform(-2, 2))) # noqa: S311 high_price = open_price + Decimal(str(random.uniform(0, 3))) # noqa: S311 low_price = open_price - Decimal(str(random.uniform(0, 3))) # noqa: S311 close_price = open_price + Decimal(str(random.uniform(-2, 2))) # noqa: S311 close_price = min(max(close_price, low_price), high_price) volume = Decimal(str(random.randint(100000, 1000000))) # noqa: S311 rows.append( { "symbol": symbol, "date": date.date() if frequency == "1d" else date, "open": open_price, "high": high_price, "low": low_price, "close": close_price, "volume": volume, } ) base_price = close_price # Create DataFrame if frequency == "1d": df = pl.DataFrame(rows).with_columns( [ pl.col("date").cast(pl.Date), pl.col("open").cast(pl.Decimal(18, 8)), pl.col("high").cast(pl.Decimal(18, 8)), pl.col("low").cast(pl.Decimal(18, 8)), pl.col("close").cast(pl.Decimal(18, 8)), pl.col("volume").cast(pl.Decimal(18, 8)), ] ) else: df = ( pl.DataFrame(rows) .rename({"date": "timestamp"}) .with_columns( [ pl.col("timestamp").cast(pl.Datetime), pl.col("open").cast(pl.Decimal(18, 8)), pl.col("high").cast(pl.Decimal(18, 8)), pl.col("low").cast(pl.Decimal(18, 8)), pl.col("close").cast(pl.Decimal(18, 8)), pl.col("volume").cast(pl.Decimal(18, 8)), ] ) ) return df
    In [ ]:
    Copied!
    async def main():
        """Demonstrate custom data adapter usage."""
        print("=" * 70)
        print("Custom Data Adapter Example")
        print("=" * 70)
    
        # Create custom adapter (using mock version for demo)
        print("\n[1/4] Initializing custom data source...")
        source = MockAPIDataSource(
            api_url="https://api.example.com/v1", api_key="mock_api_key", rate_limit_per_second=5
        )
        print("✓ Data source initialized")
    
        # Fetch data
        print("\n[2/4] Fetching data...")
        print("  Symbols: AAPL, MSFT")
        print("  Period: 2023-01-01 to 2023-03-31")
        print("  Frequency: 1d")
    
        data = await source.fetch(
            symbols=["AAPL", "MSFT"],
            start=pd.Timestamp("2023-01-01"),
            end=pd.Timestamp("2023-03-31"),
            frequency="1d",
        )
    
        print(f"✓ Fetched {len(data)} rows")
        print("\nSample data:")
        print(data.head(10))
    
        # Display metadata
        print("\n[3/4] Data source metadata:")
        metadata = source.get_metadata()
        print(f"  Type: {metadata.source_type}")
        print(f"  URL: {metadata.source_url}")
        print(f"  Supports live: {metadata.supports_live}")
        print(f"  Frequencies: {metadata.supported_frequencies}")
        print(f"  Rate limit: {metadata.rate_limit} req/min")
    
        # Create bundle
        print("\n[4/4] Creating bundle...")
        bundle_path = source.ingest_to_bundle(
            bundle_name="custom-example",
            symbols=["AAPL", "MSFT"],
            start=pd.Timestamp("2023-01-01"),
            end=pd.Timestamp("2023-03-31"),
            frequency="1d",
        )
        print(f"✓ Bundle created: {bundle_path}")
    
        # Cleanup
        await source.close()
    
        print("\n" + "=" * 70)
        print("✓ Example complete!")
        print("=" * 70)
        print("\nNext steps:")
        print("  1. Adapt CustomAPIDataSource for your API")
        print("  2. Implement authentication for your provider")
        print("  3. Add WebSocket support for live data (optional)")
        print("  4. Register adapter with DataSourceRegistry")
    
    async def main(): """Demonstrate custom data adapter usage.""" print("=" * 70) print("Custom Data Adapter Example") print("=" * 70) # Create custom adapter (using mock version for demo) print("\n[1/4] Initializing custom data source...") source = MockAPIDataSource( api_url="https://api.example.com/v1", api_key="mock_api_key", rate_limit_per_second=5 ) print("✓ Data source initialized") # Fetch data print("\n[2/4] Fetching data...") print(" Symbols: AAPL, MSFT") print(" Period: 2023-01-01 to 2023-03-31") print(" Frequency: 1d") data = await source.fetch( symbols=["AAPL", "MSFT"], start=pd.Timestamp("2023-01-01"), end=pd.Timestamp("2023-03-31"), frequency="1d", ) print(f"✓ Fetched {len(data)} rows") print("\nSample data:") print(data.head(10)) # Display metadata print("\n[3/4] Data source metadata:") metadata = source.get_metadata() print(f" Type: {metadata.source_type}") print(f" URL: {metadata.source_url}") print(f" Supports live: {metadata.supports_live}") print(f" Frequencies: {metadata.supported_frequencies}") print(f" Rate limit: {metadata.rate_limit} req/min") # Create bundle print("\n[4/4] Creating bundle...") bundle_path = source.ingest_to_bundle( bundle_name="custom-example", symbols=["AAPL", "MSFT"], start=pd.Timestamp("2023-01-01"), end=pd.Timestamp("2023-03-31"), frequency="1d", ) print(f"✓ Bundle created: {bundle_path}") # Cleanup await source.close() print("\n" + "=" * 70) print("✓ Example complete!") print("=" * 70) print("\nNext steps:") print(" 1. Adapt CustomAPIDataSource for your API") print(" 2. Implement authentication for your provider") print(" 3. Add WebSocket support for live data (optional)") print(" 4. Register adapter with DataSourceRegistry")
    In [ ]:
    Copied!
    if __name__ == "__main__":
        asyncio.run(main())
    
    if __name__ == "__main__": asyncio.run(main())
    Previous
    YFinance Data Ingestion
    Next
    Backtest with Cache
    Made with Material for MkDocs