diff --git a/README.md b/README.md index aab3ac0b..02c02049 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ also find us in `#tools-data-diff` in the [Locally Optimistic Slack.][slack]** **data-diff** is a command-line tool and Python library to efficiently diff rows across two different databases. -* ⇄ Verifies across [many different databases][dbs] (e.g. Postgres -> Snowflake) +* ⇄ Verifies across [many different databases][dbs] (e.g. PostgreSQL -> Snowflake) * 🔍 Outputs [diff of rows](#example-command-and-output) in detail * 🚨 Simple CLI/API to create monitoring and alerts * 🔥 Verify 25M+ rows in <10s, and 1B+ rows in ~5min. @@ -24,6 +24,13 @@ there are few/no changes, but is able to output each differing row! By pushing the compute into the databases, it's _much_ faster than querying for and comparing every row. +![Performance for 100M rows](https://user-images.githubusercontent.com/97400/174860361-35158d2b-0cad-4089-be66-8bf467058387.png) + +**†:** The implementation for downloading all rows that `data-diff` and +`count(*)` is compared to is not optimal. It is a single Python multi-threaded +process. The performance is fairly driver-specific, e.g. PostgreSQL's performs 10x +better than MySQL. + ## Table of Contents - [Common use-cases](#common-use-cases) @@ -38,7 +45,7 @@ comparing every row. ## Common use-cases * **Verify data migrations.** Verify that all data was copied when doing a - critical data migration. For example, migrating from Heroku Postgres to Amazon RDS. + critical data migration. For example, migrating from Heroku PostgreSQL to Amazon RDS. * **Verifying data pipelines.** Moving data from a relational database to a warehouse/data lake with Fivetran, Airbyte, Debezium, or some other pipeline. * **Alerting and maintaining data integrity SLOs.** You can create and monitor @@ -56,13 +63,13 @@ comparing every row. ## Example Command and Output -Below we run a comparison with the CLI for 25M rows in Postgres where the +Below we run a comparison with the CLI for 25M rows in PostgreSQL where the right-hand table is missing single row with `id=12500048`: ``` $ data-diff \ - postgres://postgres:password@localhost/postgres rating \ - postgres://postgres:password@localhost/postgres rating_del1 \ + postgresql://user:password@localhost/database rating \ + postgresql://user:password@localhost/database rating_del1 \ --bisection-threshold 100000 \ # for readability, try default first --bisection-factor 6 \ # for readability, try default first --update-column timestamp \ @@ -104,7 +111,7 @@ $ data-diff \ | Database | Connection string | Status | |---------------|-----------------------------------------------------------------------------------------|--------| -| Postgres | `postgres://user:password@hostname:5432/database` | 💚 | +| PostgreSQL | `postgresql://user:password@hostname:5432/database` | 💚 | | MySQL | `mysql://user:password@hostname:5432/database` | 💚 | | Snowflake | `snowflake://user:password@account/database/SCHEMA?warehouse=WAREHOUSE&role=role` | 💚 | | Oracle | `oracle://username:password@hostname/database` | 💛 | @@ -133,9 +140,28 @@ Requires Python 3.7+ with pip. ```pip install data-diff``` -or when you need extras like mysql and postgres +## Install drivers + +To connect to a database, we need to have its driver installed, in the form of a Python library. + +While you may install them manually, we offer an easy way to install them along with data-diff: + +- `pip install 'data-diff[mysql]'` + +- `pip install 'data-diff[postgresql]'` + +- `pip install 'data-diff[snowflake]'` -```pip install "data-diff[mysql,pgsql]"``` +- `pip install 'data-diff[presto]'` + +- `pip install 'data-diff[oracle]'` + +- For BigQuery, see: https://pypi.org/project/google-cloud-bigquery/ + + +Users can also install several drivers at once: + +```pip install 'data-diff[mysql,postgresql,snowflake]'``` # How to use @@ -157,6 +183,7 @@ Options: - `-d` or `--debug` - Print debug info - `-v` or `--verbose` - Print extra info - `-i` or `--interactive` - Confirm queries, implies `--debug` + - `--json` - Print JSONL output for machine readability - `--min-age` - Considers only rows older than specified. Example: `--min-age=5min` ignores rows from the last 5 minutes. Valid units: `d, days, h, hours, min, minutes, mon, months, s, seconds, w, weeks, y, years` @@ -167,6 +194,10 @@ Options: ## How to use from Python +API reference: [https://data-diff.readthedocs.io/en/latest/](https://data-diff.readthedocs.io/en/latest/) + +Example: + ```python # Optional: Set logging to display the progress of the diff import logging @@ -174,7 +205,7 @@ logging.basicConfig(level=logging.INFO) from data_diff import connect_to_table, diff_tables -table1 = connect_to_table("postgres:///", "table_name", "id") +table1 = connect_to_table("postgresql:///", "table_name", "id") table2 = connect_to_table("mysql:///", "table_name", "id") for different_row in diff_tables(table1, table2): @@ -182,7 +213,7 @@ for different_row in diff_tables(table1, table2): print(plus_or_minus, columns) ``` -Run `help(diff_tables)` or read the docs [ADD LINK] to learn about the different options. +Run `help(diff_tables)` or [read the docs](https://data-diff.readthedocs.io/en/latest/) to learn about the different options. # Technical Explanation @@ -190,11 +221,11 @@ In this section we'll be doing a walk-through of exactly how **data-diff** works, and how to tune `--bisection-factor` and `--bisection-threshold`. Let's consider a scenario with an `orders` table with 1M rows. Fivetran is -replicating it contionously from Postgres to Snowflake: +replicating it contionously from PostgreSQL to Snowflake: ``` ┌─────────────┐ ┌─────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├─────────────┤ ├─────────────┤ │ │ │ │ │ │ │ │ @@ -222,7 +253,7 @@ of the table. Then it splits the table into `--bisection-factor=10` segments of ``` ┌──────────────────────┐ ┌──────────────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├──────────────────────┤ ├──────────────────────┤ │ id=1..100k │ │ id=1..100k │ ├──────────────────────┤ ├──────────────────────┤ @@ -270,7 +301,7 @@ are the same except `id=100k..200k`: ``` ┌──────────────────────┐ ┌──────────────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├──────────────────────┤ ├──────────────────────┤ │ checksum=0102 │ │ checksum=0102 │ ├──────────────────────┤ mismatch! ├──────────────────────┤ @@ -295,7 +326,7 @@ and compare them in memory in **data-diff**. ``` ┌──────────────────────┐ ┌──────────────────────┐ -│ Postgres │ │ Snowflake │ +│ PostgreSQL │ │ Snowflake │ ├──────────────────────┤ ├──────────────────────┤ │ id=100k..110k │ │ id=100k..110k │ ├──────────────────────┤ ├──────────────────────┤ @@ -326,7 +357,7 @@ If you pass `--stats` you'll see e.g. what % of rows were different. queries. * Consider increasing the number of simultaneous threads executing queries per database with `--threads`. For databases that limit concurrency - per query, e.g. Postgres/MySQL, this can improve performance dramatically. + per query, e.g. PostgreSQL/MySQL, this can improve performance dramatically. * If you are only interested in _whether_ something changed, pass `--limit 1`. This can be useful if changes are very rare. This is often faster than doing a `count(*)`, for the reason mentioned above. @@ -408,7 +439,7 @@ Now you can insert it into the testing database(s): ```shell-session # It's optional to seed more than one to run data-diff(1) against. $ poetry run preql -f dev/prepare_db.pql mysql://mysql:Password1@127.0.0.1:3306/mysql -$ poetry run preql -f dev/prepare_db.pql postgres://postgres:Password1@127.0.0.1:5432/postgres +$ poetry run preql -f dev/prepare_db.pql postgresql://postgres:Password1@127.0.0.1:5432/postgres # Cloud databases $ poetry run preql -f dev/prepare_db.pql snowflake:// @@ -419,7 +450,7 @@ $ poetry run preql -f dev/prepare_db.pql bigquery:/// **5. Run **data-diff** against seeded database** ```bash -poetry run python3 -m data_diff postgres://postgres:Password1@localhost/postgres rating postgres://postgres:Password1@localhost/postgres rating_del1 --verbose +poetry run python3 -m data_diff postgresql://postgres:Password1@localhost/postgres rating postgresql://postgres:Password1@localhost/postgres rating_del1 --verbose ``` # License diff --git a/data_diff/__init__.py b/data_diff/__init__.py index 2688308e..4bc73733 100644 --- a/data_diff/__init__.py +++ b/data_diff/__init__.py @@ -56,7 +56,7 @@ def diff_tables( """Efficiently finds the diff between table1 and table2. Example: - >>> table1 = connect_to_table('postgres:///', 'Rating', 'id') + >>> table1 = connect_to_table('postgresql:///', 'Rating', 'id') >>> list(diff_tables(table1, table1)) [] diff --git a/data_diff/__main__.py b/data_diff/__main__.py index dcdd7f27..6ee6992c 100644 --- a/data_diff/__main__.py +++ b/data_diff/__main__.py @@ -1,6 +1,6 @@ -from multiprocessing.sharedctypes import Value import sys import time +import json import logging from itertools import islice @@ -51,6 +51,7 @@ @click.option("--max-age", default=None, help="Considers only rows younger than specified. See --min-age.") @click.option("-s", "--stats", is_flag=True, help="Print stats instead of a detailed diff") @click.option("-d", "--debug", is_flag=True, help="Print debug info") +@click.option("--json", 'json_output', is_flag=True, help="Print JSONL output for machine readability") @click.option("-v", "--verbose", is_flag=True, help="Print extra info") @click.option("-i", "--interactive", is_flag=True, help="Confirm queries, implies --debug") @click.option("--keep-column-case", is_flag=True, help="Don't use the schema to fix the case of given column names.") @@ -81,6 +82,7 @@ def main( interactive, threads, keep_column_case, + json_output, ): if limit and stats: print("Error: cannot specify a limit when using the -s/--stats switch") @@ -145,17 +147,35 @@ def main( if stats: diff = list(diff_iter) unique_diff_count = len({i[0] for _, i in diff}) - table1_count = differ.stats.get("table1_count") - percent = 100 * unique_diff_count / (table1_count or 1) - print(f"Diff-Total: {len(diff)} changed rows out of {table1_count}") - print(f"Diff-Percent: {percent:.4f}%") + max_table_count = max(differ.stats["table1_count"], differ.stats["table2_count"]) + percent = 100 * unique_diff_count / (max_table_count or 1) plus = len([1 for op, _ in diff if op == "+"]) minus = len([1 for op, _ in diff if op == "-"]) - print(f"Diff-Split: +{plus} -{minus}") + + if json_output: + json_output = { + "different_rows": len(diff), + "different_percent": percent, + "different_+": plus, + "different_-": minus, + "total": max_table_count, + } + print(json.dumps(json_output)) + else: + print(f"Diff-Total: {len(diff)} changed rows out of {max_table_count}") + print(f"Diff-Percent: {percent:.14f}%") + print(f"Diff-Split: +{plus} -{minus}") else: - for op, key in diff_iter: + for op, columns in diff_iter: color = COLOR_SCHEME[op] - rich.print(f"[{color}]{op} {key!r}[/{color}]") + + if json_output: + jsonl = json.dumps([op, list(columns)]) + rich.print(f"[{color}]{jsonl}[/{color}]") + else: + text = f"{op} {', '.join(columns)}" + rich.print(f"[{color}]{text}[/{color}]") + sys.stdout.flush() end = time.time() diff --git a/data_diff/database.py b/data_diff/database.py index 5e39e8a3..b89c2106 100644 --- a/data_diff/database.py +++ b/data_diff/database.py @@ -1,10 +1,11 @@ -from functools import lru_cache +import math +from functools import lru_cache, wraps from itertools import zip_longest import re from abc import ABC, abstractmethod from runtype import dataclass import logging -from typing import Tuple, Optional, List +from typing import Sequence, Tuple, Optional, List from concurrent.futures import ThreadPoolExecutor import threading from typing import Dict @@ -22,7 +23,25 @@ def parse_table_name(t): return tuple(t.split(".")) -def import_postgres(): +def import_helper(package: str = None, text=""): + def dec(f): + @wraps(f) + def _inner(): + try: + return f() + except ModuleNotFoundError as e: + s = text + if package: + s += f"You can install it using 'pip install data-diff[{package}]'." + raise ModuleNotFoundError(f"{e}\n\n{s}\n") + + return _inner + + return dec + + +@import_helper("postgresql") +def import_postgresql(): import psycopg2 import psycopg2.extras @@ -30,12 +49,14 @@ def import_postgres(): return psycopg2 +@import_helper("mysql") def import_mysql(): import mysql.connector return mysql.connector +@import_helper("snowflake") def import_snowflake(): import snowflake.connector @@ -54,15 +75,24 @@ def import_oracle(): return cx_Oracle +@import_helper("presto") def import_presto(): import prestodb return prestodb +@import_helper(text="Please install BigQuery and configure your google-cloud access.") +def import_bigquery(): + from google.cloud import bigquery + + return bigquery + + class ConnectError(Exception): pass + class QueryError(Exception): pass @@ -105,6 +135,26 @@ class Datetime(TemporalType): pass +@dataclass +class NumericType(ColType): + # 'precision' signifies how many fractional digits (after the dot) we want to compare + precision: int + + +class Float(NumericType): + pass + + +class Decimal(NumericType): + pass + + +@dataclass +class Integer(Decimal): + def __post_init__(self): + assert self.precision == 0 + + @dataclass class UnknownColType(ColType): text: str @@ -137,7 +187,7 @@ def select_table_schema(self, path: DbPath) -> str: ... @abstractmethod - def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: + def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: "Query the table for its schema for table in 'path', and return {column: type}" ... @@ -152,18 +202,55 @@ def close(self): ... @abstractmethod - def normalize_value_by_type(value: str, coltype: ColType) -> str: + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized timestamp. + + The returned expression must accept any SQL datetime/timestamp, and return a string. + + Date format: "YYYY-MM-DD HH:mm:SS.FFFFFF" + + Precision of dates should be rounded up/down according to coltype.rounds + """ + ... + + @abstractmethod + def normalize_number(self, value: str, coltype: ColType) -> str: + """Creates an SQL expression, that converts 'value' to a normalized number. + + The returned expression must accept any SQL int/numeric/float, and return a string. + + - Floats/Decimals are expected in the format + "I.P" + + Where I is the integer part of the number (as many digits as necessary), + and must be at least one digit (0). + P is the fractional digits, the amount of which is specified with + coltype.precision. Trailing zeroes may be necessary. + If P is 0, the dot is omitted. + + Note: This precision is different than the one used by databases. For decimals, + it's the same as ``numeric_scale``, and for floats, who use binary precision, + it can be calculated as ``log10(2**numeric_precision)``. + """ + ... + + def normalize_value_by_type(self, value: str, coltype: ColType) -> str: """Creates an SQL expression, that converts 'value' to a normalized representation. The returned expression must accept any SQL value, and return a string. - - Dates are expected in the format: - "YYYY-MM-DD HH:mm:SS.FFFFFF" + The default implementation dispatches to a method according to ``coltype``: - Rounded up/down according to coltype.rounds + TemporalType -> normalize_timestamp() + NumericType -> normalize_number() + -else- -> to_string() """ - ... + if isinstance(coltype, TemporalType): + return self.normalize_timestamp(value, coltype) + elif isinstance(coltype, NumericType): + return self.normalize_number(value, coltype) + return self.to_string(f"{value}") class Database(AbstractDatabase): @@ -174,8 +261,12 @@ class Database(AbstractDatabase): Instanciated using :meth:`~data_diff.connect_to_uri` """ - DATETIME_TYPES = NotImplemented - default_schema = NotImplemented + DATETIME_TYPES = {} + default_schema = None + + @property + def name(self): + return type(self).__name__ def query(self, sql_ast: SqlOrStr, res_type: type): "Query the given SQL code/AST, and attempt to convert the result to type 'res_type'" @@ -212,33 +303,69 @@ def query(self, sql_ast: SqlOrStr, res_type: type): def enable_interactive(self): self._interactive = True - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _convert_db_precision_to_digits(self, p: int) -> int: + """Convert from binary precision, used by floats, to decimal precision.""" + # See: https://en.wikipedia.org/wiki/Single-precision_floating-point_format + return math.floor(math.log(2**p, 10)) + + def _parse_type( + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: """ """ cls = self.DATETIME_TYPES.get(type_repr) if cls: return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, rounds=self.ROUNDS_ON_PREC_LOSS, ) + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + if issubclass(cls, Integer): + # Some DBs have a constant numeric_scale, so they don't report it. + # We fill in the constant, so we need to ignore it for integers. + return cls(precision=0) + + elif issubclass(cls, Decimal): + if numeric_scale is None: + raise ValueError(f"{self.name}: Unexpected numeric_scale is NULL, for column {col_name} of type {type_repr}.") + return cls(precision=numeric_scale) + + assert issubclass(cls, Float) + # assert numeric_scale is None + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + return UnknownColType(type_repr) def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - "SELECT column_name, data_type, datetime_precision, numeric_precision FROM information_schema.columns " + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: + def query_table_schema(self, path: DbPath, filter_columns: Optional[Sequence[str]] = None) -> Dict[str, ColType]: rows = self.query(self.select_table_schema(path), list) if not rows: - raise RuntimeError(f"{self.__class__.__name__}: Table '{'.'.join(path)}' does not exist, or has no columns") + raise RuntimeError(f"{self.name}: Table '{'.'.join(path)}' does not exist, or has no columns") - # Return a dict of form {name: type} after canonizaation - return {row[0]: self._parse_type(*row[1:]) for row in rows} + if filter_columns is not None: + accept = {i.lower() for i in filter_columns} + rows = [r for r in rows if r[0].lower() in accept] + + # Return a dict of form {name: type} after normalization + return {row[0]: self._parse_type(*row) for row in rows} # @lru_cache() # def get_table_schema(self, path: DbPath) -> Dict[str, ColType]: @@ -246,11 +373,12 @@ def query_table_schema(self, path: DbPath) -> Dict[str, ColType]: def _normalize_table_path(self, path: DbPath) -> DbPath: if len(path) == 1: - return self.default_schema, path[0] - elif len(path) == 2: - return path + if self.default_schema: + return self.default_schema, path[0] + elif len(path) != 2: + raise ValueError(f"{self.name}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") - raise ValueError(f"{self.__class__.__name__}: Bad table path for {self}: '{'.'.join(path)}'. Expected form: schema.table") + return path def parse_table_name(self, name: str) -> DbPath: return parse_table_name(name) @@ -263,12 +391,16 @@ class ThreadedDatabase(Database): """ def __init__(self, thread_count=1): + self._init_error = None self._queue = ThreadPoolExecutor(thread_count, initializer=self.set_conn) self.thread_local = threading.local() def set_conn(self): assert not hasattr(self.thread_local, "conn") - self.thread_local.conn = self.create_connection() + try: + self.thread_local.conn = self.create_connection() + except ModuleNotFoundError as e: + self._init_error = e def _query(self, sql_code: str): r = self._queue.submit(self._query_in_worker, sql_code) @@ -276,6 +408,8 @@ def _query(self, sql_code: str): def _query_in_worker(self, sql_code: str): "This method runs in a worker thread" + if self._init_error: + raise self._init_error return _query_conn(self.thread_local.conn, sql_code) def close(self): @@ -295,18 +429,27 @@ def close(self): _CHECKSUM_BITSIZE = CHECKSUM_HEXDIGITS << 2 CHECKSUM_MASK = (2**_CHECKSUM_BITSIZE) - 1 -DEFAULT_PRECISION = 6 +DEFAULT_DATETIME_PRECISION = 6 +DEFAULT_NUMERIC_PRECISION = 24 TIMESTAMP_PRECISION_POS = 20 # len("2022-06-03 12:24:35.") == 20 -class Postgres(ThreadedDatabase): +class PostgreSQL(ThreadedDatabase): DATETIME_TYPES = { "timestamp with time zone": TimestampTZ, "timestamp without time zone": Timestamp, "timestamp": Timestamp, # "datetime": Datetime, } + NUMERIC_TYPES = { + "double precision": Float, + "real": Float, + "decimal": Decimal, + "integer": Integer, + "numeric": Decimal, + "bigint": Integer, + } ROUNDS_ON_PREC_LOSS = True default_schema = "public" @@ -316,13 +459,17 @@ def __init__(self, host, port, user, password, *, database, thread_count, **kw): super().__init__(thread_count=thread_count) + def _convert_db_precision_to_digits(self, p: int) -> int: + # Subtracting 2 due to wierd precision issues in PostgreSQL + return super()._convert_db_precision_to_digits(p) - 2 + def create_connection(self): - postgres = import_postgres() + pg = import_postgresql() try: - c = postgres.connect(**self.args) + c = pg.connect(**self.args) # c.cursor().execute("SET TIME ZONE 'UTC'") return c - except postgres.OperationalError as e: + except pg.OperationalError as e: raise ConnectError(*e.args) from e def quote(self, s: str): @@ -334,24 +481,17 @@ def md5_to_int(self, s: str) -> str: def to_string(self, s: str): return f"{s}::varchar" - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - # if coltype.precision == 0: - # return f"to_char({value}::timestamp(0), 'YYYY-mm-dd HH24:MI:SS')" - # if coltype.precision == 3: - # return f"to_char({value}, 'YYYY-mm-dd HH24:MI:SS.US')" - # elif coltype.precision == 6: - # return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - # else: - # # Postgres/Redshift doesn't support arbitrary precision - # raise TypeError(f"Bad precision for {type(self).__name__}: {coltype})") - if coltype.rounds: - return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - else: - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + return f"to_char({value}::timestamp({coltype.precision}), 'YYYY-mm-dd HH24:MI:SS.US')" - return self.to_string(f"{value}") + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"{value}::decimal(38, {coltype.precision})") class Presto(Database): @@ -362,6 +502,11 @@ class Presto(Database): "timestamp": Timestamp, # "datetime": Datetime, } + NUMERIC_TYPES = { + "integer": Integer, + "real": Float, + "double": Float, + } ROUNDS_ON_PREC_LOSS = True def __init__(self, host, port, user, password, *, catalog, schema=None, **kw): @@ -386,22 +531,17 @@ def _query(self, sql_code: str) -> list: def close(self): self._conn.close() - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - if coltype.precision > 3: - pass - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - else: - s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - # datetime = f"date_format(cast({value} as timestamp(6), '%Y-%m-%d %H:%i:%S.%f'))" - # datetime = self.to_string(f"cast({value} as datetime(6))") + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + # TODO + if coltype.rounds: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" + else: + s = f"date_format(cast({value} as timestamp(6)), '%Y-%m-%d %H:%i:%S.%f')" - return ( - f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - ) + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" - return self.to_string(value) + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"cast({value} as decimal(38,{coltype.precision}))") def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) @@ -411,8 +551,9 @@ def select_table_schema(self, path: DbPath) -> str: f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: - """ """ + def _parse_type( + self, col_name: str, type_repr: str, datetime_precision: int = None, numeric_precision: int = None + ) -> ColType: regexps = { r"timestamp\((\d)\)": Timestamp, r"timestamp\((\d)\) with time zone": TimestampTZ, @@ -422,9 +563,30 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr if m: datetime_precision = int(m.group(1)) return cls( - precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, rounds=False + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=False, ) + regexps = {r"decimal\((\d+),(\d+)\)": Decimal} + for regexp, cls in regexps.items(): + m = re.match(regexp + "$", type_repr) + if m: + prec, scale = map(int, m.groups()) + return cls(scale) + + cls = self.NUMERIC_TYPES.get(type_repr) + if cls: + if issubclass(cls, Integer): + assert numeric_precision is not None + return cls(0) + + assert issubclass(cls, Float) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + return UnknownColType(type_repr) @@ -433,6 +595,12 @@ class MySQL(ThreadedDatabase): "datetime": Datetime, "timestamp": Timestamp, } + NUMERIC_TYPES = { + "double": Float, + "float": Float, + "decimal": Decimal, + "int": Integer, + } ROUNDS_ON_PREC_LOSS = True def __init__(self, host, port, user, password, *, database, thread_count, **kw): @@ -464,15 +632,15 @@ def md5_to_int(self, s: str) -> str: def to_string(self, s: str): return f"cast({s} as char)" - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") - else: - s = self.to_string(f"cast({value} as datetime(6))") - return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + return self.to_string(f"cast( cast({value} as datetime({coltype.precision})) as datetime(6))") - return self.to_string(f"{value}") + s = self.to_string(f"cast({value} as datetime(6))") + return f"RPAD(RPAD({s}, {TIMESTAMP_PRECISION_POS+coltype.precision}, '.'), {TIMESTAMP_PRECISION_POS+6}, '0')" + + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") class Oracle(ThreadedDatabase): @@ -513,16 +681,28 @@ def select_table_schema(self, path: DbPath) -> str: (table,) = path return ( - f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision" + f"SELECT column_name, data_type, 6 as datetime_precision, data_precision as numeric_precision, data_scale as numeric_scale" f" FROM USER_TAB_COLUMNS WHERE table_name = '{table.upper()}'" ) - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): - return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" - return self.to_string(f"{value}") + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + return f"to_char(cast({value} as timestamp({coltype.precision})), 'YYYY-MM-DD HH24:MI:SS.FF6')" + + def normalize_number(self, value: str, coltype: ColType) -> str: + # FM999.9990 + format_str = "FM" + "9" * (38 - coltype.precision) + if coltype.precision: + format_str += "0." + "9" * (coltype.precision - 1) + "0" + return f"to_char({value}, '{format_str}')" - def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_precision: int = None) -> ColType: + def _parse_type( + self, + col_name: str, + type_repr: str, + datetime_precision: int = None, + numeric_precision: int = None, + numeric_scale: int = None, + ) -> ColType: """ """ regexps = { r"TIMESTAMP\((\d)\) WITH LOCAL TIME ZONE": Timestamp, @@ -532,35 +712,73 @@ def _parse_type(self, type_repr: str, datetime_precision: int = None, numeric_pr m = re.match(regexp + "$", type_repr) if m: datetime_precision = int(m.group(1)) - return cls(precision=datetime_precision if datetime_precision is not None else DEFAULT_PRECISION, - rounds=self.ROUNDS_ON_PREC_LOSS + return cls( + precision=datetime_precision if datetime_precision is not None else DEFAULT_DATETIME_PRECISION, + rounds=self.ROUNDS_ON_PREC_LOSS, ) + cls = { + "NUMBER": Decimal, + "FLOAT": Float, + }.get(type_repr, None) + if cls: + if issubclass(cls, Decimal): + assert numeric_scale is not None, (type_repr, numeric_precision, numeric_scale) + return cls(precision=numeric_scale) + + assert issubclass(cls, Float) + return cls( + precision=self._convert_db_precision_to_digits( + numeric_precision if numeric_precision is not None else DEFAULT_NUMERIC_PRECISION + ) + ) + return UnknownColType(type_repr) -class Redshift(Postgres): +class Redshift(PostgreSQL): + NUMERIC_TYPES = { + **PostgreSQL.NUMERIC_TYPES, + "double": Float, + "real": Float, + } + + # def _convert_db_precision_to_digits(self, p: int) -> int: + # return super()._convert_db_precision_to_digits(p // 2) + def md5_to_int(self, s: str) -> str: return f"strtol(substring(md5({s}), {1+MD5_HEXDIGITS-CHECKSUM_HEXDIGITS}), 16)::decimal(38)" - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, TemporalType): - if coltype.rounds: - timestamp = f"{value}::timestamp(6)" - # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. - secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" - # Get the milliseconds from timestamp. - ms = f"extract(ms from {timestamp})" - # Get the microseconds from timestamp, without the milliseconds! - us = f"extract(us from {timestamp})" - # epoch = Total time since epoch in microseconds. - epoch = f"{secs}*1000000 + {ms}*1000 + {us}" - timestamp6 = f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" - else: - timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + timestamp = f"{value}::timestamp(6)" + # Get seconds since epoch. Redshift doesn't support milli- or micro-seconds. + secs = f"timestamp 'epoch' + round(extract(epoch from {timestamp})::decimal(38)" + # Get the milliseconds from timestamp. + ms = f"extract(ms from {timestamp})" + # Get the microseconds from timestamp, without the milliseconds! + us = f"extract(us from {timestamp})" + # epoch = Total time since epoch in microseconds. + epoch = f"{secs}*1000000 + {ms}*1000 + {us}" + timestamp6 = ( + f"to_char({epoch}, -6+{coltype.precision}) * interval '0.000001 seconds', 'YYYY-mm-dd HH24:MI:SS.US')" + ) + else: + timestamp6 = f"to_char({value}::timestamp(6), 'YYYY-mm-dd HH24:MI:SS.US')" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) - return self.to_string(f"{value}") + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"{value}::decimal(38,{coltype.precision})") + + def select_table_schema(self, path: DbPath) -> str: + schema, table = self._normalize_table_path(path) + + return ( + "SELECT column_name, data_type, datetime_precision, numeric_precision, numeric_scale FROM information_schema.columns " + f"WHERE table_name = '{table.lower()}' AND table_schema = '{schema.lower()}'" + ) class MsSQL(ThreadedDatabase): @@ -595,10 +813,18 @@ class BigQuery(Database): "TIMESTAMP": Timestamp, "DATETIME": Datetime, } + NUMERIC_TYPES = { + "INT64": Integer, + "INT32": Integer, + "NUMERIC": Decimal, + "BIGNUMERIC": Decimal, + "FLOAT64": Float, + "FLOAT32": Float, + } ROUNDS_ON_PREC_LOSS = False # Technically BigQuery doesn't allow implicit rounding or truncation def __init__(self, project, *, dataset, **kw): - from google.cloud import bigquery + bigquery = import_bigquery() self._client = bigquery.Client(project, **kw) self.project = project @@ -640,25 +866,29 @@ def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return ( - f"SELECT column_name, data_type, 6 as datetime_precision, 6 as numeric_precision FROM {schema}.INFORMATION_SCHEMA.COLUMNS " + f"SELECT column_name, data_type, 6 as datetime_precision, 38 as numeric_precision, 9 as numeric_scale FROM {schema}.INFORMATION_SCHEMA.COLUMNS " f"WHERE table_name = '{table}' AND table_schema = '{schema}'" ) - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): - if coltype.rounds: - timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - else: - if coltype.precision == 0: - return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" - elif coltype.precision == 6: - return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + timestamp = f"timestamp_micros(cast(round(unix_micros(cast({value} as timestamp))/1000000, {coltype.precision})*1000000 as int))" + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {timestamp})" - timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + if coltype.precision == 0: + return f"FORMAT_TIMESTAMP('%F %H:%M:%S.000000, {value})" + elif coltype.precision == 6: + return f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" - return self.to_string(f"{value}") + timestamp6 = f"FORMAT_TIMESTAMP('%F %H:%M:%E6S', {value})" + return ( + f"RPAD(LEFT({timestamp6}, {TIMESTAMP_PRECISION_POS+coltype.precision}), {TIMESTAMP_PRECISION_POS+6}, '0')" + ) + + def normalize_number(self, value: str, coltype: ColType) -> str: + if isinstance(coltype, Integer): + return self.to_string(value) + return f"format('%.{coltype.precision}f', {value})" def parse_table_name(self, name: str) -> DbPath: path = parse_table_name(name) @@ -671,6 +901,10 @@ class Snowflake(Database): "TIMESTAMP_LTZ": Timestamp, "TIMESTAMP_TZ": TimestampTZ, } + NUMERIC_TYPES = { + "NUMBER": Decimal, + "FLOAT": Float, + } ROUNDS_ON_PREC_LOSS = False def __init__( @@ -728,16 +962,16 @@ def select_table_schema(self, path: DbPath) -> str: schema, table = self._normalize_table_path(path) return super().select_table_schema((schema, table)) - def normalize_value_by_type(self, value: str, coltype: ColType) -> str: - if isinstance(coltype, PrecisionType): - if coltype.rounds: - timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" - else: - timestamp = f"cast({value} as timestamp({coltype.precision}))" + def normalize_timestamp(self, value: str, coltype: ColType) -> str: + if coltype.rounds: + timestamp = f"to_timestamp(round(date_part(epoch_nanosecond, {value}::timestamp(9))/1000000000, {coltype.precision}))" + else: + timestamp = f"cast({value} as timestamp({coltype.precision}))" - return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" + return f"to_char({timestamp}, 'YYYY-MM-DD HH24:MI:SS.FF6')" - return self.to_string(f"{value}") + def normalize_number(self, value: str, coltype: ColType) -> str: + return self.to_string(f"cast({value} as decimal(38, {coltype.precision}))") @dataclass @@ -790,7 +1024,7 @@ def match_path(self, dsn): MATCH_URI_PATH = { - "postgres": MatchUriPath(Postgres, ["database?"], help_str="postgres://:@/"), + "postgresql": MatchUriPath(PostgreSQL, ["database?"], help_str="postgresql://:@/"), "mysql": MatchUriPath(MySQL, ["database?"], help_str="mysql://:@/"), "oracle": MatchUriPath(Oracle, ["database?"], help_str="oracle://:@/"), "mssql": MatchUriPath(MsSQL, ["database?"], help_str="mssql://:@/"), @@ -819,7 +1053,7 @@ def connect_to_uri(db_uri: str, thread_count: Optional[int] = 1) -> Database: Note: For non-cloud databases, a low thread-pool size may be a performance bottleneck. Supported schemes: - - postgres + - postgresql - mysql - mssql - oracle diff --git a/data_diff/diff_tables.py b/data_diff/diff_tables.py index 4087b49d..aee65cca 100644 --- a/data_diff/diff_tables.py +++ b/data_diff/diff_tables.py @@ -12,7 +12,7 @@ from runtype import dataclass from .sql import Select, Checksum, Compare, DbPath, DbKey, DbTime, Count, TableName, Time, Min, Max -from .database import Database, PrecisionType, ColType +from .database import Database, NumericType, PrecisionType, ColType, UnknownColType logger = logging.getLogger("diff_tables") @@ -142,15 +142,16 @@ def with_schema(self) -> "TableSegment": "Queries the table schema from the database, and returns a new instance of TableSegmentWithSchema." if self._schema: return self - schema = self.database.query_table_schema(self.table_path) + + schema = self.database.query_table_schema(self.table_path, self._relevant_columns) if self.case_sensitive: schema = Schema_CaseSensitive(schema) else: if len({k.lower() for k in schema}) < len(schema): - logger.warn( + logger.warning( f'Ambiguous schema for {self.database}:{".".join(self.table_path)} | Columns = {", ".join(list(schema))}' ) - logger.warn("We recommend to disable case-insensitivity (remove --any-case).") + logger.warning("We recommend to disable case-insensitivity (remove --any-case).") schema = Schema_CaseInsensitive(schema) return self.new(_schema=schema) @@ -241,7 +242,7 @@ def count_and_checksum(self) -> Tuple[int, int]: ) duration = time.time() - start if duration > RECOMMENDED_CHECKSUM_DURATION: - logger.warn( + logger.warning( f"Checksum is taking longer than expected ({duration:.2f}s). " "We recommend increasing --bisection-factor or decreasing --threads." ) @@ -364,11 +365,32 @@ def _validate_and_adjust_columns(self, table1, table2): lowest = min(col1, col2, key=attrgetter("precision")) if col1.precision != col2.precision: - logger.warn(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") + logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") table1._schema[c] = col1.replace(precision=lowest.precision, rounds=lowest.rounds) table2._schema[c] = col2.replace(precision=lowest.precision, rounds=lowest.rounds) + elif isinstance(col1, NumericType): + if not isinstance(col2, NumericType): + raise TypeError(f"Incompatible types for column {c}: {col1} <-> {col2}") + + lowest = min(col1, col2, key=attrgetter("precision")) + + if col1.precision != col2.precision: + logger.warning(f"Using reduced precision {lowest} for column '{c}'. Types={col1}, {col2}") + + table1._schema[c] = col1.replace(precision=lowest.precision) + table2._schema[c] = col2.replace(precision=lowest.precision) + + for t in [table1, table2]: + for c in t._relevant_columns: + ctype = t._schema[c] + if isinstance(ctype, UnknownColType): + logger.warn( + f"[{t.database.name}] Column '{c}' of type '{ctype.text}' has no compatibility handling. " + "If encoding/formatting differs between databases, it may result in false positives." + ) + def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): assert table1.is_bounded and table2.is_bounded @@ -381,6 +403,14 @@ def _bisect_and_diff_tables(self, table1, table2, level=0, max_rows=None): if max_rows < self.bisection_threshold: rows1, rows2 = self._threaded_call("get_values", [table1, table2]) diff = list(diff_sets(rows1, rows2)) + + # Initial bisection_threshold larger than count. Normally we always + # checksum and count segments, even if we get the values. At the + # first level, however, that won't be true. + if level == 0: + self.stats["table1_count"] = len(rows1) + self.stats["table2_count"] = len(rows2) + logger.info(". " * level + f"Diff found {len(diff)} different rows.") self.stats["rows_downloaded"] = self.stats.get("rows_downloaded", 0) + max(len(rows1), len(rows2)) yield from diff @@ -412,7 +442,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun (count1, checksum1), (count2, checksum2) = self._threaded_call("count_and_checksum", [table1, table2]) if count1 == 0 and count2 == 0: - logger.warn( + logger.warning( "Uneven distribution of keys detected. (big gaps in the key column). " "For better performance, we recommend to increase the bisection-threshold." ) @@ -421,6 +451,7 @@ def _diff_tables(self, table1, table2, level=0, segment_index=None, segment_coun if level == 1: self.stats["table1_count"] = self.stats.get("table1_count", 0) + count1 + self.stats["table2_count"] = self.stats.get("table2_count", 0) + count2 if checksum1 != checksum2: yield from self._bisect_and_diff_tables(table1, table2, level=level, max_rows=max(count1, count2)) diff --git a/data_diff/sql.py b/data_diff/sql.py index 1c19aef1..a81839eb 100644 --- a/data_diff/sql.py +++ b/data_diff/sql.py @@ -46,7 +46,8 @@ class TableName(Sql): name: DbPath def compile(self, c: Compiler): - return ".".join(map(c.quote, self.name)) + path = c.database._normalize_table_path(self.name) + return ".".join(map(c.quote, path)) @dataclass diff --git a/docs/index.rst b/docs/index.rst index 6caab45d..372de44c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -5,11 +5,13 @@ python-api +Introduction +------------ -**data-diff** is a command-line tool and Python library to efficiently diff +**Data-diff** is a command-line tool and Python library to efficiently diff rows across two different databases. -⇄ Verifies across many different databases (e.g. Postgres -> Snowflake) ! +⇄ Verifies across many different databases (e.g. *PostgreSQL* -> *Snowflake*) ! 🔍 Outputs diff of rows in detail @@ -30,11 +32,11 @@ Requires Python 3.7+ with pip. pip install data-diff -or when you need extras like mysql and postgres: +or when you need extras like mysql and postgresql: :: - pip install "data-diff[mysql,pgsql]" + pip install "data-diff[mysql,postgresql]" How to use from Python @@ -48,12 +50,16 @@ How to use from Python from data_diff import connect_to_table, diff_tables - table1 = connect_to_table("postgres:///", "table_name", "id") + table1 = connect_to_table("postgresql:///", "table_name", "id") table2 = connect_to_table("mysql:///", "table_name", "id") - for different_row in diff_tables(table1, table2): - plus_or_minus, columns = different_row - print(plus_or_minus, columns) + for sign, columns in diff_tables(table1, table2): + print(sign, columns) + + # Example output: + + ('4775622148347', '2022-06-05 16:57:32.000000') + - ('4775622312187', '2022-06-05 16:57:32.000000') + - ('4777375432955', '2022-06-07 16:57:36.000000') Resources diff --git a/poetry.lock b/poetry.lock index 8841f2c3..7172ba56 100644 --- a/poetry.lock +++ b/poetry.lock @@ -20,7 +20,7 @@ python-versions = "*" [[package]] name = "certifi" -version = "2022.5.18.1" +version = "2022.6.15" description = "Python package for providing Mozilla's CA Bundle." category = "main" optional = false @@ -62,7 +62,7 @@ importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} [[package]] name = "colorama" -version = "0.4.4" +version = "0.4.5" description = "Cross-platform colored terminal text." category = "main" optional = false @@ -359,7 +359,7 @@ jupyter = ["ipywidgets (>=7.5.1,<8.0.0)"] [[package]] name = "runtype" -version = "0.2.4" +version = "0.2.6" description = "Type dispatch and validation for run-time Python" category = "main" optional = false @@ -443,7 +443,8 @@ testing = ["pytest (>=6)", "pytest-checkdocs (>=2.4)", "pytest-flake8", "pytest- [extras] mysql = ["mysql-connector-python"] -pgsql = ["psycopg2"] +oracle = [] +postgresql = ["psycopg2"] preql = ["preql"] presto = [] snowflake = ["snowflake-connector-python"] @@ -451,7 +452,7 @@ snowflake = ["snowflake-connector-python"] [metadata] lock-version = "1.1" python-versions = "^3.7" -content-hash = "cd595c78ae0024cb9d980d4a2d83d8011f82947fe557537eea0280057bcbb535" +content-hash = "e1b2b05a166d2d6d81bec8e15e562480998b6e578592a4a0ed04b6fb6a2e046c" [metadata.files] arrow = [ @@ -463,8 +464,8 @@ asn1crypto = [ {file = "asn1crypto-1.5.1.tar.gz", hash = "sha256:13ae38502be632115abf8a24cbe5f4da52e3b5231990aff31123c805306ccb9c"}, ] certifi = [ - {file = "certifi-2022.5.18.1-py3-none-any.whl", hash = "sha256:f1d53542ee8cbedbe2118b5686372fb33c297fcd6379b050cca0ef13a597382a"}, - {file = "certifi-2022.5.18.1.tar.gz", hash = "sha256:9c5705e395cd70084351dd8ad5c41e65655e08ce46f2ec9cf6c2c08390f71eb7"}, + {file = "certifi-2022.6.15-py3-none-any.whl", hash = "sha256:fe86415d55e84719d75f8b69414f6438ac3547d2078ab91b67e779ef69378412"}, + {file = "certifi-2022.6.15.tar.gz", hash = "sha256:84c85a9078b11105f04f3036a9482ae10e4621616db313fe045dd24743a0820d"}, ] cffi = [ {file = "cffi-1.15.0-cp27-cp27m-macosx_10_9_x86_64.whl", hash = "sha256:c2502a1a03b6312837279c8c1bd3ebedf6c12c4228ddbad40912d671ccc8a962"}, @@ -527,8 +528,8 @@ click = [ {file = "click-8.1.3.tar.gz", hash = "sha256:7682dc8afb30297001674575ea00d1814d808d6a36af415a82bd481d37ba7b8e"}, ] colorama = [ - {file = "colorama-0.4.4-py2.py3-none-any.whl", hash = "sha256:9f47eda37229f68eee03b24b9748937c7dc3868f906e8ba69fbcbdd3bc5dc3e2"}, - {file = "colorama-0.4.4.tar.gz", hash = "sha256:5941b2b48a20143d2267e95b1c2a7603ce057ee39fd88e7329b0c292aa16869b"}, + {file = "colorama-0.4.5-py2.py3-none-any.whl", hash = "sha256:854bf444933e37f5824ae7bfc1e98d5bce2ebe4160d46b5edf346a89358e99da"}, + {file = "colorama-0.4.5.tar.gz", hash = "sha256:e6c6b4334fc50988a639d9b98aa429a0b57da6e17b9a44f0451f930b6967b7a4"}, ] commonmark = [ {file = "commonmark-0.9.1-py2.py3-none-any.whl", hash = "sha256:da2f38c92590f83de410ba1a3cbceafbc74fee9def35f9251ba9a971d6d66fd9"}, @@ -701,8 +702,8 @@ rich = [ {file = "rich-10.16.2.tar.gz", hash = "sha256:720974689960e06c2efdb54327f8bf0cdbdf4eae4ad73b6c94213cad405c371b"}, ] runtype = [ - {file = "runtype-0.2.4-py3-none-any.whl", hash = "sha256:1adab62f867199536820898ce04df22586ba2a52084448385004faa532b19e97"}, - {file = "runtype-0.2.4.tar.gz", hash = "sha256:642f747b199fd872deb79d361d47ea83a1a0db49986fbeaa0c375d2bd9805e00"}, + {file = "runtype-0.2.6-py3-none-any.whl", hash = "sha256:1739136f46551240a9f68807d167b5acbe4c18512de08ebdcc2fa0648d97c834"}, + {file = "runtype-0.2.6.tar.gz", hash = "sha256:31818c1991c8b5e01ec2e54a53ad44af104dcf1bb4e82efd1aa7eba1047dc6dd"}, ] six = [ {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, diff --git a/pyproject.toml b/pyproject.toml index a1d60be4..11c697b1 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "data-diff" -version = "0.0.8" +version = "0.2.0" description = "Command-line tool and Python library to efficiently diff rows across two different databases." authors = ["Erez Shinnan ", "Simon Eskildsen "] license = "MIT" @@ -47,9 +47,10 @@ parameterized = "*" # When adding, update also: README + dev deps just above preql = ["preql"] mysql = ["mysql-connector-python"] -pgsql = ["psycopg2"] +postgresql = ["psycopg2"] snowflake = ["snowflake-connector-python"] presto = ["presto-python-client"] +oracle = ["cx_Oracle"] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/common.py b/tests/common.py index 33281861..1fd610a0 100644 --- a/tests/common.py +++ b/tests/common.py @@ -6,7 +6,7 @@ logging.basicConfig(level=logging.INFO) TEST_MYSQL_CONN_STRING: str = "mysql://mysql:Password1@localhost/mysql" -TEST_POSTGRES_CONN_STRING: str = None +TEST_POSTGRESQL_CONN_STRING: str = None TEST_SNOWFLAKE_CONN_STRING: str = None TEST_BIGQUERY_CONN_STRING: str = None TEST_REDSHIFT_CONN_STRING: str = None @@ -19,10 +19,14 @@ except ImportError: pass # No local settings +if TEST_BIGQUERY_CONN_STRING and TEST_SNOWFLAKE_CONN_STRING: + # TODO Fix this. Seems to have something to do with pyarrow + raise RuntimeError("Using BigQuery at the same time as Snowflake causes an error!!") + CONN_STRINGS = { - # db.BigQuery: TEST_BIGQUERY_CONN_STRING, # TODO BigQuery before/after Snowflake causes an error! + db.BigQuery: TEST_BIGQUERY_CONN_STRING, db.MySQL: TEST_MYSQL_CONN_STRING, - db.Postgres: TEST_POSTGRES_CONN_STRING, + db.PostgreSQL: TEST_POSTGRESQL_CONN_STRING, db.Snowflake: TEST_SNOWFLAKE_CONN_STRING, db.Redshift: TEST_REDSHIFT_CONN_STRING, db.Oracle: TEST_ORACLE_CONN_STRING, diff --git a/tests/test_api.py b/tests/test_api.py index cd5b9c19..2a532edd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -15,6 +15,7 @@ def setUpClass(cls): cls.preql = preql.Preql(TEST_MYSQL_CONN_STRING) def setUp(self) -> None: + self.preql = preql.Preql(TEST_MYSQL_CONN_STRING) self.preql( r""" table test_api { diff --git a/tests/test_database.py b/tests/test_database.py index eabed2f7..924925c2 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -22,7 +22,7 @@ def test_md5_to_int(self): class TestConnect(unittest.TestCase): def test_bad_uris(self): self.assertRaises(ValueError, connect_to_uri, "p") - self.assertRaises(ValueError, connect_to_uri, "postgres:///bla/foo") + self.assertRaises(ValueError, connect_to_uri, "postgresql:///bla/foo") self.assertRaises(ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1") self.assertRaises( ValueError, connect_to_uri, "snowflake://erez:erez27Snow@bya42734/xdiffdev/TEST1?warehouse=ha&schema=dup" diff --git a/tests/test_database_types.py b/tests/test_database_types.py index 6b95d310..2f665618 100644 --- a/tests/test_database_types.py +++ b/tests/test_database_types.py @@ -1,13 +1,20 @@ from contextlib import suppress import unittest import time +import logging +from decimal import Decimal + +from parameterized import parameterized, parameterized_class +import preql + from data_diff import database as db from data_diff.diff_tables import TableDiffer, TableSegment from parameterized import parameterized, parameterized_class from .common import CONN_STRINGS import logging -logging.getLogger("diff_tables").setLevel(logging.WARN) + +logging.getLogger("diff_tables").setLevel(logging.ERROR) logging.getLogger("database").setLevel(logging.WARN) CONNS = {k: db.connect_to_uri(v) for k, v in CONN_STRINGS.items()} @@ -24,11 +31,28 @@ "2022-05-01 15:10:03.003030", "2022-06-01 15:10:05.009900", ], - "float": [0.0, 0.1, 0.10, 10.0, 100.98], + "float": [ + 0.0, + 0.1, + 0.00188, + 0.99999, + 0.091919, + 0.10, + 10.0, + 100.98, + 0.001201923076923077, + 1 / 3, + 1 / 5, + 1 / 109, + 1 / 109489, + 1 / 1094893892389, + 1 / 10948938923893289, + 3.141592653589793, + ], } DATABASE_TYPES = { - db.Postgres: { + db.PostgreSQL: { # https://www.postgresql.org/docs/current/datatype-numeric.html#DATATYPE-INT "int": [ # "smallint", # 2 bytes @@ -43,9 +67,10 @@ ], # https://www.postgresql.org/docs/current/datatype-numeric.html "float": [ - # "real", - # "double precision", - # "numeric(6,3)", + "real", + "float", + "double precision", + "numeric(6,3)", ], }, db.MySQL: { @@ -58,12 +83,19 @@ # "bigint", # 8 bytes ], # https://dev.mysql.com/doc/refman/8.0/en/datetime.html - "datetime_no_timezone": ["timestamp(6)", "timestamp(3)", "timestamp(0)", "timestamp", "datetime(6)"], + "datetime_no_timezone": [ + "timestamp(6)", + "timestamp(3)", + "timestamp(0)", + "timestamp", + "datetime(6)", + ], # https://dev.mysql.com/doc/refman/8.0/en/numeric-types.html "float": [ - # "float", - # "double", - # "numeric", + "float", + "double", + "numeric", + "numeric(65, 10)", ], }, db.BigQuery: { @@ -71,6 +103,11 @@ "timestamp", # "datetime", ], + "float": [ + "numeric", + "float64", + "bignumeric", + ], }, db.Snowflake: { # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#int-integer-bigint-smallint-tinyint-byteint @@ -92,8 +129,8 @@ ], # https://docs.snowflake.com/en/sql-reference/data-types-numeric.html#decimal-numeric "float": [ - # "float" - # "numeric", + "float", + "numeric", ], }, db.Redshift: { @@ -105,9 +142,9 @@ ], # https://docs.aws.amazon.com/redshift/latest/dg/r_Numeric_types201.html#r_Numeric_types201-floating-point-types "float": [ - # "float4", - # "float8", - # "numeric", + "float4", + "float8", + "numeric", ], }, db.Oracle: { @@ -120,8 +157,8 @@ "timestamp(9) with local time zone", ], "float": [ - # "float", - # "numeric", + "float", + "numeric", ], }, db.Presto: { @@ -132,11 +169,18 @@ # "int", # 4 bytes # "bigint", # 8 bytes ], - "datetime_no_timezone": ["timestamp(6)", "timestamp(3)", "timestamp(0)", "timestamp", "datetime(6)"], + "datetime_no_timezone": [ + "timestamp(6)", + "timestamp(3)", + "timestamp(0)", + "timestamp", + "datetime(6)", + ], "float": [ - # "float", - # "double", - # "numeric", + "real", + "double", + "decimal(10,2)", + "decimal(30,6)", ], }, } @@ -150,7 +194,10 @@ # target_type: (int, bigint) } for source_db, source_type_categories in DATABASE_TYPES.items(): for target_db, target_type_categories in DATABASE_TYPES.items(): - for type_category, source_types in source_type_categories.items(): # int, datetime, .. + for ( + type_category, + source_types, + ) in source_type_categories.items(): # int, datetime, .. for source_type in source_types: for target_type in target_type_categories[type_category]: if CONNS.get(source_db, False) and CONNS.get(target_db, False): @@ -184,25 +231,38 @@ def _insert_to_table(conn, table, values): if isinstance(conn, db.Oracle): selects = [] for j, sample in values: - selects.append( f"SELECT {j}, timestamp '{sample}' FROM dual" ) - insertion_query += ' UNION ALL '.join(selects) + if isinstance(sample, (float, Decimal, int)): + value = str(sample) + else: + value = f"timestamp '{sample}'" + selects.append(f"SELECT {j}, {value} FROM dual") + insertion_query += " UNION ALL ".join(selects) else: - insertion_query += ' VALUES ' + insertion_query += " VALUES " for j, sample in values: - insertion_query += f"({j}, '{sample}')," + if isinstance(sample, (float, Decimal)): + value = str(sample) + else: + value = f"'{sample}'" + insertion_query += f"({j}, {value})," + insertion_query = insertion_query[0:-1] conn.query(insertion_query, None) + if not isinstance(conn, db.BigQuery): conn.query("COMMIT", None) + def _drop_table_if_exists(conn, table): with suppress(db.QueryError): if isinstance(conn, db.Oracle): conn.query(f"DROP TABLE {table}", None) + conn.query(f"DROP TABLE {table}", None) else: conn.query(f"DROP TABLE IF EXISTS {table}", None) + class TestDiffCrossDatabaseTables(unittest.TestCase): @parameterized.expand(type_pairs, name_func=expand_params) def test_types(self, source_db, target_db, source_type, target_type, type_category): @@ -214,8 +274,12 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego self.connections = [self.src_conn, self.dst_conn] sample_values = TYPE_SAMPLES[type_category] - src_table_path = src_conn.parse_table_name("src") - dst_table_path = dst_conn.parse_table_name("dst") + # Limit in MySQL is 64 + src_table_name = f"src_{self._testMethodName[:60]}" + dst_table_name = f"dst_{self._testMethodName[:60]}" + + src_table_path = src_conn.parse_table_name(src_table_name) + dst_table_path = dst_conn.parse_table_name(dst_table_name) src_table = src_conn.quote(".".join(src_table_path)) dst_table = dst_conn.quote(".".join(dst_table_path)) @@ -250,4 +314,3 @@ def test_types(self, source_db, target_db, source_type, target_type, type_catego duration = time.time() - start # print(f"source_db={source_db.__name__} target_db={target_db.__name__} source_type={source_type} target_type={target_type} duration={round(duration * 1000, 2)}ms") - diff --git a/tests/test_diff_tables.py b/tests/test_diff_tables.py index a457081d..3cd97212 100644 --- a/tests/test_diff_tables.py +++ b/tests/test_diff_tables.py @@ -32,9 +32,24 @@ def tearDownClass(cls): cls.preql.close() cls.connection.close() + # Fallback for test runners that doesn't support setUpClass/tearDownClass + def setUp(self) -> None: + if not hasattr(self, "connection"): + self.setUpClass.__func__(self) + self.private_connection = True + + return super().setUp() + + def tearDown(self) -> None: + if hasattr(self, "private_connection"): + self.tearDownClass.__func__(self) + + return super().tearDown() + class TestDates(TestWithConnection): def setUp(self): + super().setUp() self.connection.query("DROP TABLE IF EXISTS a", None) self.connection.query("DROP TABLE IF EXISTS b", None) self.preql( @@ -110,6 +125,7 @@ def test_offset(self): class TestDiffTables(TestWithConnection): def setUp(self): + super().setUp() self.connection.query("DROP TABLE IF EXISTS ratings_test", None) self.connection.query("DROP TABLE IF EXISTS ratings_test2", None) self.preql.load("./tests/setup.pql") @@ -155,6 +171,8 @@ def test_diff_small_tables(self): diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("2", time + ".000000"))] self.assertEqual(expected, diff) + self.assertEqual(2, self.differ.stats["table1_count"]) + self.assertEqual(1, self.differ.stats["table2_count"]) def test_diff_table_above_bisection_threshold(self): time = "2022-01-01 00:00:00" @@ -176,6 +194,8 @@ def test_diff_table_above_bisection_threshold(self): diff = list(self.differ.diff_tables(self.table, self.table2)) expected = [("-", ("5", time + ".000000"))] self.assertEqual(expected, diff) + self.assertEqual(5, self.differ.stats["table1_count"]) + self.assertEqual(4, self.differ.stats["table2_count"]) def test_return_empty_array_when_same(self): time = "2022-01-01 00:00:00" @@ -221,9 +241,9 @@ def test_diff_sorted_by_key(self): class TestTableSegment(TestWithConnection): def setUp(self) -> None: + super().setUp() self.table = TableSegment(self.connection, ("ratings_test",), "id", "timestamp") self.table2 = TableSegment(self.connection, ("ratings_test2",), "id", "timestamp") - return super().setUp() def test_table_segment(self): early = datetime.datetime(2021, 1, 1, 0, 0) diff --git a/tests/test_normalize_fields.py b/tests/test_normalize_fields.py index 2953f8ad..7893022f 100644 --- a/tests/test_normalize_fields.py +++ b/tests/test_normalize_fields.py @@ -5,7 +5,7 @@ import preql -from data_diff.database import BigQuery, MySQL, Snowflake, connect_to_uri, Oracle, DEFAULT_PRECISION +from data_diff.database import BigQuery, MySQL, Snowflake, connect_to_uri, Oracle from data_diff.sql import Select from data_diff import database as db @@ -14,7 +14,7 @@ logger = logging.getLogger() DATE_TYPES = { - db.Postgres: ["timestamp({p}) with time zone", "timestamp({p}) without time zone"], + db.PostgreSQL: ["timestamp({p}) with time zone", "timestamp({p}) without time zone"], db.MySQL: ["datetime({p})", "timestamp({p})"], db.Snowflake: ["timestamp({p})", "timestamp_tz({p})", "timestamp_ntz({p})"], db.BigQuery: ["timestamp", "datetime"],