sqlglot.transforms
1from __future__ import annotations 2 3import typing as t 4 5from sqlglot import expressions as exp 6from sqlglot.helper import find_new_name, name_sequence 7 8if t.TYPE_CHECKING: 9 from sqlglot.generator import Generator 10 11 12def preprocess( 13 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 14) -> t.Callable[[Generator, exp.Expression], str]: 15 """ 16 Creates a new transform by chaining a sequence of transformations and converts the resulting 17 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 18 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 19 20 Args: 21 transforms: sequence of transform functions. These will be called in order. 22 23 Returns: 24 Function that can be used as a generator transform. 25 """ 26 27 def _to_sql(self, expression: exp.Expression) -> str: 28 expression_type = type(expression) 29 30 expression = transforms[0](expression) 31 for transform in transforms[1:]: 32 expression = transform(expression) 33 34 _sql_handler = getattr(self, expression.key + "_sql", None) 35 if _sql_handler: 36 return _sql_handler(expression) 37 38 transforms_handler = self.TRANSFORMS.get(type(expression)) 39 if transforms_handler: 40 if expression_type is type(expression): 41 if isinstance(expression, exp.Func): 42 return self.function_fallback_sql(expression) 43 44 # Ensures we don't enter an infinite loop. This can happen when the original expression 45 # has the same type as the final expression and there's no _sql method available for it, 46 # because then it'd re-enter _to_sql. 47 raise ValueError( 48 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 49 ) 50 51 return transforms_handler(self, expression) 52 53 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 54 55 return _to_sql 56 57 58def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 59 if isinstance(expression, exp.Select): 60 count = 0 61 recursive_ctes = [] 62 63 for unnest in expression.find_all(exp.Unnest): 64 if ( 65 not isinstance(unnest.parent, (exp.From, exp.Join)) 66 or len(unnest.expressions) != 1 67 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 68 ): 69 continue 70 71 generate_date_array = unnest.expressions[0] 72 start = generate_date_array.args.get("start") 73 end = generate_date_array.args.get("end") 74 step = generate_date_array.args.get("step") 75 76 if not start or not end or not isinstance(step, exp.Interval): 77 continue 78 79 alias = unnest.args.get("alias") 80 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 81 82 start = exp.cast(start, "date") 83 date_add = exp.func( 84 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 85 ) 86 cast_date_add = exp.cast(date_add, "date") 87 88 cte_name = "_generated_dates" + (f"_{count}" if count else "") 89 90 base_query = exp.select(start.as_(column_name)) 91 recursive_query = ( 92 exp.select(cast_date_add) 93 .from_(cte_name) 94 .where(cast_date_add <= exp.cast(end, "date")) 95 ) 96 cte_query = base_query.union(recursive_query, distinct=False) 97 98 generate_dates_query = exp.select(column_name).from_(cte_name) 99 unnest.replace(generate_dates_query.subquery(cte_name)) 100 101 recursive_ctes.append( 102 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 103 ) 104 count += 1 105 106 if recursive_ctes: 107 with_expression = expression.args.get("with") or exp.With() 108 with_expression.set("recursive", True) 109 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 110 expression.set("with", with_expression) 111 112 return expression 113 114 115def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 116 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 117 this = expression.this 118 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 119 unnest = exp.Unnest(expressions=[this]) 120 if expression.alias: 121 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 122 123 return unnest 124 125 return expression 126 127 128def unalias_group(expression: exp.Expression) -> exp.Expression: 129 """ 130 Replace references to select aliases in GROUP BY clauses. 131 132 Example: 133 >>> import sqlglot 134 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 135 'SELECT a AS b FROM x GROUP BY 1' 136 137 Args: 138 expression: the expression that will be transformed. 139 140 Returns: 141 The transformed expression. 142 """ 143 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 144 aliased_selects = { 145 e.alias: i 146 for i, e in enumerate(expression.parent.expressions, start=1) 147 if isinstance(e, exp.Alias) 148 } 149 150 for group_by in expression.expressions: 151 if ( 152 isinstance(group_by, exp.Column) 153 and not group_by.table 154 and group_by.name in aliased_selects 155 ): 156 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 157 158 return expression 159 160 161def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 162 """ 163 Convert SELECT DISTINCT ON statements to a subquery with a window function. 164 165 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 166 167 Args: 168 expression: the expression that will be transformed. 169 170 Returns: 171 The transformed expression. 172 """ 173 if ( 174 isinstance(expression, exp.Select) 175 and expression.args.get("distinct") 176 and expression.args["distinct"].args.get("on") 177 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 178 ): 179 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 180 outer_selects = expression.selects 181 row_number = find_new_name(expression.named_selects, "_row_number") 182 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 183 order = expression.args.get("order") 184 185 if order: 186 window.set("order", order.pop()) 187 else: 188 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 189 190 window = exp.alias_(window, row_number) 191 expression.select(window, copy=False) 192 193 return ( 194 exp.select(*outer_selects, copy=False) 195 .from_(expression.subquery("_t", copy=False), copy=False) 196 .where(exp.column(row_number).eq(1), copy=False) 197 ) 198 199 return expression 200 201 202def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 203 """ 204 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 205 206 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 207 https://docs.snowflake.com/en/sql-reference/constructs/qualify 208 209 Some dialects don't support window functions in the WHERE clause, so we need to include them as 210 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 211 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 212 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 213 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 214 corresponding expression to avoid creating invalid column references. 215 """ 216 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 217 taken = set(expression.named_selects) 218 for select in expression.selects: 219 if not select.alias_or_name: 220 alias = find_new_name(taken, "_c") 221 select.replace(exp.alias_(select, alias)) 222 taken.add(alias) 223 224 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 225 alias_or_name = select.alias_or_name 226 identifier = select.args.get("alias") or select.this 227 if isinstance(identifier, exp.Identifier): 228 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 229 return alias_or_name 230 231 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 232 qualify_filters = expression.args["qualify"].pop().this 233 expression_by_alias = { 234 select.alias: select.this 235 for select in expression.selects 236 if isinstance(select, exp.Alias) 237 } 238 239 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 240 for select_candidate in qualify_filters.find_all(select_candidates): 241 if isinstance(select_candidate, exp.Window): 242 if expression_by_alias: 243 for column in select_candidate.find_all(exp.Column): 244 expr = expression_by_alias.get(column.name) 245 if expr: 246 column.replace(expr) 247 248 alias = find_new_name(expression.named_selects, "_w") 249 expression.select(exp.alias_(select_candidate, alias), copy=False) 250 column = exp.column(alias) 251 252 if isinstance(select_candidate.parent, exp.Qualify): 253 qualify_filters = column 254 else: 255 select_candidate.replace(column) 256 elif select_candidate.name not in expression.named_selects: 257 expression.select(select_candidate.copy(), copy=False) 258 259 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 260 qualify_filters, copy=False 261 ) 262 263 return expression 264 265 266def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 267 """ 268 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 269 other expressions. This transforms removes the precision from parameterized types in expressions. 270 """ 271 for node in expression.find_all(exp.DataType): 272 node.set( 273 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 274 ) 275 276 return expression 277 278 279def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 280 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 281 from sqlglot.optimizer.scope import find_all_in_scope 282 283 if isinstance(expression, exp.Select): 284 unnest_aliases = { 285 unnest.alias 286 for unnest in find_all_in_scope(expression, exp.Unnest) 287 if isinstance(unnest.parent, (exp.From, exp.Join)) 288 } 289 if unnest_aliases: 290 for column in expression.find_all(exp.Column): 291 if column.table in unnest_aliases: 292 column.set("table", None) 293 elif column.db in unnest_aliases: 294 column.set("db", None) 295 296 return expression 297 298 299def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 300 """Convert cross join unnest into lateral view explode.""" 301 if isinstance(expression, exp.Select): 302 from_ = expression.args.get("from") 303 304 if from_ and isinstance(from_.this, exp.Unnest): 305 unnest = from_.this 306 alias = unnest.args.get("alias") 307 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 308 this, *expressions = unnest.expressions 309 unnest.replace( 310 exp.Table( 311 this=udtf( 312 this=this, 313 expressions=expressions, 314 ), 315 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 316 ) 317 ) 318 319 for join in expression.args.get("joins") or []: 320 unnest = join.this 321 322 if isinstance(unnest, exp.Unnest): 323 alias = unnest.args.get("alias") 324 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 325 326 expression.args["joins"].remove(join) 327 328 for e, column in zip(unnest.expressions, alias.columns if alias else []): 329 expression.append( 330 "laterals", 331 exp.Lateral( 332 this=udtf(this=e), 333 view=True, 334 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 335 ), 336 ) 337 338 return expression 339 340 341def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 342 """Convert explode/posexplode into unnest.""" 343 344 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 345 if isinstance(expression, exp.Select): 346 from sqlglot.optimizer.scope import Scope 347 348 taken_select_names = set(expression.named_selects) 349 taken_source_names = {name for name, _ in Scope(expression).references} 350 351 def new_name(names: t.Set[str], name: str) -> str: 352 name = find_new_name(names, name) 353 names.add(name) 354 return name 355 356 arrays: t.List[exp.Condition] = [] 357 series_alias = new_name(taken_select_names, "pos") 358 series = exp.alias_( 359 exp.Unnest( 360 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 361 ), 362 new_name(taken_source_names, "_u"), 363 table=[series_alias], 364 ) 365 366 # we use list here because expression.selects is mutated inside the loop 367 for select in list(expression.selects): 368 explode = select.find(exp.Explode) 369 370 if explode: 371 pos_alias = "" 372 explode_alias = "" 373 374 if isinstance(select, exp.Alias): 375 explode_alias = select.args["alias"] 376 alias = select 377 elif isinstance(select, exp.Aliases): 378 pos_alias = select.aliases[0] 379 explode_alias = select.aliases[1] 380 alias = select.replace(exp.alias_(select.this, "", copy=False)) 381 else: 382 alias = select.replace(exp.alias_(select, "")) 383 explode = alias.find(exp.Explode) 384 assert explode 385 386 is_posexplode = isinstance(explode, exp.Posexplode) 387 explode_arg = explode.this 388 389 if isinstance(explode, exp.ExplodeOuter): 390 bracket = explode_arg[0] 391 bracket.set("safe", True) 392 bracket.set("offset", True) 393 explode_arg = exp.func( 394 "IF", 395 exp.func( 396 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 397 ).eq(0), 398 exp.array(bracket, copy=False), 399 explode_arg, 400 ) 401 402 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 403 if isinstance(explode_arg, exp.Column): 404 taken_select_names.add(explode_arg.output_name) 405 406 unnest_source_alias = new_name(taken_source_names, "_u") 407 408 if not explode_alias: 409 explode_alias = new_name(taken_select_names, "col") 410 411 if is_posexplode: 412 pos_alias = new_name(taken_select_names, "pos") 413 414 if not pos_alias: 415 pos_alias = new_name(taken_select_names, "pos") 416 417 alias.set("alias", exp.to_identifier(explode_alias)) 418 419 series_table_alias = series.args["alias"].this 420 column = exp.If( 421 this=exp.column(series_alias, table=series_table_alias).eq( 422 exp.column(pos_alias, table=unnest_source_alias) 423 ), 424 true=exp.column(explode_alias, table=unnest_source_alias), 425 ) 426 427 explode.replace(column) 428 429 if is_posexplode: 430 expressions = expression.expressions 431 expressions.insert( 432 expressions.index(alias) + 1, 433 exp.If( 434 this=exp.column(series_alias, table=series_table_alias).eq( 435 exp.column(pos_alias, table=unnest_source_alias) 436 ), 437 true=exp.column(pos_alias, table=unnest_source_alias), 438 ).as_(pos_alias), 439 ) 440 expression.set("expressions", expressions) 441 442 if not arrays: 443 if expression.args.get("from"): 444 expression.join(series, copy=False, join_type="CROSS") 445 else: 446 expression.from_(series, copy=False) 447 448 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 449 arrays.append(size) 450 451 # trino doesn't support left join unnest with on conditions 452 # if it did, this would be much simpler 453 expression.join( 454 exp.alias_( 455 exp.Unnest( 456 expressions=[explode_arg.copy()], 457 offset=exp.to_identifier(pos_alias), 458 ), 459 unnest_source_alias, 460 table=[explode_alias], 461 ), 462 join_type="CROSS", 463 copy=False, 464 ) 465 466 if index_offset != 1: 467 size = size - 1 468 469 expression.where( 470 exp.column(series_alias, table=series_table_alias) 471 .eq(exp.column(pos_alias, table=unnest_source_alias)) 472 .or_( 473 (exp.column(series_alias, table=series_table_alias) > size).and_( 474 exp.column(pos_alias, table=unnest_source_alias).eq(size) 475 ) 476 ), 477 copy=False, 478 ) 479 480 if arrays: 481 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 482 483 if index_offset != 1: 484 end = end - (1 - index_offset) 485 series.expressions[0].set("end", end) 486 487 return expression 488 489 return _explode_to_unnest 490 491 492def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 493 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 494 if ( 495 isinstance(expression, exp.PERCENTILES) 496 and not isinstance(expression.parent, exp.WithinGroup) 497 and expression.expression 498 ): 499 column = expression.this.pop() 500 expression.set("this", expression.expression.pop()) 501 order = exp.Order(expressions=[exp.Ordered(this=column)]) 502 expression = exp.WithinGroup(this=expression, expression=order) 503 504 return expression 505 506 507def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 508 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 509 if ( 510 isinstance(expression, exp.WithinGroup) 511 and isinstance(expression.this, exp.PERCENTILES) 512 and isinstance(expression.expression, exp.Order) 513 ): 514 quantile = expression.this.this 515 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 516 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 517 518 return expression 519 520 521def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 522 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 523 if isinstance(expression, exp.With) and expression.recursive: 524 next_name = name_sequence("_c_") 525 526 for cte in expression.expressions: 527 if not cte.args["alias"].columns: 528 query = cte.this 529 if isinstance(query, exp.SetOperation): 530 query = query.this 531 532 cte.args["alias"].set( 533 "columns", 534 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 535 ) 536 537 return expression 538 539 540def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 541 """Replace 'epoch' in casts by the equivalent date literal.""" 542 if ( 543 isinstance(expression, (exp.Cast, exp.TryCast)) 544 and expression.name.lower() == "epoch" 545 and expression.to.this in exp.DataType.TEMPORAL_TYPES 546 ): 547 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 548 549 return expression 550 551 552def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 553 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 554 if isinstance(expression, exp.Select): 555 for join in expression.args.get("joins") or []: 556 on = join.args.get("on") 557 if on and join.kind in ("SEMI", "ANTI"): 558 subquery = exp.select("1").from_(join.this).where(on) 559 exists = exp.Exists(this=subquery) 560 if join.kind == "ANTI": 561 exists = exists.not_(copy=False) 562 563 join.pop() 564 expression.where(exists, copy=False) 565 566 return expression 567 568 569def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 570 """ 571 Converts a query with a FULL OUTER join to a union of identical queries that 572 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 573 for queries that have a single FULL OUTER join. 574 """ 575 if isinstance(expression, exp.Select): 576 full_outer_joins = [ 577 (index, join) 578 for index, join in enumerate(expression.args.get("joins") or []) 579 if join.side == "FULL" 580 ] 581 582 if len(full_outer_joins) == 1: 583 expression_copy = expression.copy() 584 expression.set("limit", None) 585 index, full_outer_join = full_outer_joins[0] 586 full_outer_join.set("side", "left") 587 expression_copy.args["joins"][index].set("side", "right") 588 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 589 590 return exp.union(expression, expression_copy, copy=False) 591 592 return expression 593 594 595def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 596 """ 597 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 598 defined at the top-level, so for example queries like: 599 600 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 601 602 are invalid in those dialects. This transformation can be used to ensure all CTEs are 603 moved to the top level so that the final SQL code is valid from a syntax standpoint. 604 605 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 606 """ 607 top_level_with = expression.args.get("with") 608 for inner_with in expression.find_all(exp.With): 609 if inner_with.parent is expression: 610 continue 611 612 if not top_level_with: 613 top_level_with = inner_with.pop() 614 expression.set("with", top_level_with) 615 else: 616 if inner_with.recursive: 617 top_level_with.set("recursive", True) 618 619 parent_cte = inner_with.find_ancestor(exp.CTE) 620 inner_with.pop() 621 622 if parent_cte: 623 i = top_level_with.expressions.index(parent_cte) 624 top_level_with.expressions[i:i] = inner_with.expressions 625 top_level_with.set("expressions", top_level_with.expressions) 626 else: 627 top_level_with.set( 628 "expressions", top_level_with.expressions + inner_with.expressions 629 ) 630 631 return expression 632 633 634def ensure_bools(expression: exp.Expression) -> exp.Expression: 635 """Converts numeric values used in conditions into explicit boolean expressions.""" 636 from sqlglot.optimizer.canonicalize import ensure_bools 637 638 def _ensure_bool(node: exp.Expression) -> None: 639 if ( 640 node.is_number 641 or ( 642 not isinstance(node, exp.SubqueryPredicate) 643 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 644 ) 645 or (isinstance(node, exp.Column) and not node.type) 646 ): 647 node.replace(node.neq(0)) 648 649 for node in expression.walk(): 650 ensure_bools(node, _ensure_bool) 651 652 return expression 653 654 655def unqualify_columns(expression: exp.Expression) -> exp.Expression: 656 for column in expression.find_all(exp.Column): 657 # We only wanna pop off the table, db, catalog args 658 for part in column.parts[:-1]: 659 part.pop() 660 661 return expression 662 663 664def remove_unique_constraints(expression: exp.Expression) -> exp.Expression: 665 assert isinstance(expression, exp.Create) 666 for constraint in expression.find_all(exp.UniqueColumnConstraint): 667 if constraint.parent: 668 constraint.parent.pop() 669 670 return expression 671 672 673def ctas_with_tmp_tables_to_create_tmp_view( 674 expression: exp.Expression, 675 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 676) -> exp.Expression: 677 assert isinstance(expression, exp.Create) 678 properties = expression.args.get("properties") 679 temporary = any( 680 isinstance(prop, exp.TemporaryProperty) 681 for prop in (properties.expressions if properties else []) 682 ) 683 684 # CTAS with temp tables map to CREATE TEMPORARY VIEW 685 if expression.kind == "TABLE" and temporary: 686 if expression.expression: 687 return exp.Create( 688 kind="TEMPORARY VIEW", 689 this=expression.this, 690 expression=expression.expression, 691 ) 692 return tmp_storage_provider(expression) 693 694 return expression 695 696 697def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 698 """ 699 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 700 PARTITIONED BY value is an array of column names, they are transformed into a schema. 701 The corresponding columns are removed from the create statement. 702 """ 703 assert isinstance(expression, exp.Create) 704 has_schema = isinstance(expression.this, exp.Schema) 705 is_partitionable = expression.kind in {"TABLE", "VIEW"} 706 707 if has_schema and is_partitionable: 708 prop = expression.find(exp.PartitionedByProperty) 709 if prop and prop.this and not isinstance(prop.this, exp.Schema): 710 schema = expression.this 711 columns = {v.name.upper() for v in prop.this.expressions} 712 partitions = [col for col in schema.expressions if col.name.upper() in columns] 713 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 714 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 715 expression.set("this", schema) 716 717 return expression 718 719 720def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 721 """ 722 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 723 724 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 725 """ 726 assert isinstance(expression, exp.Create) 727 prop = expression.find(exp.PartitionedByProperty) 728 if ( 729 prop 730 and prop.this 731 and isinstance(prop.this, exp.Schema) 732 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 733 ): 734 prop_this = exp.Tuple( 735 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 736 ) 737 schema = expression.this 738 for e in prop.this.expressions: 739 schema.append("expressions", e) 740 prop.set("this", prop_this) 741 742 return expression 743 744 745def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 746 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 747 if isinstance(expression, exp.Struct): 748 expression.set( 749 "expressions", 750 [ 751 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 752 for e in expression.expressions 753 ], 754 ) 755 756 return expression 757 758 759def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 760 """ 761 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 762 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 763 764 For example, 765 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 766 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 767 768 Args: 769 expression: The AST to remove join marks from. 770 771 Returns: 772 The AST with join marks removed. 773 """ 774 from sqlglot.optimizer.scope import traverse_scope 775 776 for scope in traverse_scope(expression): 777 query = scope.expression 778 779 where = query.args.get("where") 780 joins = query.args.get("joins") 781 782 if not where or not joins: 783 continue 784 785 query_from = query.args["from"] 786 787 # These keep track of the joins to be replaced 788 new_joins: t.Dict[str, exp.Join] = {} 789 old_joins = {join.alias_or_name: join for join in joins} 790 791 for column in scope.columns: 792 if not column.args.get("join_mark"): 793 continue 794 795 predicate = column.find_ancestor(exp.Predicate, exp.Select) 796 assert isinstance( 797 predicate, exp.Binary 798 ), "Columns can only be marked with (+) when involved in a binary operation" 799 800 predicate_parent = predicate.parent 801 join_predicate = predicate.pop() 802 803 left_columns = [ 804 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 805 ] 806 right_columns = [ 807 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 808 ] 809 810 assert not ( 811 left_columns and right_columns 812 ), "The (+) marker cannot appear in both sides of a binary predicate" 813 814 marked_column_tables = set() 815 for col in left_columns or right_columns: 816 table = col.table 817 assert table, f"Column {col} needs to be qualified with a table" 818 819 col.set("join_mark", False) 820 marked_column_tables.add(table) 821 822 assert ( 823 len(marked_column_tables) == 1 824 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 825 826 join_this = old_joins.get(col.table, query_from).this 827 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 828 829 # Upsert new_join into new_joins dictionary 830 new_join_alias_or_name = new_join.alias_or_name 831 existing_join = new_joins.get(new_join_alias_or_name) 832 if existing_join: 833 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 834 else: 835 new_joins[new_join_alias_or_name] = new_join 836 837 # If the parent of the target predicate is a binary node, then it now has only one child 838 if isinstance(predicate_parent, exp.Binary): 839 if predicate_parent.left is None: 840 predicate_parent.replace(predicate_parent.right) 841 else: 842 predicate_parent.replace(predicate_parent.left) 843 844 if query_from.alias_or_name in new_joins: 845 only_old_joins = old_joins.keys() - new_joins.keys() 846 assert ( 847 len(only_old_joins) >= 1 848 ), "Cannot determine which table to use in the new FROM clause" 849 850 new_from_name = list(only_old_joins)[0] 851 query.set("from", exp.From(this=old_joins[new_from_name].this)) 852 853 query.set("joins", list(new_joins.values())) 854 855 if not where.this: 856 where.pop() 857 858 return expression
13def preprocess( 14 transforms: t.List[t.Callable[[exp.Expression], exp.Expression]], 15) -> t.Callable[[Generator, exp.Expression], str]: 16 """ 17 Creates a new transform by chaining a sequence of transformations and converts the resulting 18 expression to SQL, using either the "_sql" method corresponding to the resulting expression, 19 or the appropriate `Generator.TRANSFORMS` function (when applicable -- see below). 20 21 Args: 22 transforms: sequence of transform functions. These will be called in order. 23 24 Returns: 25 Function that can be used as a generator transform. 26 """ 27 28 def _to_sql(self, expression: exp.Expression) -> str: 29 expression_type = type(expression) 30 31 expression = transforms[0](expression) 32 for transform in transforms[1:]: 33 expression = transform(expression) 34 35 _sql_handler = getattr(self, expression.key + "_sql", None) 36 if _sql_handler: 37 return _sql_handler(expression) 38 39 transforms_handler = self.TRANSFORMS.get(type(expression)) 40 if transforms_handler: 41 if expression_type is type(expression): 42 if isinstance(expression, exp.Func): 43 return self.function_fallback_sql(expression) 44 45 # Ensures we don't enter an infinite loop. This can happen when the original expression 46 # has the same type as the final expression and there's no _sql method available for it, 47 # because then it'd re-enter _to_sql. 48 raise ValueError( 49 f"Expression type {expression.__class__.__name__} requires a _sql method in order to be transformed." 50 ) 51 52 return transforms_handler(self, expression) 53 54 raise ValueError(f"Unsupported expression type {expression.__class__.__name__}.") 55 56 return _to_sql
Creates a new transform by chaining a sequence of transformations and converts the resulting
expression to SQL, using either the "_sql" method corresponding to the resulting expression,
or the appropriate Generator.TRANSFORMS
function (when applicable -- see below).
Arguments:
- transforms: sequence of transform functions. These will be called in order.
Returns:
Function that can be used as a generator transform.
59def unnest_generate_date_array_using_recursive_cte(expression: exp.Expression) -> exp.Expression: 60 if isinstance(expression, exp.Select): 61 count = 0 62 recursive_ctes = [] 63 64 for unnest in expression.find_all(exp.Unnest): 65 if ( 66 not isinstance(unnest.parent, (exp.From, exp.Join)) 67 or len(unnest.expressions) != 1 68 or not isinstance(unnest.expressions[0], exp.GenerateDateArray) 69 ): 70 continue 71 72 generate_date_array = unnest.expressions[0] 73 start = generate_date_array.args.get("start") 74 end = generate_date_array.args.get("end") 75 step = generate_date_array.args.get("step") 76 77 if not start or not end or not isinstance(step, exp.Interval): 78 continue 79 80 alias = unnest.args.get("alias") 81 column_name = alias.columns[0] if isinstance(alias, exp.TableAlias) else "date_value" 82 83 start = exp.cast(start, "date") 84 date_add = exp.func( 85 "date_add", column_name, exp.Literal.number(step.name), step.args.get("unit") 86 ) 87 cast_date_add = exp.cast(date_add, "date") 88 89 cte_name = "_generated_dates" + (f"_{count}" if count else "") 90 91 base_query = exp.select(start.as_(column_name)) 92 recursive_query = ( 93 exp.select(cast_date_add) 94 .from_(cte_name) 95 .where(cast_date_add <= exp.cast(end, "date")) 96 ) 97 cte_query = base_query.union(recursive_query, distinct=False) 98 99 generate_dates_query = exp.select(column_name).from_(cte_name) 100 unnest.replace(generate_dates_query.subquery(cte_name)) 101 102 recursive_ctes.append( 103 exp.alias_(exp.CTE(this=cte_query), cte_name, table=[column_name]) 104 ) 105 count += 1 106 107 if recursive_ctes: 108 with_expression = expression.args.get("with") or exp.With() 109 with_expression.set("recursive", True) 110 with_expression.set("expressions", [*recursive_ctes, *with_expression.expressions]) 111 expression.set("with", with_expression) 112 113 return expression
116def unnest_generate_series(expression: exp.Expression) -> exp.Expression: 117 """Unnests GENERATE_SERIES or SEQUENCE table references.""" 118 this = expression.this 119 if isinstance(expression, exp.Table) and isinstance(this, exp.GenerateSeries): 120 unnest = exp.Unnest(expressions=[this]) 121 if expression.alias: 122 return exp.alias_(unnest, alias="_u", table=[expression.alias], copy=False) 123 124 return unnest 125 126 return expression
Unnests GENERATE_SERIES or SEQUENCE table references.
129def unalias_group(expression: exp.Expression) -> exp.Expression: 130 """ 131 Replace references to select aliases in GROUP BY clauses. 132 133 Example: 134 >>> import sqlglot 135 >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 136 'SELECT a AS b FROM x GROUP BY 1' 137 138 Args: 139 expression: the expression that will be transformed. 140 141 Returns: 142 The transformed expression. 143 """ 144 if isinstance(expression, exp.Group) and isinstance(expression.parent, exp.Select): 145 aliased_selects = { 146 e.alias: i 147 for i, e in enumerate(expression.parent.expressions, start=1) 148 if isinstance(e, exp.Alias) 149 } 150 151 for group_by in expression.expressions: 152 if ( 153 isinstance(group_by, exp.Column) 154 and not group_by.table 155 and group_by.name in aliased_selects 156 ): 157 group_by.replace(exp.Literal.number(aliased_selects.get(group_by.name))) 158 159 return expression
Replace references to select aliases in GROUP BY clauses.
Example:
>>> import sqlglot >>> sqlglot.parse_one("SELECT a AS b FROM x GROUP BY b").transform(unalias_group).sql() 'SELECT a AS b FROM x GROUP BY 1'
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
162def eliminate_distinct_on(expression: exp.Expression) -> exp.Expression: 163 """ 164 Convert SELECT DISTINCT ON statements to a subquery with a window function. 165 166 This is useful for dialects that don't support SELECT DISTINCT ON but support window functions. 167 168 Args: 169 expression: the expression that will be transformed. 170 171 Returns: 172 The transformed expression. 173 """ 174 if ( 175 isinstance(expression, exp.Select) 176 and expression.args.get("distinct") 177 and expression.args["distinct"].args.get("on") 178 and isinstance(expression.args["distinct"].args["on"], exp.Tuple) 179 ): 180 distinct_cols = expression.args["distinct"].pop().args["on"].expressions 181 outer_selects = expression.selects 182 row_number = find_new_name(expression.named_selects, "_row_number") 183 window = exp.Window(this=exp.RowNumber(), partition_by=distinct_cols) 184 order = expression.args.get("order") 185 186 if order: 187 window.set("order", order.pop()) 188 else: 189 window.set("order", exp.Order(expressions=[c.copy() for c in distinct_cols])) 190 191 window = exp.alias_(window, row_number) 192 expression.select(window, copy=False) 193 194 return ( 195 exp.select(*outer_selects, copy=False) 196 .from_(expression.subquery("_t", copy=False), copy=False) 197 .where(exp.column(row_number).eq(1), copy=False) 198 ) 199 200 return expression
Convert SELECT DISTINCT ON statements to a subquery with a window function.
This is useful for dialects that don't support SELECT DISTINCT ON but support window functions.
Arguments:
- expression: the expression that will be transformed.
Returns:
The transformed expression.
203def eliminate_qualify(expression: exp.Expression) -> exp.Expression: 204 """ 205 Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently. 206 207 The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: 208 https://docs.snowflake.com/en/sql-reference/constructs/qualify 209 210 Some dialects don't support window functions in the WHERE clause, so we need to include them as 211 projections in the subquery, in order to refer to them in the outer filter using aliases. Also, 212 if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, 213 otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a 214 newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the 215 corresponding expression to avoid creating invalid column references. 216 """ 217 if isinstance(expression, exp.Select) and expression.args.get("qualify"): 218 taken = set(expression.named_selects) 219 for select in expression.selects: 220 if not select.alias_or_name: 221 alias = find_new_name(taken, "_c") 222 select.replace(exp.alias_(select, alias)) 223 taken.add(alias) 224 225 def _select_alias_or_name(select: exp.Expression) -> str | exp.Column: 226 alias_or_name = select.alias_or_name 227 identifier = select.args.get("alias") or select.this 228 if isinstance(identifier, exp.Identifier): 229 return exp.column(alias_or_name, quoted=identifier.args.get("quoted")) 230 return alias_or_name 231 232 outer_selects = exp.select(*list(map(_select_alias_or_name, expression.selects))) 233 qualify_filters = expression.args["qualify"].pop().this 234 expression_by_alias = { 235 select.alias: select.this 236 for select in expression.selects 237 if isinstance(select, exp.Alias) 238 } 239 240 select_candidates = exp.Window if expression.is_star else (exp.Window, exp.Column) 241 for select_candidate in qualify_filters.find_all(select_candidates): 242 if isinstance(select_candidate, exp.Window): 243 if expression_by_alias: 244 for column in select_candidate.find_all(exp.Column): 245 expr = expression_by_alias.get(column.name) 246 if expr: 247 column.replace(expr) 248 249 alias = find_new_name(expression.named_selects, "_w") 250 expression.select(exp.alias_(select_candidate, alias), copy=False) 251 column = exp.column(alias) 252 253 if isinstance(select_candidate.parent, exp.Qualify): 254 qualify_filters = column 255 else: 256 select_candidate.replace(column) 257 elif select_candidate.name not in expression.named_selects: 258 expression.select(select_candidate.copy(), copy=False) 259 260 return outer_selects.from_(expression.subquery(alias="_t", copy=False), copy=False).where( 261 qualify_filters, copy=False 262 ) 263 264 return expression
Convert SELECT statements that contain the QUALIFY clause into subqueries, filtered equivalently.
The idea behind this transformation can be seen in Snowflake's documentation for QUALIFY: https://docs.snowflake.com/en/sql-reference/constructs/qualify
Some dialects don't support window functions in the WHERE clause, so we need to include them as projections in the subquery, in order to refer to them in the outer filter using aliases. Also, if a column is referenced in the QUALIFY clause but is not selected, we need to include it too, otherwise we won't be able to refer to it in the outer query's WHERE clause. Finally, if a newly aliased projection is referenced in the QUALIFY clause, it will be replaced by the corresponding expression to avoid creating invalid column references.
267def remove_precision_parameterized_types(expression: exp.Expression) -> exp.Expression: 268 """ 269 Some dialects only allow the precision for parameterized types to be defined in the DDL and not in 270 other expressions. This transforms removes the precision from parameterized types in expressions. 271 """ 272 for node in expression.find_all(exp.DataType): 273 node.set( 274 "expressions", [e for e in node.expressions if not isinstance(e, exp.DataTypeParam)] 275 ) 276 277 return expression
Some dialects only allow the precision for parameterized types to be defined in the DDL and not in other expressions. This transforms removes the precision from parameterized types in expressions.
280def unqualify_unnest(expression: exp.Expression) -> exp.Expression: 281 """Remove references to unnest table aliases, added by the optimizer's qualify_columns step.""" 282 from sqlglot.optimizer.scope import find_all_in_scope 283 284 if isinstance(expression, exp.Select): 285 unnest_aliases = { 286 unnest.alias 287 for unnest in find_all_in_scope(expression, exp.Unnest) 288 if isinstance(unnest.parent, (exp.From, exp.Join)) 289 } 290 if unnest_aliases: 291 for column in expression.find_all(exp.Column): 292 if column.table in unnest_aliases: 293 column.set("table", None) 294 elif column.db in unnest_aliases: 295 column.set("db", None) 296 297 return expression
Remove references to unnest table aliases, added by the optimizer's qualify_columns step.
300def unnest_to_explode(expression: exp.Expression) -> exp.Expression: 301 """Convert cross join unnest into lateral view explode.""" 302 if isinstance(expression, exp.Select): 303 from_ = expression.args.get("from") 304 305 if from_ and isinstance(from_.this, exp.Unnest): 306 unnest = from_.this 307 alias = unnest.args.get("alias") 308 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 309 this, *expressions = unnest.expressions 310 unnest.replace( 311 exp.Table( 312 this=udtf( 313 this=this, 314 expressions=expressions, 315 ), 316 alias=exp.TableAlias(this=alias.this, columns=alias.columns) if alias else None, 317 ) 318 ) 319 320 for join in expression.args.get("joins") or []: 321 unnest = join.this 322 323 if isinstance(unnest, exp.Unnest): 324 alias = unnest.args.get("alias") 325 udtf = exp.Posexplode if unnest.args.get("offset") else exp.Explode 326 327 expression.args["joins"].remove(join) 328 329 for e, column in zip(unnest.expressions, alias.columns if alias else []): 330 expression.append( 331 "laterals", 332 exp.Lateral( 333 this=udtf(this=e), 334 view=True, 335 alias=exp.TableAlias(this=alias.this, columns=[column]), # type: ignore 336 ), 337 ) 338 339 return expression
Convert cross join unnest into lateral view explode.
342def explode_to_unnest(index_offset: int = 0) -> t.Callable[[exp.Expression], exp.Expression]: 343 """Convert explode/posexplode into unnest.""" 344 345 def _explode_to_unnest(expression: exp.Expression) -> exp.Expression: 346 if isinstance(expression, exp.Select): 347 from sqlglot.optimizer.scope import Scope 348 349 taken_select_names = set(expression.named_selects) 350 taken_source_names = {name for name, _ in Scope(expression).references} 351 352 def new_name(names: t.Set[str], name: str) -> str: 353 name = find_new_name(names, name) 354 names.add(name) 355 return name 356 357 arrays: t.List[exp.Condition] = [] 358 series_alias = new_name(taken_select_names, "pos") 359 series = exp.alias_( 360 exp.Unnest( 361 expressions=[exp.GenerateSeries(start=exp.Literal.number(index_offset))] 362 ), 363 new_name(taken_source_names, "_u"), 364 table=[series_alias], 365 ) 366 367 # we use list here because expression.selects is mutated inside the loop 368 for select in list(expression.selects): 369 explode = select.find(exp.Explode) 370 371 if explode: 372 pos_alias = "" 373 explode_alias = "" 374 375 if isinstance(select, exp.Alias): 376 explode_alias = select.args["alias"] 377 alias = select 378 elif isinstance(select, exp.Aliases): 379 pos_alias = select.aliases[0] 380 explode_alias = select.aliases[1] 381 alias = select.replace(exp.alias_(select.this, "", copy=False)) 382 else: 383 alias = select.replace(exp.alias_(select, "")) 384 explode = alias.find(exp.Explode) 385 assert explode 386 387 is_posexplode = isinstance(explode, exp.Posexplode) 388 explode_arg = explode.this 389 390 if isinstance(explode, exp.ExplodeOuter): 391 bracket = explode_arg[0] 392 bracket.set("safe", True) 393 bracket.set("offset", True) 394 explode_arg = exp.func( 395 "IF", 396 exp.func( 397 "ARRAY_SIZE", exp.func("COALESCE", explode_arg, exp.Array()) 398 ).eq(0), 399 exp.array(bracket, copy=False), 400 explode_arg, 401 ) 402 403 # This ensures that we won't use [POS]EXPLODE's argument as a new selection 404 if isinstance(explode_arg, exp.Column): 405 taken_select_names.add(explode_arg.output_name) 406 407 unnest_source_alias = new_name(taken_source_names, "_u") 408 409 if not explode_alias: 410 explode_alias = new_name(taken_select_names, "col") 411 412 if is_posexplode: 413 pos_alias = new_name(taken_select_names, "pos") 414 415 if not pos_alias: 416 pos_alias = new_name(taken_select_names, "pos") 417 418 alias.set("alias", exp.to_identifier(explode_alias)) 419 420 series_table_alias = series.args["alias"].this 421 column = exp.If( 422 this=exp.column(series_alias, table=series_table_alias).eq( 423 exp.column(pos_alias, table=unnest_source_alias) 424 ), 425 true=exp.column(explode_alias, table=unnest_source_alias), 426 ) 427 428 explode.replace(column) 429 430 if is_posexplode: 431 expressions = expression.expressions 432 expressions.insert( 433 expressions.index(alias) + 1, 434 exp.If( 435 this=exp.column(series_alias, table=series_table_alias).eq( 436 exp.column(pos_alias, table=unnest_source_alias) 437 ), 438 true=exp.column(pos_alias, table=unnest_source_alias), 439 ).as_(pos_alias), 440 ) 441 expression.set("expressions", expressions) 442 443 if not arrays: 444 if expression.args.get("from"): 445 expression.join(series, copy=False, join_type="CROSS") 446 else: 447 expression.from_(series, copy=False) 448 449 size: exp.Condition = exp.ArraySize(this=explode_arg.copy()) 450 arrays.append(size) 451 452 # trino doesn't support left join unnest with on conditions 453 # if it did, this would be much simpler 454 expression.join( 455 exp.alias_( 456 exp.Unnest( 457 expressions=[explode_arg.copy()], 458 offset=exp.to_identifier(pos_alias), 459 ), 460 unnest_source_alias, 461 table=[explode_alias], 462 ), 463 join_type="CROSS", 464 copy=False, 465 ) 466 467 if index_offset != 1: 468 size = size - 1 469 470 expression.where( 471 exp.column(series_alias, table=series_table_alias) 472 .eq(exp.column(pos_alias, table=unnest_source_alias)) 473 .or_( 474 (exp.column(series_alias, table=series_table_alias) > size).and_( 475 exp.column(pos_alias, table=unnest_source_alias).eq(size) 476 ) 477 ), 478 copy=False, 479 ) 480 481 if arrays: 482 end: exp.Condition = exp.Greatest(this=arrays[0], expressions=arrays[1:]) 483 484 if index_offset != 1: 485 end = end - (1 - index_offset) 486 series.expressions[0].set("end", end) 487 488 return expression 489 490 return _explode_to_unnest
Convert explode/posexplode into unnest.
493def add_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 494 """Transforms percentiles by adding a WITHIN GROUP clause to them.""" 495 if ( 496 isinstance(expression, exp.PERCENTILES) 497 and not isinstance(expression.parent, exp.WithinGroup) 498 and expression.expression 499 ): 500 column = expression.this.pop() 501 expression.set("this", expression.expression.pop()) 502 order = exp.Order(expressions=[exp.Ordered(this=column)]) 503 expression = exp.WithinGroup(this=expression, expression=order) 504 505 return expression
Transforms percentiles by adding a WITHIN GROUP clause to them.
508def remove_within_group_for_percentiles(expression: exp.Expression) -> exp.Expression: 509 """Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.""" 510 if ( 511 isinstance(expression, exp.WithinGroup) 512 and isinstance(expression.this, exp.PERCENTILES) 513 and isinstance(expression.expression, exp.Order) 514 ): 515 quantile = expression.this.this 516 input_value = t.cast(exp.Ordered, expression.find(exp.Ordered)).this 517 return expression.replace(exp.ApproxQuantile(this=input_value, quantile=quantile)) 518 519 return expression
Transforms percentiles by getting rid of their corresponding WITHIN GROUP clause.
522def add_recursive_cte_column_names(expression: exp.Expression) -> exp.Expression: 523 """Uses projection output names in recursive CTE definitions to define the CTEs' columns.""" 524 if isinstance(expression, exp.With) and expression.recursive: 525 next_name = name_sequence("_c_") 526 527 for cte in expression.expressions: 528 if not cte.args["alias"].columns: 529 query = cte.this 530 if isinstance(query, exp.SetOperation): 531 query = query.this 532 533 cte.args["alias"].set( 534 "columns", 535 [exp.to_identifier(s.alias_or_name or next_name()) for s in query.selects], 536 ) 537 538 return expression
Uses projection output names in recursive CTE definitions to define the CTEs' columns.
541def epoch_cast_to_ts(expression: exp.Expression) -> exp.Expression: 542 """Replace 'epoch' in casts by the equivalent date literal.""" 543 if ( 544 isinstance(expression, (exp.Cast, exp.TryCast)) 545 and expression.name.lower() == "epoch" 546 and expression.to.this in exp.DataType.TEMPORAL_TYPES 547 ): 548 expression.this.replace(exp.Literal.string("1970-01-01 00:00:00")) 549 550 return expression
Replace 'epoch' in casts by the equivalent date literal.
553def eliminate_semi_and_anti_joins(expression: exp.Expression) -> exp.Expression: 554 """Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.""" 555 if isinstance(expression, exp.Select): 556 for join in expression.args.get("joins") or []: 557 on = join.args.get("on") 558 if on and join.kind in ("SEMI", "ANTI"): 559 subquery = exp.select("1").from_(join.this).where(on) 560 exists = exp.Exists(this=subquery) 561 if join.kind == "ANTI": 562 exists = exists.not_(copy=False) 563 564 join.pop() 565 expression.where(exists, copy=False) 566 567 return expression
Convert SEMI and ANTI joins into equivalent forms that use EXIST instead.
570def eliminate_full_outer_join(expression: exp.Expression) -> exp.Expression: 571 """ 572 Converts a query with a FULL OUTER join to a union of identical queries that 573 use LEFT/RIGHT OUTER joins instead. This transformation currently only works 574 for queries that have a single FULL OUTER join. 575 """ 576 if isinstance(expression, exp.Select): 577 full_outer_joins = [ 578 (index, join) 579 for index, join in enumerate(expression.args.get("joins") or []) 580 if join.side == "FULL" 581 ] 582 583 if len(full_outer_joins) == 1: 584 expression_copy = expression.copy() 585 expression.set("limit", None) 586 index, full_outer_join = full_outer_joins[0] 587 full_outer_join.set("side", "left") 588 expression_copy.args["joins"][index].set("side", "right") 589 expression_copy.args.pop("with", None) # remove CTEs from RIGHT side 590 591 return exp.union(expression, expression_copy, copy=False) 592 593 return expression
Converts a query with a FULL OUTER join to a union of identical queries that use LEFT/RIGHT OUTER joins instead. This transformation currently only works for queries that have a single FULL OUTER join.
596def move_ctes_to_top_level(expression: exp.Expression) -> exp.Expression: 597 """ 598 Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be 599 defined at the top-level, so for example queries like: 600 601 SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq 602 603 are invalid in those dialects. This transformation can be used to ensure all CTEs are 604 moved to the top level so that the final SQL code is valid from a syntax standpoint. 605 606 TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly). 607 """ 608 top_level_with = expression.args.get("with") 609 for inner_with in expression.find_all(exp.With): 610 if inner_with.parent is expression: 611 continue 612 613 if not top_level_with: 614 top_level_with = inner_with.pop() 615 expression.set("with", top_level_with) 616 else: 617 if inner_with.recursive: 618 top_level_with.set("recursive", True) 619 620 parent_cte = inner_with.find_ancestor(exp.CTE) 621 inner_with.pop() 622 623 if parent_cte: 624 i = top_level_with.expressions.index(parent_cte) 625 top_level_with.expressions[i:i] = inner_with.expressions 626 top_level_with.set("expressions", top_level_with.expressions) 627 else: 628 top_level_with.set( 629 "expressions", top_level_with.expressions + inner_with.expressions 630 ) 631 632 return expression
Some dialects (e.g. Hive, T-SQL, Spark prior to version 3) only allow CTEs to be defined at the top-level, so for example queries like:
SELECT * FROM (WITH t(c) AS (SELECT 1) SELECT * FROM t) AS subq
are invalid in those dialects. This transformation can be used to ensure all CTEs are moved to the top level so that the final SQL code is valid from a syntax standpoint.
TODO: handle name clashes whilst moving CTEs (it can get quite tricky & costly).
635def ensure_bools(expression: exp.Expression) -> exp.Expression: 636 """Converts numeric values used in conditions into explicit boolean expressions.""" 637 from sqlglot.optimizer.canonicalize import ensure_bools 638 639 def _ensure_bool(node: exp.Expression) -> None: 640 if ( 641 node.is_number 642 or ( 643 not isinstance(node, exp.SubqueryPredicate) 644 and node.is_type(exp.DataType.Type.UNKNOWN, *exp.DataType.NUMERIC_TYPES) 645 ) 646 or (isinstance(node, exp.Column) and not node.type) 647 ): 648 node.replace(node.neq(0)) 649 650 for node in expression.walk(): 651 ensure_bools(node, _ensure_bool) 652 653 return expression
Converts numeric values used in conditions into explicit boolean expressions.
674def ctas_with_tmp_tables_to_create_tmp_view( 675 expression: exp.Expression, 676 tmp_storage_provider: t.Callable[[exp.Expression], exp.Expression] = lambda e: e, 677) -> exp.Expression: 678 assert isinstance(expression, exp.Create) 679 properties = expression.args.get("properties") 680 temporary = any( 681 isinstance(prop, exp.TemporaryProperty) 682 for prop in (properties.expressions if properties else []) 683 ) 684 685 # CTAS with temp tables map to CREATE TEMPORARY VIEW 686 if expression.kind == "TABLE" and temporary: 687 if expression.expression: 688 return exp.Create( 689 kind="TEMPORARY VIEW", 690 this=expression.this, 691 expression=expression.expression, 692 ) 693 return tmp_storage_provider(expression) 694 695 return expression
698def move_schema_columns_to_partitioned_by(expression: exp.Expression) -> exp.Expression: 699 """ 700 In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the 701 PARTITIONED BY value is an array of column names, they are transformed into a schema. 702 The corresponding columns are removed from the create statement. 703 """ 704 assert isinstance(expression, exp.Create) 705 has_schema = isinstance(expression.this, exp.Schema) 706 is_partitionable = expression.kind in {"TABLE", "VIEW"} 707 708 if has_schema and is_partitionable: 709 prop = expression.find(exp.PartitionedByProperty) 710 if prop and prop.this and not isinstance(prop.this, exp.Schema): 711 schema = expression.this 712 columns = {v.name.upper() for v in prop.this.expressions} 713 partitions = [col for col in schema.expressions if col.name.upper() in columns] 714 schema.set("expressions", [e for e in schema.expressions if e not in partitions]) 715 prop.replace(exp.PartitionedByProperty(this=exp.Schema(expressions=partitions))) 716 expression.set("this", schema) 717 718 return expression
In Hive, the PARTITIONED BY property acts as an extension of a table's schema. When the PARTITIONED BY value is an array of column names, they are transformed into a schema. The corresponding columns are removed from the create statement.
721def move_partitioned_by_to_schema_columns(expression: exp.Expression) -> exp.Expression: 722 """ 723 Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE. 724 725 Currently, SQLGlot uses the DATASOURCE format for Spark 3. 726 """ 727 assert isinstance(expression, exp.Create) 728 prop = expression.find(exp.PartitionedByProperty) 729 if ( 730 prop 731 and prop.this 732 and isinstance(prop.this, exp.Schema) 733 and all(isinstance(e, exp.ColumnDef) and e.kind for e in prop.this.expressions) 734 ): 735 prop_this = exp.Tuple( 736 expressions=[exp.to_identifier(e.this) for e in prop.this.expressions] 737 ) 738 schema = expression.this 739 for e in prop.this.expressions: 740 schema.append("expressions", e) 741 prop.set("this", prop_this) 742 743 return expression
Spark 3 supports both "HIVEFORMAT" and "DATASOURCE" formats for CREATE TABLE.
Currently, SQLGlot uses the DATASOURCE format for Spark 3.
746def struct_kv_to_alias(expression: exp.Expression) -> exp.Expression: 747 """Converts struct arguments to aliases, e.g. STRUCT(1 AS y).""" 748 if isinstance(expression, exp.Struct): 749 expression.set( 750 "expressions", 751 [ 752 exp.alias_(e.expression, e.this) if isinstance(e, exp.PropertyEQ) else e 753 for e in expression.expressions 754 ], 755 ) 756 757 return expression
Converts struct arguments to aliases, e.g. STRUCT(1 AS y).
760def eliminate_join_marks(expression: exp.Expression) -> exp.Expression: 761 """ 762 Remove join marks from an AST. This rule assumes that all marked columns are qualified. 763 If this does not hold for a query, consider running `sqlglot.optimizer.qualify` first. 764 765 For example, 766 SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to 767 SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this 768 769 Args: 770 expression: The AST to remove join marks from. 771 772 Returns: 773 The AST with join marks removed. 774 """ 775 from sqlglot.optimizer.scope import traverse_scope 776 777 for scope in traverse_scope(expression): 778 query = scope.expression 779 780 where = query.args.get("where") 781 joins = query.args.get("joins") 782 783 if not where or not joins: 784 continue 785 786 query_from = query.args["from"] 787 788 # These keep track of the joins to be replaced 789 new_joins: t.Dict[str, exp.Join] = {} 790 old_joins = {join.alias_or_name: join for join in joins} 791 792 for column in scope.columns: 793 if not column.args.get("join_mark"): 794 continue 795 796 predicate = column.find_ancestor(exp.Predicate, exp.Select) 797 assert isinstance( 798 predicate, exp.Binary 799 ), "Columns can only be marked with (+) when involved in a binary operation" 800 801 predicate_parent = predicate.parent 802 join_predicate = predicate.pop() 803 804 left_columns = [ 805 c for c in join_predicate.left.find_all(exp.Column) if c.args.get("join_mark") 806 ] 807 right_columns = [ 808 c for c in join_predicate.right.find_all(exp.Column) if c.args.get("join_mark") 809 ] 810 811 assert not ( 812 left_columns and right_columns 813 ), "The (+) marker cannot appear in both sides of a binary predicate" 814 815 marked_column_tables = set() 816 for col in left_columns or right_columns: 817 table = col.table 818 assert table, f"Column {col} needs to be qualified with a table" 819 820 col.set("join_mark", False) 821 marked_column_tables.add(table) 822 823 assert ( 824 len(marked_column_tables) == 1 825 ), "Columns of only a single table can be marked with (+) in a given binary predicate" 826 827 join_this = old_joins.get(col.table, query_from).this 828 new_join = exp.Join(this=join_this, on=join_predicate, kind="LEFT") 829 830 # Upsert new_join into new_joins dictionary 831 new_join_alias_or_name = new_join.alias_or_name 832 existing_join = new_joins.get(new_join_alias_or_name) 833 if existing_join: 834 existing_join.set("on", exp.and_(existing_join.args.get("on"), new_join.args["on"])) 835 else: 836 new_joins[new_join_alias_or_name] = new_join 837 838 # If the parent of the target predicate is a binary node, then it now has only one child 839 if isinstance(predicate_parent, exp.Binary): 840 if predicate_parent.left is None: 841 predicate_parent.replace(predicate_parent.right) 842 else: 843 predicate_parent.replace(predicate_parent.left) 844 845 if query_from.alias_or_name in new_joins: 846 only_old_joins = old_joins.keys() - new_joins.keys() 847 assert ( 848 len(only_old_joins) >= 1 849 ), "Cannot determine which table to use in the new FROM clause" 850 851 new_from_name = list(only_old_joins)[0] 852 query.set("from", exp.From(this=old_joins[new_from_name].this)) 853 854 query.set("joins", list(new_joins.values())) 855 856 if not where.this: 857 where.pop() 858 859 return expression
Remove join marks from an AST. This rule assumes that all marked columns are qualified.
If this does not hold for a query, consider running sqlglot.optimizer.qualify
first.
For example, SELECT * FROM a, b WHERE a.id = b.id(+) -- ... is converted to SELECT * FROM a LEFT JOIN b ON a.id = b.id -- this
Arguments:
- expression: The AST to remove join marks from.
Returns:
The AST with join marks removed.