diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py index 338381549b..af830162bd 100644 --- a/sqlmesh/core/engine_adapter/athena.py +++ b/sqlmesh/core/engine_adapter/athena.py @@ -78,10 +78,9 @@ def s3_warehouse_location_or_raise(self) -> str: @property def catalog_support(self) -> CatalogSupport: - # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that - # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog" - # are pointers to the "awsdatacatalog" of other AWS accounts - return CatalogSupport.SINGLE_CATALOG_ONLY + # Athena supports querying and writing to multiple catalogs (e.g. awsdatacatalog and s3tablescatalog) + # without needing a SET CATALOG command. + return CatalogSupport.FULL_SUPPORT def create_state_table( self, @@ -105,6 +104,9 @@ def _get_data_objects( """ schema_name = to_schema(schema_name) schema = schema_name.db + + info_schema_tables = exp.table_("tables", db="information_schema", catalog=schema_name.catalog, alias="t") + query = ( exp.select( exp.column("table_catalog").as_("catalog"), @@ -118,7 +120,7 @@ def _get_data_objects( .else_(exp.column("table_type", table="t")) .as_("type"), ) - .from_(exp.to_table("information_schema.tables", alias="t")) + .from_(info_schema_tables) .where(exp.column("table_schema", table="t").eq(schema)) ) if object_names: @@ -136,22 +138,115 @@ def _get_data_objects( for row in df.itertuples() ] + def table_exists(self, table_name: TableName) -> bool: + from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key + table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None + + try: + # We don't use DESCRIBE because it fails with "Unsupported ddl with 2 catalogs" + # for cross-catalog queries in Athena. + # And since table_exists isn't run with the set_catalog decorator (which sets QueryExecutionContext), + # we must fallback to a query that works with fully qualified names or + # uses the information_schema/limit 0. A limit 0 select works with fully qualified names in Athena. + self.execute(exp.select("1").from_(table).limit(0)) + return True + except Exception: + return False + def columns( self, table_name: TableName, include_pseudo_columns: bool = False ) -> t.Dict[str, exp.DataType]: table = exp.to_table(table_name) # note: the data_type column contains the full parameterized type, eg 'varchar(10)' + + info_schema_columns = exp.table_("columns", db="information_schema", catalog=table.catalog) + query = ( exp.select("column_name", "data_type") - .from_("information_schema.columns") + .from_(info_schema_columns) .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name)) .order_by("ordinal_position") ) - result = self.fetchdf(query, quote_identifiers=True) - return { - str(r.column_name): exp.DataType.build(str(r.data_type)) - for r in result.itertuples(index=False) - } + + try: + result = self.fetchdf(query, quote_identifiers=True) + return { + str(r.column_name): exp.DataType.build(str(r.data_type)) + for r in result.itertuples(index=False) + } + except Exception as e: + # If information_schema query fails, we fallback to DESCRIBE. + # But DESCRIBE with multiple catalogs fails in Athena, so we strip the catalog here + # and rely on the set_current_catalog mechanism (applied at the EngineAdapter method level) + # to set the catalog in the execution context. + describe_table = table.copy() + catalog = describe_table.catalog + current_catalog = self.get_current_catalog() + + if catalog and catalog != self._default_catalog: + describe_table.set("catalog", None) + if catalog != current_catalog: + self.set_current_catalog(catalog) + + try: + self.execute(exp.Describe(this=describe_table, kind="TABLE")) + + from sqlmesh.core.engine_adapter.base import _decoded_str + import itertools + describe_output = self.cursor.fetchall() + return { + # Note: MySQL returns the column type as bytes. + column_name: exp.DataType.build(_decoded_str(column_type), dialect=self.dialect) + for column_name, column_type, *_ in itertools.takewhile( + lambda t: not t[0].startswith("#"), + describe_output, + ) + if column_name and column_name.strip() and column_type and column_type.strip() + } + finally: + if catalog and catalog != self._default_catalog and current_catalog != catalog: + if current_catalog is not None: + self.set_current_catalog(current_catalog) + + def _drop_object( + self, + name: TableName | SchemaName, + exists: bool = True, + kind: str = "TABLE", + cascade: bool = False, + **drop_args: t.Any, + ) -> None: + if cascade and kind.upper() in self.SUPPORTED_DROP_CASCADE_OBJECT_KINDS: + drop_args["cascade"] = cascade + + target_table = exp.to_table(name).copy() + is_schema = kind.upper() == "SCHEMA" + catalog = target_table.db if is_schema else target_table.catalog + + if catalog and catalog != self._default_catalog: + if is_schema: + target_table.set("db", None) + else: + target_table.set("catalog", None) + + current_catalog = self.get_current_catalog() + if current_catalog != catalog: + self.set_current_catalog(catalog) + + try: + self.execute(exp.Drop(this=target_table, kind=kind, exists=exists, **drop_args)) + finally: + if current_catalog is not None and current_catalog != catalog: + self.set_current_catalog(current_catalog) + else: + self.execute(exp.Drop(this=target_table, kind=kind, exists=exists, **drop_args)) + + self._clear_data_object_cache(name) def _create_schema( self, @@ -161,11 +256,40 @@ def _create_schema( properties: t.List[exp.Expr], kind: str, ) -> None: + schema = to_schema(schema_name) + if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)): # don't add extra LocationProperty's if one already exists if not any(p for p in properties if isinstance(p, exp.LocationProperty)): properties.append(location) + if schema.catalog and schema.catalog != self._default_catalog: + target_schema = schema.copy() + catalog = target_schema.catalog + target_schema.set("catalog", None) + + current_catalog = self.get_current_catalog() + if current_catalog != catalog: + self.set_current_catalog(catalog) + + try: + self.execute( + exp.Create( + this=target_schema, + kind=kind, + exists=ignore_if_exists, + properties=exp.Properties(expressions=properties), + ) + ) + except Exception as e: + if not warn_on_error: + raise + logger.warning("Failed to create %s '%s': %s", kind.lower(), schema_name, e) + finally: + if current_catalog is not None and current_catalog != catalog: + self.set_current_catalog(current_catalog) + return + return super()._create_schema( schema_name=schema_name, ignore_if_exists=ignore_if_exists, @@ -174,6 +298,76 @@ def _create_schema( kind=kind, ) + def _get_temp_table( + self, table: TableName, table_only: bool = False, quoted: bool = True + ) -> exp.Table: + """ + Returns the name of the temp table that should be used for the given table name. + """ + from sqlmesh.utils import random_id + + table = t.cast(exp.Table, exp.to_table(table).copy()) + + # AWS S3 Tables (and Athena generally) prefer or require table names to start with a letter. + # S3 Tables specifically fail with: "The specified table name is not valid" if it starts with __temp_ + table.set( + "this", exp.to_identifier(f"temp_{table.name}_{random_id(short=True)}", quoted=quoted) + ) + + if table_only: + table.set("db", None) + table.set("catalog", None) + + return table + + def _create_table( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expr], + exists: bool = True, + replace: bool = False, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, + **kwargs: t.Any, + ) -> None: + table: exp.Table + if isinstance(table_name_or_schema, str): + table = exp.to_table(table_name_or_schema) + elif isinstance(table_name_or_schema, exp.Schema): + table = table_name_or_schema.this + else: + table = table_name_or_schema + + catalog = table.catalog + current_catalog = self.get_current_catalog() + + # For non-CTAS CREATE TABLE in a non-default catalog, the catalog is stripped by _build_create_table_exp. + # We need to set the query execution context here. + if not expression and catalog and catalog != self._default_catalog: + if current_catalog != catalog: + self.set_current_catalog(catalog) + + try: + super()._create_table( + table_name_or_schema=table_name_or_schema, + expression=expression, + exists=exists, + replace=replace, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + table_kind=table_kind, + track_rows_processed=track_rows_processed, + **kwargs, + ) + finally: + if not expression and catalog and catalog != self._default_catalog: + if current_catalog is not None and current_catalog != catalog: + self.set_current_catalog(current_catalog) + def _build_create_table_exp( self, table_name_or_schema: t.Union[exp.Schema, TableName], @@ -197,6 +391,11 @@ def _build_create_table_exp( else: table = table_name_or_schema + table_format = kwargs.pop("table_format", None) + if not table_format and table_properties and "table_format" in table_properties: + tf = table_properties.get("table_format") + table_format = tf.name if isinstance(tf, exp.Literal) else str(tf) + properties = self._build_table_properties_exp( table=table, expression=expression, @@ -205,10 +404,11 @@ def _build_create_table_exp( table_properties=table_properties, table_description=table_description, table_kind=table_kind, + table_format=table_format, **kwargs, ) - is_hive = self._table_type(kwargs.get("table_format", None)) == "hive" + is_hive = self._table_type(table_format) == "hive" # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html @@ -221,8 +421,20 @@ def _build_create_table_exp( ] table_name_or_schema.args["expressions"] = filtered_expressions + create_table = table_name_or_schema.copy() + + # When creating a table without AS SELECT, Athena fails with "Unsupported ddl with 2 catalogs" + # if a custom catalog like s3tablescatalog/supply is provided in the CREATE TABLE statement. + # It requires the catalog to be provided via QueryExecutionContext instead. + # The set_catalog decorator (which calls set_current_catalog) passes it to the QueryExecutionContext. + # But we also need to strip it from the generated CREATE TABLE statement. + # Note: We must strip the catalog from the table in the schema if table_name_or_schema is a schema. + target_table = create_table.this if isinstance(create_table, exp.Schema) else create_table + if not expression and target_table.catalog and target_table.catalog != self._default_catalog: + target_table.set("catalog", None) + return exp.Create( - this=table_name_or_schema, + this=create_table, kind=table_kind or "TABLE", replace=replace, exists=exists, @@ -247,17 +459,36 @@ def _build_table_properties_exp( **kwargs: t.Any, ) -> t.Optional[exp.Properties]: properties: t.List[exp.Expr] = [] - table_properties = table_properties or {} + table_properties = table_properties.copy() if table_properties else {} + + s3_table_prop = table_properties.pop("s3_table", None) + is_s3_table = False + if s3_table_prop is not None: + if isinstance(s3_table_prop, exp.Boolean): + is_s3_table = s3_table_prop.this + elif isinstance(s3_table_prop, exp.Literal): + is_s3_table = s3_table_prop.name.lower() in ("true", "1") + else: + is_s3_table = str(s3_table_prop).lower() in ("true", "1") + elif table and table.catalog and table.catalog.startswith("s3tablescatalog/"): + is_s3_table = True + + tf = table_properties.pop("table_format", None) + if not table_format and tf: + table_format = tf.name if isinstance(tf, exp.Literal) else str(tf) is_hive = self._table_type(table_format) == "hive" is_iceberg = not is_hive + if is_s3_table and is_hive: + raise SQLMeshError("Amazon S3 Tables only support the Iceberg format") + if is_hive and not expression: # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE # Unless it's a CTAS, those are always CREATE TABLE properties.append(exp.ExternalProperty()) - if table_format: + if table_format and not is_s3_table: properties.append( exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format)) ) @@ -279,9 +510,30 @@ def _build_table_properties_exp( else: schema_expressions = partitioned_by - properties.append( - exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) - ) + if is_hive: + properties.append( + exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) + ) + else: + if is_s3_table: + array_exprs = [] + for e in schema_expressions: + e_copy = e.copy() + e_copy.transform( + lambda n: n.name if isinstance(n, exp.Identifier) else n, copy=False + ) + expr_sql = e_copy.sql(dialect="athena") + array_exprs.append(exp.Literal.string(expr_sql)) + + properties.append( + exp.Property( + this=exp.var("partitioning"), value=exp.Array(expressions=array_exprs) + ) + ) + else: + properties.append( + exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) + ) if clustered_by: # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO BUCKETS @@ -293,13 +545,16 @@ def _build_table_properties_exp( if storage_format: if is_iceberg: - # TBLPROPERTIES('format'='parquet') - table_properties["format"] = exp.Literal.string(storage_format) + if not is_s3_table or storage_format.lower() == "parquet": + # TBLPROPERTIES('format'='parquet') + table_properties["format"] = exp.Literal.string(storage_format) + elif is_s3_table and storage_format.lower() != "parquet": + raise SQLMeshError("Amazon S3 Tables only support the PARQUET storage format") else: # STORED AS PARQUET properties.append(exp.FileFormatProperty(this=storage_format)) - if table and (location := self._table_location_or_raise(table_properties, table)): + if table and not is_s3_table and (location := self._table_location_or_raise(table_properties, table)): properties.append(location) if is_iceberg and expression: @@ -308,8 +563,28 @@ def _build_table_properties_exp( # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties properties.append(exp.Property(this=exp.var("is_external"), value="false")) - for name, value in table_properties.items(): - properties.append(exp.Property(this=exp.var(name), value=value)) + if not is_s3_table: + for name, value in table_properties.items(): + properties.append(exp.Property(this=exp.var(name), value=value)) + elif is_s3_table: + # According to AWS documentation for S3 Tables CTAS queries: + # "The `table_type` property defaults to `ICEBERG`, so you don't need to explicitly specify it" + # "If you don't specify a format, the system automatically uses `PARQUET`" + # We explicitly prevent all TBLPROPERTIES because Athena doesn't support them during CTAS + if expression: + # the only property allowed in CTAS for S3 Tables is 'format' (which we captured above) + format_val = table_properties.pop("format", exp.Literal.string("PARQUET")) + # Ensure it's uppercase PARQUET for S3 Tables just to be safe as per AWS examples + if isinstance(format_val, exp.Literal) and format_val.name.lower() == "parquet": + format_val = exp.Literal.string("PARQUET") + properties.append(exp.Property(this=exp.var("format"), value=format_val)) + + if table_properties: + logging.warning(f"Ignoring unsupported table properties for S3 Table CTAS: {list(table_properties.keys())}") + else: + # Standard CREATE TABLE for S3 Tables allows properties + for name, value in table_properties.items(): + properties.append(exp.Property(this=exp.var(name), value=value)) if properties: return exp.Properties(expressions=properties) @@ -364,11 +639,29 @@ def _query_table_type_or_raise(self, table: exp.Table) -> TableType: """ # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks) - for row in self.fetchall(f"SHOW TBLPROPERTIES {table.sql(dialect='hive', identify=True)}"): - # This query returns a single column with values like 'EXTERNAL\tTRUE' - row_lower = row[0].lower() - if "external" in row_lower and "true" in row_lower: - return "hive" + target_table = table.copy() + if target_table.catalog and target_table.catalog != self._default_catalog: + catalog = target_table.catalog + target_table.set("catalog", None) + + current_catalog = self.get_current_catalog() + if current_catalog != catalog: + self.set_current_catalog(catalog) + + try: + for row in self.fetchall(f"SHOW TBLPROPERTIES {target_table.sql(dialect='hive', identify=True)}"): + row_lower = row[0].lower() + if "external" in row_lower and "true" in row_lower: + return "hive" + finally: + if current_catalog is not None and current_catalog != catalog: + self.set_current_catalog(current_catalog) + else: + for row in self.fetchall(f"SHOW TBLPROPERTIES {target_table.sql(dialect='hive', identify=True)}"): + # This query returns a single column with values like 'EXTERNAL\tTRUE' + row_lower = row[0].lower() + if "external" in row_lower and "true" in row_lower: + return "hive" return "iceberg" def _is_hive_partitioned_table(self, table: exp.Table) -> bool: @@ -618,5 +911,10 @@ def _boto3_client(self, name: str) -> t.Any: **conn._client_kwargs, ) # type: ignore + def set_current_catalog(self, catalog: str) -> None: + self.connection.catalog_name = catalog + if hasattr(self.cursor, "_catalog_name"): + self.cursor._catalog_name = catalog + def get_current_catalog(self) -> t.Optional[str]: return self.connection.catalog_name