sqlglot.dialects.dialect
1from __future__ import annotations 2 3import logging 4import typing as t 5from enum import Enum, auto 6from functools import reduce 7 8from sqlglot import exp 9from sqlglot.errors import ParseError 10from sqlglot.generator import Generator 11from sqlglot.helper import AutoName, flatten, is_int, seq_get, subclasses 12from sqlglot.jsonpath import JSONPathTokenizer, parse as parse_json_path 13from sqlglot.parser import Parser 14from sqlglot.time import TIMEZONES, format_time, subsecond_precision 15from sqlglot.tokens import Token, Tokenizer, TokenType 16from sqlglot.trie import new_trie 17 18DATE_ADD_OR_DIFF = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateDiff, exp.TsOrDsDiff] 19DATE_ADD_OR_SUB = t.Union[exp.DateAdd, exp.TsOrDsAdd, exp.DateSub] 20JSON_EXTRACT_TYPE = t.Union[exp.JSONExtract, exp.JSONExtractScalar] 21 22 23if t.TYPE_CHECKING: 24 from sqlglot._typing import B, E, F 25 26 from sqlglot.optimizer.annotate_types import TypeAnnotator 27 28 AnnotatorsType = t.Dict[t.Type[E], t.Callable[[TypeAnnotator, E], E]] 29 30logger = logging.getLogger("sqlglot") 31 32UNESCAPED_SEQUENCES = { 33 "\\a": "\a", 34 "\\b": "\b", 35 "\\f": "\f", 36 "\\n": "\n", 37 "\\r": "\r", 38 "\\t": "\t", 39 "\\v": "\v", 40 "\\\\": "\\", 41} 42 43 44def _annotate_with_type_lambda(data_type: exp.DataType.Type) -> t.Callable[[TypeAnnotator, E], E]: 45 return lambda self, e: self._annotate_with_type(e, data_type) 46 47 48class Dialects(str, Enum): 49 """Dialects supported by SQLGLot.""" 50 51 DIALECT = "" 52 53 ATHENA = "athena" 54 BIGQUERY = "bigquery" 55 CLICKHOUSE = "clickhouse" 56 DATABRICKS = "databricks" 57 DORIS = "doris" 58 DRILL = "drill" 59 DUCKDB = "duckdb" 60 HIVE = "hive" 61 MATERIALIZE = "materialize" 62 MYSQL = "mysql" 63 ORACLE = "oracle" 64 POSTGRES = "postgres" 65 PRESTO = "presto" 66 PRQL = "prql" 67 REDSHIFT = "redshift" 68 RISINGWAVE = "risingwave" 69 SNOWFLAKE = "snowflake" 70 SPARK = "spark" 71 SPARK2 = "spark2" 72 SQLITE = "sqlite" 73 STARROCKS = "starrocks" 74 TABLEAU = "tableau" 75 TERADATA = "teradata" 76 TRINO = "trino" 77 TSQL = "tsql" 78 79 80class NormalizationStrategy(str, AutoName): 81 """Specifies the strategy according to which identifiers should be normalized.""" 82 83 LOWERCASE = auto() 84 """Unquoted identifiers are lowercased.""" 85 86 UPPERCASE = auto() 87 """Unquoted identifiers are uppercased.""" 88 89 CASE_SENSITIVE = auto() 90 """Always case-sensitive, regardless of quotes.""" 91 92 CASE_INSENSITIVE = auto() 93 """Always case-insensitive, regardless of quotes.""" 94 95 96class _Dialect(type): 97 classes: t.Dict[str, t.Type[Dialect]] = {} 98 99 def __eq__(cls, other: t.Any) -> bool: 100 if cls is other: 101 return True 102 if isinstance(other, str): 103 return cls is cls.get(other) 104 if isinstance(other, Dialect): 105 return cls is type(other) 106 107 return False 108 109 def __hash__(cls) -> int: 110 return hash(cls.__name__.lower()) 111 112 @classmethod 113 def __getitem__(cls, key: str) -> t.Type[Dialect]: 114 return cls.classes[key] 115 116 @classmethod 117 def get( 118 cls, key: str, default: t.Optional[t.Type[Dialect]] = None 119 ) -> t.Optional[t.Type[Dialect]]: 120 return cls.classes.get(key, default) 121 122 def __new__(cls, clsname, bases, attrs): 123 klass = super().__new__(cls, clsname, bases, attrs) 124 enum = Dialects.__members__.get(clsname.upper()) 125 cls.classes[enum.value if enum is not None else clsname.lower()] = klass 126 127 klass.TIME_TRIE = new_trie(klass.TIME_MAPPING) 128 klass.FORMAT_TRIE = ( 129 new_trie(klass.FORMAT_MAPPING) if klass.FORMAT_MAPPING else klass.TIME_TRIE 130 ) 131 klass.INVERSE_TIME_MAPPING = {v: k for k, v in klass.TIME_MAPPING.items()} 132 klass.INVERSE_TIME_TRIE = new_trie(klass.INVERSE_TIME_MAPPING) 133 klass.INVERSE_FORMAT_MAPPING = {v: k for k, v in klass.FORMAT_MAPPING.items()} 134 klass.INVERSE_FORMAT_TRIE = new_trie(klass.INVERSE_FORMAT_MAPPING) 135 136 klass.INVERSE_CREATABLE_KIND_MAPPING = { 137 v: k for k, v in klass.CREATABLE_KIND_MAPPING.items() 138 } 139 140 base = seq_get(bases, 0) 141 base_tokenizer = (getattr(base, "tokenizer_class", Tokenizer),) 142 base_jsonpath_tokenizer = (getattr(base, "jsonpath_tokenizer_class", JSONPathTokenizer),) 143 base_parser = (getattr(base, "parser_class", Parser),) 144 base_generator = (getattr(base, "generator_class", Generator),) 145 146 klass.tokenizer_class = klass.__dict__.get( 147 "Tokenizer", type("Tokenizer", base_tokenizer, {}) 148 ) 149 klass.jsonpath_tokenizer_class = klass.__dict__.get( 150 "JSONPathTokenizer", type("JSONPathTokenizer", base_jsonpath_tokenizer, {}) 151 ) 152 klass.parser_class = klass.__dict__.get("Parser", type("Parser", base_parser, {})) 153 klass.generator_class = klass.__dict__.get( 154 "Generator", type("Generator", base_generator, {}) 155 ) 156 157 klass.QUOTE_START, klass.QUOTE_END = list(klass.tokenizer_class._QUOTES.items())[0] 158 klass.IDENTIFIER_START, klass.IDENTIFIER_END = list( 159 klass.tokenizer_class._IDENTIFIERS.items() 160 )[0] 161 162 def get_start_end(token_type: TokenType) -> t.Tuple[t.Optional[str], t.Optional[str]]: 163 return next( 164 ( 165 (s, e) 166 for s, (e, t) in klass.tokenizer_class._FORMAT_STRINGS.items() 167 if t == token_type 168 ), 169 (None, None), 170 ) 171 172 klass.BIT_START, klass.BIT_END = get_start_end(TokenType.BIT_STRING) 173 klass.HEX_START, klass.HEX_END = get_start_end(TokenType.HEX_STRING) 174 klass.BYTE_START, klass.BYTE_END = get_start_end(TokenType.BYTE_STRING) 175 klass.UNICODE_START, klass.UNICODE_END = get_start_end(TokenType.UNICODE_STRING) 176 177 if "\\" in klass.tokenizer_class.STRING_ESCAPES: 178 klass.UNESCAPED_SEQUENCES = { 179 **UNESCAPED_SEQUENCES, 180 **klass.UNESCAPED_SEQUENCES, 181 } 182 183 klass.ESCAPED_SEQUENCES = {v: k for k, v in klass.UNESCAPED_SEQUENCES.items()} 184 185 klass.SUPPORTS_COLUMN_JOIN_MARKS = "(+)" in klass.tokenizer_class.KEYWORDS 186 187 if enum not in ("", "bigquery"): 188 klass.generator_class.SELECT_KINDS = () 189 190 if enum not in ("", "clickhouse"): 191 klass.generator_class.SUPPORTS_NULLABLE_TYPES = False 192 193 if enum not in ("", "athena", "presto", "trino"): 194 klass.generator_class.TRY_SUPPORTED = False 195 klass.generator_class.SUPPORTS_UESCAPE = False 196 197 if enum not in ("", "databricks", "hive", "spark", "spark2"): 198 modifier_transforms = klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS.copy() 199 for modifier in ("cluster", "distribute", "sort"): 200 modifier_transforms.pop(modifier, None) 201 202 klass.generator_class.AFTER_HAVING_MODIFIER_TRANSFORMS = modifier_transforms 203 204 if enum not in ("", "doris", "mysql"): 205 klass.parser_class.ID_VAR_TOKENS = klass.parser_class.ID_VAR_TOKENS | { 206 TokenType.STRAIGHT_JOIN, 207 } 208 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 209 TokenType.STRAIGHT_JOIN, 210 } 211 212 if not klass.SUPPORTS_SEMI_ANTI_JOIN: 213 klass.parser_class.TABLE_ALIAS_TOKENS = klass.parser_class.TABLE_ALIAS_TOKENS | { 214 TokenType.ANTI, 215 TokenType.SEMI, 216 } 217 218 return klass 219 220 221class Dialect(metaclass=_Dialect): 222 INDEX_OFFSET = 0 223 """The base index offset for arrays.""" 224 225 WEEK_OFFSET = 0 226 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 227 228 UNNEST_COLUMN_ONLY = False 229 """Whether `UNNEST` table aliases are treated as column aliases.""" 230 231 ALIAS_POST_TABLESAMPLE = False 232 """Whether the table alias comes after tablesample.""" 233 234 TABLESAMPLE_SIZE_IS_PERCENT = False 235 """Whether a size in the table sample clause represents percentage.""" 236 237 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 238 """Specifies the strategy according to which identifiers should be normalized.""" 239 240 IDENTIFIERS_CAN_START_WITH_DIGIT = False 241 """Whether an unquoted identifier can start with a digit.""" 242 243 DPIPE_IS_STRING_CONCAT = True 244 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 245 246 STRICT_STRING_CONCAT = False 247 """Whether `CONCAT`'s arguments must be strings.""" 248 249 SUPPORTS_USER_DEFINED_TYPES = True 250 """Whether user-defined data types are supported.""" 251 252 SUPPORTS_SEMI_ANTI_JOIN = True 253 """Whether `SEMI` or `ANTI` joins are supported.""" 254 255 SUPPORTS_COLUMN_JOIN_MARKS = False 256 """Whether the old-style outer join (+) syntax is supported.""" 257 258 COPY_PARAMS_ARE_CSV = True 259 """Separator of COPY statement parameters.""" 260 261 NORMALIZE_FUNCTIONS: bool | str = "upper" 262 """ 263 Determines how function names are going to be normalized. 264 Possible values: 265 "upper" or True: Convert names to uppercase. 266 "lower": Convert names to lowercase. 267 False: Disables function name normalization. 268 """ 269 270 LOG_BASE_FIRST: t.Optional[bool] = True 271 """ 272 Whether the base comes first in the `LOG` function. 273 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 274 """ 275 276 NULL_ORDERING = "nulls_are_small" 277 """ 278 Default `NULL` ordering method to use if not explicitly set. 279 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 280 """ 281 282 TYPED_DIVISION = False 283 """ 284 Whether the behavior of `a / b` depends on the types of `a` and `b`. 285 False means `a / b` is always float division. 286 True means `a / b` is integer division if both `a` and `b` are integers. 287 """ 288 289 SAFE_DIVISION = False 290 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 291 292 CONCAT_COALESCE = False 293 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 294 295 HEX_LOWERCASE = False 296 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 297 298 DATE_FORMAT = "'%Y-%m-%d'" 299 DATEINT_FORMAT = "'%Y%m%d'" 300 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 301 302 TIME_MAPPING: t.Dict[str, str] = {} 303 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 304 305 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 306 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 307 FORMAT_MAPPING: t.Dict[str, str] = {} 308 """ 309 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 310 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 311 """ 312 313 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 314 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 315 316 PSEUDOCOLUMNS: t.Set[str] = set() 317 """ 318 Columns that are auto-generated by the engine corresponding to this dialect. 319 For example, such columns may be excluded from `SELECT *` queries. 320 """ 321 322 PREFER_CTE_ALIAS_COLUMN = False 323 """ 324 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 325 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 326 any projection aliases in the subquery. 327 328 For example, 329 WITH y(c) AS ( 330 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 331 ) SELECT c FROM y; 332 333 will be rewritten as 334 335 WITH y(c) AS ( 336 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 337 ) SELECT c FROM y; 338 """ 339 340 COPY_PARAMS_ARE_CSV = True 341 """ 342 Whether COPY statement parameters are separated by comma or whitespace 343 """ 344 345 FORCE_EARLY_ALIAS_REF_EXPANSION = False 346 """ 347 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 348 349 For example: 350 WITH data AS ( 351 SELECT 352 1 AS id, 353 2 AS my_id 354 ) 355 SELECT 356 id AS my_id 357 FROM 358 data 359 WHERE 360 my_id = 1 361 GROUP BY 362 my_id, 363 HAVING 364 my_id = 1 365 366 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 367 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 368 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 369 """ 370 371 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 372 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 373 374 SUPPORTS_ORDER_BY_ALL = False 375 """ 376 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 377 """ 378 379 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 380 """ 381 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 382 as the former is of type INT[] vs the latter which is SUPER 383 """ 384 385 SUPPORTS_FIXED_SIZE_ARRAYS = False 386 """ 387 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 388 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 389 """ 390 391 CREATABLE_KIND_MAPPING: dict[str, str] = {} 392 """ 393 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 394 equivalent of CREATE SCHEMA is CREATE DATABASE. 395 """ 396 397 # --- Autofilled --- 398 399 tokenizer_class = Tokenizer 400 jsonpath_tokenizer_class = JSONPathTokenizer 401 parser_class = Parser 402 generator_class = Generator 403 404 # A trie of the time_mapping keys 405 TIME_TRIE: t.Dict = {} 406 FORMAT_TRIE: t.Dict = {} 407 408 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 409 INVERSE_TIME_TRIE: t.Dict = {} 410 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 411 INVERSE_FORMAT_TRIE: t.Dict = {} 412 413 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 414 415 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 416 417 # Delimiters for string literals and identifiers 418 QUOTE_START = "'" 419 QUOTE_END = "'" 420 IDENTIFIER_START = '"' 421 IDENTIFIER_END = '"' 422 423 # Delimiters for bit, hex, byte and unicode literals 424 BIT_START: t.Optional[str] = None 425 BIT_END: t.Optional[str] = None 426 HEX_START: t.Optional[str] = None 427 HEX_END: t.Optional[str] = None 428 BYTE_START: t.Optional[str] = None 429 BYTE_END: t.Optional[str] = None 430 UNICODE_START: t.Optional[str] = None 431 UNICODE_END: t.Optional[str] = None 432 433 DATE_PART_MAPPING = { 434 "Y": "YEAR", 435 "YY": "YEAR", 436 "YYY": "YEAR", 437 "YYYY": "YEAR", 438 "YR": "YEAR", 439 "YEARS": "YEAR", 440 "YRS": "YEAR", 441 "MM": "MONTH", 442 "MON": "MONTH", 443 "MONS": "MONTH", 444 "MONTHS": "MONTH", 445 "D": "DAY", 446 "DD": "DAY", 447 "DAYS": "DAY", 448 "DAYOFMONTH": "DAY", 449 "DAY OF WEEK": "DAYOFWEEK", 450 "WEEKDAY": "DAYOFWEEK", 451 "DOW": "DAYOFWEEK", 452 "DW": "DAYOFWEEK", 453 "WEEKDAY_ISO": "DAYOFWEEKISO", 454 "DOW_ISO": "DAYOFWEEKISO", 455 "DW_ISO": "DAYOFWEEKISO", 456 "DAY OF YEAR": "DAYOFYEAR", 457 "DOY": "DAYOFYEAR", 458 "DY": "DAYOFYEAR", 459 "W": "WEEK", 460 "WK": "WEEK", 461 "WEEKOFYEAR": "WEEK", 462 "WOY": "WEEK", 463 "WY": "WEEK", 464 "WEEK_ISO": "WEEKISO", 465 "WEEKOFYEARISO": "WEEKISO", 466 "WEEKOFYEAR_ISO": "WEEKISO", 467 "Q": "QUARTER", 468 "QTR": "QUARTER", 469 "QTRS": "QUARTER", 470 "QUARTERS": "QUARTER", 471 "H": "HOUR", 472 "HH": "HOUR", 473 "HR": "HOUR", 474 "HOURS": "HOUR", 475 "HRS": "HOUR", 476 "M": "MINUTE", 477 "MI": "MINUTE", 478 "MIN": "MINUTE", 479 "MINUTES": "MINUTE", 480 "MINS": "MINUTE", 481 "S": "SECOND", 482 "SEC": "SECOND", 483 "SECONDS": "SECOND", 484 "SECS": "SECOND", 485 "MS": "MILLISECOND", 486 "MSEC": "MILLISECOND", 487 "MSECS": "MILLISECOND", 488 "MSECOND": "MILLISECOND", 489 "MSECONDS": "MILLISECOND", 490 "MILLISEC": "MILLISECOND", 491 "MILLISECS": "MILLISECOND", 492 "MILLISECON": "MILLISECOND", 493 "MILLISECONDS": "MILLISECOND", 494 "US": "MICROSECOND", 495 "USEC": "MICROSECOND", 496 "USECS": "MICROSECOND", 497 "MICROSEC": "MICROSECOND", 498 "MICROSECS": "MICROSECOND", 499 "USECOND": "MICROSECOND", 500 "USECONDS": "MICROSECOND", 501 "MICROSECONDS": "MICROSECOND", 502 "NS": "NANOSECOND", 503 "NSEC": "NANOSECOND", 504 "NANOSEC": "NANOSECOND", 505 "NSECOND": "NANOSECOND", 506 "NSECONDS": "NANOSECOND", 507 "NANOSECS": "NANOSECOND", 508 "EPOCH_SECOND": "EPOCH", 509 "EPOCH_SECONDS": "EPOCH", 510 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 511 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 512 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 513 "TZH": "TIMEZONE_HOUR", 514 "TZM": "TIMEZONE_MINUTE", 515 "DEC": "DECADE", 516 "DECS": "DECADE", 517 "DECADES": "DECADE", 518 "MIL": "MILLENIUM", 519 "MILS": "MILLENIUM", 520 "MILLENIA": "MILLENIUM", 521 "C": "CENTURY", 522 "CENT": "CENTURY", 523 "CENTS": "CENTURY", 524 "CENTURIES": "CENTURY", 525 } 526 527 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 528 exp.DataType.Type.BIGINT: { 529 exp.ApproxDistinct, 530 exp.ArraySize, 531 exp.Count, 532 exp.Length, 533 }, 534 exp.DataType.Type.BOOLEAN: { 535 exp.Between, 536 exp.Boolean, 537 exp.In, 538 exp.RegexpLike, 539 }, 540 exp.DataType.Type.DATE: { 541 exp.CurrentDate, 542 exp.Date, 543 exp.DateFromParts, 544 exp.DateStrToDate, 545 exp.DiToDate, 546 exp.StrToDate, 547 exp.TimeStrToDate, 548 exp.TsOrDsToDate, 549 }, 550 exp.DataType.Type.DATETIME: { 551 exp.CurrentDatetime, 552 exp.Datetime, 553 exp.DatetimeAdd, 554 exp.DatetimeSub, 555 }, 556 exp.DataType.Type.DOUBLE: { 557 exp.ApproxQuantile, 558 exp.Avg, 559 exp.Div, 560 exp.Exp, 561 exp.Ln, 562 exp.Log, 563 exp.Pow, 564 exp.Quantile, 565 exp.Round, 566 exp.SafeDivide, 567 exp.Sqrt, 568 exp.Stddev, 569 exp.StddevPop, 570 exp.StddevSamp, 571 exp.Variance, 572 exp.VariancePop, 573 }, 574 exp.DataType.Type.INT: { 575 exp.Ceil, 576 exp.DatetimeDiff, 577 exp.DateDiff, 578 exp.TimestampDiff, 579 exp.TimeDiff, 580 exp.DateToDi, 581 exp.Levenshtein, 582 exp.Sign, 583 exp.StrPosition, 584 exp.TsOrDiToDi, 585 }, 586 exp.DataType.Type.JSON: { 587 exp.ParseJSON, 588 }, 589 exp.DataType.Type.TIME: { 590 exp.Time, 591 }, 592 exp.DataType.Type.TIMESTAMP: { 593 exp.CurrentTime, 594 exp.CurrentTimestamp, 595 exp.StrToTime, 596 exp.TimeAdd, 597 exp.TimeStrToTime, 598 exp.TimeSub, 599 exp.TimestampAdd, 600 exp.TimestampSub, 601 exp.UnixToTime, 602 }, 603 exp.DataType.Type.TINYINT: { 604 exp.Day, 605 exp.Month, 606 exp.Week, 607 exp.Year, 608 exp.Quarter, 609 }, 610 exp.DataType.Type.VARCHAR: { 611 exp.ArrayConcat, 612 exp.Concat, 613 exp.ConcatWs, 614 exp.DateToDateStr, 615 exp.GroupConcat, 616 exp.Initcap, 617 exp.Lower, 618 exp.Substring, 619 exp.TimeToStr, 620 exp.TimeToTimeStr, 621 exp.Trim, 622 exp.TsOrDsToDateStr, 623 exp.UnixToStr, 624 exp.UnixToTimeStr, 625 exp.Upper, 626 }, 627 } 628 629 ANNOTATORS: AnnotatorsType = { 630 **{ 631 expr_type: lambda self, e: self._annotate_unary(e) 632 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 633 }, 634 **{ 635 expr_type: lambda self, e: self._annotate_binary(e) 636 for expr_type in subclasses(exp.__name__, exp.Binary) 637 }, 638 **{ 639 expr_type: _annotate_with_type_lambda(data_type) 640 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 641 for expr_type in expressions 642 }, 643 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 644 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 645 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 646 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 647 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 648 exp.Bracket: lambda self, e: self._annotate_bracket(e), 649 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 650 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 651 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 652 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 653 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 654 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 655 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 656 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 657 exp.Div: lambda self, e: self._annotate_div(e), 658 exp.Dot: lambda self, e: self._annotate_dot(e), 659 exp.Explode: lambda self, e: self._annotate_explode(e), 660 exp.Extract: lambda self, e: self._annotate_extract(e), 661 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 662 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 663 e, exp.DataType.build("ARRAY<DATE>") 664 ), 665 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 666 e, exp.DataType.build("ARRAY<TIMESTAMP>") 667 ), 668 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 669 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 670 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 671 exp.Literal: lambda self, e: self._annotate_literal(e), 672 exp.Map: lambda self, e: self._annotate_map(e), 673 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 674 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 675 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 676 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 677 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 678 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 679 exp.Struct: lambda self, e: self._annotate_struct(e), 680 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 681 exp.Timestamp: lambda self, e: self._annotate_with_type( 682 e, 683 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 684 ), 685 exp.ToMap: lambda self, e: self._annotate_to_map(e), 686 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 687 exp.Unnest: lambda self, e: self._annotate_unnest(e), 688 exp.VarMap: lambda self, e: self._annotate_map(e), 689 } 690 691 @classmethod 692 def get_or_raise(cls, dialect: DialectType) -> Dialect: 693 """ 694 Look up a dialect in the global dialect registry and return it if it exists. 695 696 Args: 697 dialect: The target dialect. If this is a string, it can be optionally followed by 698 additional key-value pairs that are separated by commas and are used to specify 699 dialect settings, such as whether the dialect's identifiers are case-sensitive. 700 701 Example: 702 >>> dialect = dialect_class = get_or_raise("duckdb") 703 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 704 705 Returns: 706 The corresponding Dialect instance. 707 """ 708 709 if not dialect: 710 return cls() 711 if isinstance(dialect, _Dialect): 712 return dialect() 713 if isinstance(dialect, Dialect): 714 return dialect 715 if isinstance(dialect, str): 716 try: 717 dialect_name, *kv_strings = dialect.split(",") 718 kv_pairs = (kv.split("=") for kv in kv_strings) 719 kwargs = {} 720 for pair in kv_pairs: 721 key = pair[0].strip() 722 value: t.Union[bool | str | None] = None 723 724 if len(pair) == 1: 725 # Default initialize standalone settings to True 726 value = True 727 elif len(pair) == 2: 728 value = pair[1].strip() 729 730 # Coerce the value to boolean if it matches to the truthy/falsy values below 731 value_lower = value.lower() 732 if value_lower in ("true", "1"): 733 value = True 734 elif value_lower in ("false", "0"): 735 value = False 736 737 kwargs[key] = value 738 739 except ValueError: 740 raise ValueError( 741 f"Invalid dialect format: '{dialect}'. " 742 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 743 ) 744 745 result = cls.get(dialect_name.strip()) 746 if not result: 747 from difflib import get_close_matches 748 749 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 750 if similar: 751 similar = f" Did you mean {similar}?" 752 753 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 754 755 return result(**kwargs) 756 757 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 758 759 @classmethod 760 def format_time( 761 cls, expression: t.Optional[str | exp.Expression] 762 ) -> t.Optional[exp.Expression]: 763 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 764 if isinstance(expression, str): 765 return exp.Literal.string( 766 # the time formats are quoted 767 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 768 ) 769 770 if expression and expression.is_string: 771 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 772 773 return expression 774 775 def __init__(self, **kwargs) -> None: 776 normalization_strategy = kwargs.pop("normalization_strategy", None) 777 778 if normalization_strategy is None: 779 self.normalization_strategy = self.NORMALIZATION_STRATEGY 780 else: 781 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 782 783 self.settings = kwargs 784 785 def __eq__(self, other: t.Any) -> bool: 786 # Does not currently take dialect state into account 787 return type(self) == other 788 789 def __hash__(self) -> int: 790 # Does not currently take dialect state into account 791 return hash(type(self)) 792 793 def normalize_identifier(self, expression: E) -> E: 794 """ 795 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 796 797 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 798 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 799 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 800 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 801 802 There are also dialects like Spark, which are case-insensitive even when quotes are 803 present, and dialects like MySQL, whose resolution rules match those employed by the 804 underlying operating system, for example they may always be case-sensitive in Linux. 805 806 Finally, the normalization behavior of some engines can even be controlled through flags, 807 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 808 809 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 810 that it can analyze queries in the optimizer and successfully capture their semantics. 811 """ 812 if ( 813 isinstance(expression, exp.Identifier) 814 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 815 and ( 816 not expression.quoted 817 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 818 ) 819 ): 820 expression.set( 821 "this", 822 ( 823 expression.this.upper() 824 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 825 else expression.this.lower() 826 ), 827 ) 828 829 return expression 830 831 def case_sensitive(self, text: str) -> bool: 832 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 833 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 834 return False 835 836 unsafe = ( 837 str.islower 838 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 839 else str.isupper 840 ) 841 return any(unsafe(char) for char in text) 842 843 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 844 """Checks if text can be identified given an identify option. 845 846 Args: 847 text: The text to check. 848 identify: 849 `"always"` or `True`: Always returns `True`. 850 `"safe"`: Only returns `True` if the identifier is case-insensitive. 851 852 Returns: 853 Whether the given text can be identified. 854 """ 855 if identify is True or identify == "always": 856 return True 857 858 if identify == "safe": 859 return not self.case_sensitive(text) 860 861 return False 862 863 def quote_identifier(self, expression: E, identify: bool = True) -> E: 864 """ 865 Adds quotes to a given identifier. 866 867 Args: 868 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 869 identify: If set to `False`, the quotes will only be added if the identifier is deemed 870 "unsafe", with respect to its characters and this dialect's normalization strategy. 871 """ 872 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 873 name = expression.this 874 expression.set( 875 "quoted", 876 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 877 ) 878 879 return expression 880 881 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 882 if isinstance(path, exp.Literal): 883 path_text = path.name 884 if path.is_number: 885 path_text = f"[{path_text}]" 886 try: 887 return parse_json_path(path_text, self) 888 except ParseError as e: 889 logger.warning(f"Invalid JSON path syntax. {str(e)}") 890 891 return path 892 893 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 894 return self.parser(**opts).parse(self.tokenize(sql), sql) 895 896 def parse_into( 897 self, expression_type: exp.IntoType, sql: str, **opts 898 ) -> t.List[t.Optional[exp.Expression]]: 899 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 900 901 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 902 return self.generator(**opts).generate(expression, copy=copy) 903 904 def transpile(self, sql: str, **opts) -> t.List[str]: 905 return [ 906 self.generate(expression, copy=False, **opts) if expression else "" 907 for expression in self.parse(sql) 908 ] 909 910 def tokenize(self, sql: str) -> t.List[Token]: 911 return self.tokenizer.tokenize(sql) 912 913 @property 914 def tokenizer(self) -> Tokenizer: 915 return self.tokenizer_class(dialect=self) 916 917 @property 918 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 919 return self.jsonpath_tokenizer_class(dialect=self) 920 921 def parser(self, **opts) -> Parser: 922 return self.parser_class(dialect=self, **opts) 923 924 def generator(self, **opts) -> Generator: 925 return self.generator_class(dialect=self, **opts) 926 927 928DialectType = t.Union[str, Dialect, t.Type[Dialect], None] 929 930 931def rename_func(name: str) -> t.Callable[[Generator, exp.Expression], str]: 932 return lambda self, expression: self.func(name, *flatten(expression.args.values())) 933 934 935def approx_count_distinct_sql(self: Generator, expression: exp.ApproxDistinct) -> str: 936 if expression.args.get("accuracy"): 937 self.unsupported("APPROX_COUNT_DISTINCT does not support accuracy") 938 return self.func("APPROX_COUNT_DISTINCT", expression.this) 939 940 941def if_sql( 942 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 943) -> t.Callable[[Generator, exp.If], str]: 944 def _if_sql(self: Generator, expression: exp.If) -> str: 945 return self.func( 946 name, 947 expression.this, 948 expression.args.get("true"), 949 expression.args.get("false") or false_value, 950 ) 951 952 return _if_sql 953 954 955def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 956 this = expression.this 957 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 958 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 959 960 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>") 961 962 963def inline_array_sql(self: Generator, expression: exp.Array) -> str: 964 return f"[{self.expressions(expression, dynamic=True, new_line=True, skip_first=True, skip_last=True)}]" 965 966 967def inline_array_unless_query(self: Generator, expression: exp.Array) -> str: 968 elem = seq_get(expression.expressions, 0) 969 if isinstance(elem, exp.Expression) and elem.find(exp.Query): 970 return self.func("ARRAY", elem) 971 return inline_array_sql(self, expression) 972 973 974def no_ilike_sql(self: Generator, expression: exp.ILike) -> str: 975 return self.like_sql( 976 exp.Like( 977 this=exp.Lower(this=expression.this), expression=exp.Lower(this=expression.expression) 978 ) 979 ) 980 981 982def no_paren_current_date_sql(self: Generator, expression: exp.CurrentDate) -> str: 983 zone = self.sql(expression, "this") 984 return f"CURRENT_DATE AT TIME ZONE {zone}" if zone else "CURRENT_DATE" 985 986 987def no_recursive_cte_sql(self: Generator, expression: exp.With) -> str: 988 if expression.args.get("recursive"): 989 self.unsupported("Recursive CTEs are unsupported") 990 expression.args["recursive"] = False 991 return self.with_sql(expression) 992 993 994def no_safe_divide_sql(self: Generator, expression: exp.SafeDivide) -> str: 995 n = self.sql(expression, "this") 996 d = self.sql(expression, "expression") 997 return f"IF(({d}) <> 0, ({n}) / ({d}), NULL)" 998 999 1000def no_tablesample_sql(self: Generator, expression: exp.TableSample) -> str: 1001 self.unsupported("TABLESAMPLE unsupported") 1002 return self.sql(expression.this) 1003 1004 1005def no_pivot_sql(self: Generator, expression: exp.Pivot) -> str: 1006 self.unsupported("PIVOT unsupported") 1007 return "" 1008 1009 1010def no_trycast_sql(self: Generator, expression: exp.TryCast) -> str: 1011 return self.cast_sql(expression) 1012 1013 1014def no_comment_column_constraint_sql( 1015 self: Generator, expression: exp.CommentColumnConstraint 1016) -> str: 1017 self.unsupported("CommentColumnConstraint unsupported") 1018 return "" 1019 1020 1021def no_map_from_entries_sql(self: Generator, expression: exp.MapFromEntries) -> str: 1022 self.unsupported("MAP_FROM_ENTRIES unsupported") 1023 return "" 1024 1025 1026def str_position_sql( 1027 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1028) -> str: 1029 this = self.sql(expression, "this") 1030 substr = self.sql(expression, "substr") 1031 position = self.sql(expression, "position") 1032 instance = expression.args.get("instance") if generate_instance else None 1033 position_offset = "" 1034 1035 if position: 1036 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1037 this = self.func("SUBSTR", this, position) 1038 position_offset = f" + {position} - 1" 1039 1040 return self.func("STRPOS", this, substr, instance) + position_offset 1041 1042 1043def struct_extract_sql(self: Generator, expression: exp.StructExtract) -> str: 1044 return ( 1045 f"{self.sql(expression, 'this')}.{self.sql(exp.to_identifier(expression.expression.name))}" 1046 ) 1047 1048 1049def var_map_sql( 1050 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1051) -> str: 1052 keys = expression.args["keys"] 1053 values = expression.args["values"] 1054 1055 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1056 self.unsupported("Cannot convert array columns into map.") 1057 return self.func(map_func_name, keys, values) 1058 1059 args = [] 1060 for key, value in zip(keys.expressions, values.expressions): 1061 args.append(self.sql(key)) 1062 args.append(self.sql(value)) 1063 1064 return self.func(map_func_name, *args) 1065 1066 1067def build_formatted_time( 1068 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1069) -> t.Callable[[t.List], E]: 1070 """Helper used for time expressions. 1071 1072 Args: 1073 exp_class: the expression class to instantiate. 1074 dialect: target sql dialect. 1075 default: the default format, True being time. 1076 1077 Returns: 1078 A callable that can be used to return the appropriately formatted time expression. 1079 """ 1080 1081 def _builder(args: t.List): 1082 return exp_class( 1083 this=seq_get(args, 0), 1084 format=Dialect[dialect].format_time( 1085 seq_get(args, 1) 1086 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1087 ), 1088 ) 1089 1090 return _builder 1091 1092 1093def time_format( 1094 dialect: DialectType = None, 1095) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1096 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1097 """ 1098 Returns the time format for a given expression, unless it's equivalent 1099 to the default time format of the dialect of interest. 1100 """ 1101 time_format = self.format_time(expression) 1102 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1103 1104 return _time_format 1105 1106 1107def build_date_delta( 1108 exp_class: t.Type[E], 1109 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1110 default_unit: t.Optional[str] = "DAY", 1111) -> t.Callable[[t.List], E]: 1112 def _builder(args: t.List) -> E: 1113 unit_based = len(args) == 3 1114 this = args[2] if unit_based else seq_get(args, 0) 1115 unit = None 1116 if unit_based or default_unit: 1117 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1118 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1119 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1120 1121 return _builder 1122 1123 1124def build_date_delta_with_interval( 1125 expression_class: t.Type[E], 1126) -> t.Callable[[t.List], t.Optional[E]]: 1127 def _builder(args: t.List) -> t.Optional[E]: 1128 if len(args) < 2: 1129 return None 1130 1131 interval = args[1] 1132 1133 if not isinstance(interval, exp.Interval): 1134 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1135 1136 expression = interval.this 1137 if expression and expression.is_string: 1138 expression = exp.Literal.number(expression.this) 1139 1140 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1141 1142 return _builder 1143 1144 1145def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1146 unit = seq_get(args, 0) 1147 this = seq_get(args, 1) 1148 1149 if isinstance(this, exp.Cast) and this.is_type("date"): 1150 return exp.DateTrunc(unit=unit, this=this) 1151 return exp.TimestampTrunc(this=this, unit=unit) 1152 1153 1154def date_add_interval_sql( 1155 data_type: str, kind: str 1156) -> t.Callable[[Generator, exp.Expression], str]: 1157 def func(self: Generator, expression: exp.Expression) -> str: 1158 this = self.sql(expression, "this") 1159 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1160 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1161 1162 return func 1163 1164 1165def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1166 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1167 args = [unit_to_str(expression), expression.this] 1168 if zone: 1169 args.append(expression.args.get("zone")) 1170 return self.func("DATE_TRUNC", *args) 1171 1172 return _timestamptrunc_sql 1173 1174 1175def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1176 zone = expression.args.get("zone") 1177 if not zone: 1178 from sqlglot.optimizer.annotate_types import annotate_types 1179 1180 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1181 return self.sql(exp.cast(expression.this, target_type)) 1182 if zone.name.lower() in TIMEZONES: 1183 return self.sql( 1184 exp.AtTimeZone( 1185 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1186 zone=zone, 1187 ) 1188 ) 1189 return self.func("TIMESTAMP", expression.this, zone) 1190 1191 1192def no_time_sql(self: Generator, expression: exp.Time) -> str: 1193 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1194 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1195 expr = exp.cast( 1196 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1197 ) 1198 return self.sql(expr) 1199 1200 1201def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1202 this = expression.this 1203 expr = expression.expression 1204 1205 if expr.name.lower() in TIMEZONES: 1206 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1207 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1208 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1209 return self.sql(this) 1210 1211 this = exp.cast(this, exp.DataType.Type.DATE) 1212 expr = exp.cast(expr, exp.DataType.Type.TIME) 1213 1214 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP)) 1215 1216 1217def locate_to_strposition(args: t.List) -> exp.Expression: 1218 return exp.StrPosition( 1219 this=seq_get(args, 1), substr=seq_get(args, 0), position=seq_get(args, 2) 1220 ) 1221 1222 1223def strposition_to_locate_sql(self: Generator, expression: exp.StrPosition) -> str: 1224 return self.func( 1225 "LOCATE", expression.args.get("substr"), expression.this, expression.args.get("position") 1226 ) 1227 1228 1229def left_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1230 return self.sql( 1231 exp.Substring( 1232 this=expression.this, start=exp.Literal.number(1), length=expression.expression 1233 ) 1234 ) 1235 1236 1237def right_to_substring_sql(self: Generator, expression: exp.Left) -> str: 1238 return self.sql( 1239 exp.Substring( 1240 this=expression.this, 1241 start=exp.Length(this=expression.this) - exp.paren(expression.expression - 1), 1242 ) 1243 ) 1244 1245 1246def timestrtotime_sql( 1247 self: Generator, 1248 expression: exp.TimeStrToTime, 1249 include_precision: bool = False, 1250) -> str: 1251 datatype = exp.DataType.build( 1252 exp.DataType.Type.TIMESTAMPTZ 1253 if expression.args.get("zone") 1254 else exp.DataType.Type.TIMESTAMP 1255 ) 1256 1257 if isinstance(expression.this, exp.Literal) and include_precision: 1258 precision = subsecond_precision(expression.this.name) 1259 if precision > 0: 1260 datatype = exp.DataType.build( 1261 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1262 ) 1263 1264 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect)) 1265 1266 1267def datestrtodate_sql(self: Generator, expression: exp.DateStrToDate) -> str: 1268 return self.sql(exp.cast(expression.this, exp.DataType.Type.DATE)) 1269 1270 1271# Used for Presto and Duckdb which use functions that don't support charset, and assume utf-8 1272def encode_decode_sql( 1273 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1274) -> str: 1275 charset = expression.args.get("charset") 1276 if charset and charset.name.lower() != "utf-8": 1277 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1278 1279 return self.func(name, expression.this, expression.args.get("replace") if replace else None) 1280 1281 1282def min_or_least(self: Generator, expression: exp.Min) -> str: 1283 name = "LEAST" if expression.expressions else "MIN" 1284 return rename_func(name)(self, expression) 1285 1286 1287def max_or_greatest(self: Generator, expression: exp.Max) -> str: 1288 name = "GREATEST" if expression.expressions else "MAX" 1289 return rename_func(name)(self, expression) 1290 1291 1292def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1293 cond = expression.this 1294 1295 if isinstance(expression.this, exp.Distinct): 1296 cond = expression.this.expressions[0] 1297 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1298 1299 return self.func("sum", exp.func("if", cond, 1, 0)) 1300 1301 1302def trim_sql(self: Generator, expression: exp.Trim) -> str: 1303 target = self.sql(expression, "this") 1304 trim_type = self.sql(expression, "position") 1305 remove_chars = self.sql(expression, "expression") 1306 collation = self.sql(expression, "collation") 1307 1308 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1309 if not remove_chars: 1310 return self.trim_sql(expression) 1311 1312 trim_type = f"{trim_type} " if trim_type else "" 1313 remove_chars = f"{remove_chars} " if remove_chars else "" 1314 from_part = "FROM " if trim_type or remove_chars else "" 1315 collation = f" COLLATE {collation}" if collation else "" 1316 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})" 1317 1318 1319def str_to_time_sql(self: Generator, expression: exp.Expression) -> str: 1320 return self.func("STRPTIME", expression.this, self.format_time(expression)) 1321 1322 1323def concat_to_dpipe_sql(self: Generator, expression: exp.Concat) -> str: 1324 return self.sql(reduce(lambda x, y: exp.DPipe(this=x, expression=y), expression.expressions)) 1325 1326 1327def concat_ws_to_dpipe_sql(self: Generator, expression: exp.ConcatWs) -> str: 1328 delim, *rest_args = expression.expressions 1329 return self.sql( 1330 reduce( 1331 lambda x, y: exp.DPipe(this=x, expression=exp.DPipe(this=delim, expression=y)), 1332 rest_args, 1333 ) 1334 ) 1335 1336 1337def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1338 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1339 if bad_args: 1340 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1341 1342 return self.func( 1343 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1344 ) 1345 1346 1347def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1348 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1349 if bad_args: 1350 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1351 1352 return self.func( 1353 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1354 ) 1355 1356 1357def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1358 names = [] 1359 for agg in aggregations: 1360 if isinstance(agg, exp.Alias): 1361 names.append(agg.alias) 1362 else: 1363 """ 1364 This case corresponds to aggregations without aliases being used as suffixes 1365 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1366 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1367 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1368 """ 1369 agg_all_unquoted = agg.transform( 1370 lambda node: ( 1371 exp.Identifier(this=node.name, quoted=False) 1372 if isinstance(node, exp.Identifier) 1373 else node 1374 ) 1375 ) 1376 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1377 1378 return names 1379 1380 1381def binary_from_function(expr_type: t.Type[B]) -> t.Callable[[t.List], B]: 1382 return lambda args: expr_type(this=seq_get(args, 0), expression=seq_get(args, 1)) 1383 1384 1385# Used to represent DATE_TRUNC in Doris, Postgres and Starrocks dialects 1386def build_timestamp_trunc(args: t.List) -> exp.TimestampTrunc: 1387 return exp.TimestampTrunc(this=seq_get(args, 1), unit=seq_get(args, 0)) 1388 1389 1390def any_value_to_max_sql(self: Generator, expression: exp.AnyValue) -> str: 1391 return self.func("MAX", expression.this) 1392 1393 1394def bool_xor_sql(self: Generator, expression: exp.Xor) -> str: 1395 a = self.sql(expression.left) 1396 b = self.sql(expression.right) 1397 return f"({a} AND (NOT {b})) OR ((NOT {a}) AND {b})" 1398 1399 1400def is_parse_json(expression: exp.Expression) -> bool: 1401 return isinstance(expression, exp.ParseJSON) or ( 1402 isinstance(expression, exp.Cast) and expression.is_type("json") 1403 ) 1404 1405 1406def isnull_to_is_null(args: t.List) -> exp.Expression: 1407 return exp.Paren(this=exp.Is(this=seq_get(args, 0), expression=exp.null())) 1408 1409 1410def generatedasidentitycolumnconstraint_sql( 1411 self: Generator, expression: exp.GeneratedAsIdentityColumnConstraint 1412) -> str: 1413 start = self.sql(expression, "start") or "1" 1414 increment = self.sql(expression, "increment") or "1" 1415 return f"IDENTITY({start}, {increment})" 1416 1417 1418def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1419 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1420 if expression.args.get("count"): 1421 self.unsupported(f"Only two arguments are supported in function {name}.") 1422 1423 return self.func(name, expression.this, expression.expression) 1424 1425 return _arg_max_or_min_sql 1426 1427 1428def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1429 this = expression.this.copy() 1430 1431 return_type = expression.return_type 1432 if return_type.is_type(exp.DataType.Type.DATE): 1433 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1434 # can truncate timestamp strings, because some dialects can't cast them to DATE 1435 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1436 1437 expression.this.replace(exp.cast(this, return_type)) 1438 return expression 1439 1440 1441def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1442 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1443 if cast and isinstance(expression, exp.TsOrDsAdd): 1444 expression = ts_or_ds_add_cast(expression) 1445 1446 return self.func( 1447 name, 1448 unit_to_var(expression), 1449 expression.expression, 1450 expression.this, 1451 ) 1452 1453 return _delta_sql 1454 1455 1456def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1457 unit = expression.args.get("unit") 1458 1459 if isinstance(unit, exp.Placeholder): 1460 return unit 1461 if unit: 1462 return exp.Literal.string(unit.name) 1463 return exp.Literal.string(default) if default else None 1464 1465 1466def unit_to_var(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1467 unit = expression.args.get("unit") 1468 1469 if isinstance(unit, (exp.Var, exp.Placeholder)): 1470 return unit 1471 return exp.Var(this=default) if default else None 1472 1473 1474@t.overload 1475def map_date_part(part: exp.Expression, dialect: DialectType = Dialect) -> exp.Var: 1476 pass 1477 1478 1479@t.overload 1480def map_date_part( 1481 part: t.Optional[exp.Expression], dialect: DialectType = Dialect 1482) -> t.Optional[exp.Expression]: 1483 pass 1484 1485 1486def map_date_part(part, dialect: DialectType = Dialect): 1487 mapped = ( 1488 Dialect.get_or_raise(dialect).DATE_PART_MAPPING.get(part.name.upper()) if part else None 1489 ) 1490 return exp.var(mapped) if mapped else part 1491 1492 1493def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1494 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1495 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1496 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1497 1498 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE)) 1499 1500 1501def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1502 """Remove table refs from columns in when statements.""" 1503 alias = expression.this.args.get("alias") 1504 1505 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1506 return self.dialect.normalize_identifier(identifier).name if identifier else None 1507 1508 targets = {normalize(expression.this.this)} 1509 1510 if alias: 1511 targets.add(normalize(alias.this)) 1512 1513 for when in expression.expressions: 1514 # only remove the target names from the THEN clause 1515 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1516 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1517 then = when.args.get("then") 1518 if then: 1519 then.transform( 1520 lambda node: ( 1521 exp.column(node.this) 1522 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1523 else node 1524 ), 1525 copy=False, 1526 ) 1527 1528 return self.merge_sql(expression) 1529 1530 1531def build_json_extract_path( 1532 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1533) -> t.Callable[[t.List], F]: 1534 def _builder(args: t.List) -> F: 1535 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1536 for arg in args[1:]: 1537 if not isinstance(arg, exp.Literal): 1538 # We use the fallback parser because we can't really transpile non-literals safely 1539 return expr_type.from_arg_list(args) 1540 1541 text = arg.name 1542 if is_int(text): 1543 index = int(text) 1544 segments.append( 1545 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1546 ) 1547 else: 1548 segments.append(exp.JSONPathKey(this=text)) 1549 1550 # This is done to avoid failing in the expression validator due to the arg count 1551 del args[2:] 1552 return expr_type( 1553 this=seq_get(args, 0), 1554 expression=exp.JSONPath(expressions=segments), 1555 only_json_types=arrow_req_json_type, 1556 ) 1557 1558 return _builder 1559 1560 1561def json_extract_segments( 1562 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1563) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1564 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1565 path = expression.expression 1566 if not isinstance(path, exp.JSONPath): 1567 return rename_func(name)(self, expression) 1568 1569 segments = [] 1570 for segment in path.expressions: 1571 path = self.sql(segment) 1572 if path: 1573 if isinstance(segment, exp.JSONPathPart) and ( 1574 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1575 ): 1576 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1577 1578 segments.append(path) 1579 1580 if op: 1581 return f" {op} ".join([self.sql(expression.this), *segments]) 1582 return self.func(name, expression.this, *segments) 1583 1584 return _json_extract_segments 1585 1586 1587def json_path_key_only_name(self: Generator, expression: exp.JSONPathKey) -> str: 1588 if isinstance(expression.this, exp.JSONPathWildcard): 1589 self.unsupported("Unsupported wildcard in JSONPathKey expression") 1590 1591 return expression.name 1592 1593 1594def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1595 cond = expression.expression 1596 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1597 alias = cond.expressions[0] 1598 cond = cond.this 1599 elif isinstance(cond, exp.Predicate): 1600 alias = "_u" 1601 else: 1602 self.unsupported("Unsupported filter condition") 1603 return "" 1604 1605 unnest = exp.Unnest(expressions=[expression.this]) 1606 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1607 return self.sql(exp.Array(expressions=[filtered])) 1608 1609 1610def to_number_with_nls_param(self: Generator, expression: exp.ToNumber) -> str: 1611 return self.func( 1612 "TO_NUMBER", 1613 expression.this, 1614 expression.args.get("format"), 1615 expression.args.get("nlsparam"), 1616 ) 1617 1618 1619def build_default_decimal_type( 1620 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1621) -> t.Callable[[exp.DataType], exp.DataType]: 1622 def _builder(dtype: exp.DataType) -> exp.DataType: 1623 if dtype.expressions or precision is None: 1624 return dtype 1625 1626 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1627 return exp.DataType.build(f"DECIMAL({params})") 1628 1629 return _builder 1630 1631 1632def build_timestamp_from_parts(args: t.List) -> exp.Func: 1633 if len(args) == 2: 1634 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1635 # so we parse this into Anonymous for now instead of introducing complexity 1636 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1637 1638 return exp.TimestampFromParts.from_arg_list(args) 1639 1640 1641def sha256_sql(self: Generator, expression: exp.SHA2) -> str: 1642 return self.func(f"SHA{expression.text('length') or '256'}", expression.this) 1643 1644 1645def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1646 start = expression.args.get("start") 1647 end = expression.args.get("end") 1648 step = expression.args.get("step") 1649 1650 if isinstance(start, exp.Cast): 1651 target_type = start.to 1652 elif isinstance(end, exp.Cast): 1653 target_type = end.to 1654 else: 1655 target_type = None 1656 1657 if start and end and target_type and target_type.is_type("date", "timestamp"): 1658 if isinstance(start, exp.Cast) and target_type is start.to: 1659 end = exp.cast(end, target_type) 1660 else: 1661 start = exp.cast(start, target_type) 1662 1663 return self.func("SEQUENCE", start, end, step)
49class Dialects(str, Enum): 50 """Dialects supported by SQLGLot.""" 51 52 DIALECT = "" 53 54 ATHENA = "athena" 55 BIGQUERY = "bigquery" 56 CLICKHOUSE = "clickhouse" 57 DATABRICKS = "databricks" 58 DORIS = "doris" 59 DRILL = "drill" 60 DUCKDB = "duckdb" 61 HIVE = "hive" 62 MATERIALIZE = "materialize" 63 MYSQL = "mysql" 64 ORACLE = "oracle" 65 POSTGRES = "postgres" 66 PRESTO = "presto" 67 PRQL = "prql" 68 REDSHIFT = "redshift" 69 RISINGWAVE = "risingwave" 70 SNOWFLAKE = "snowflake" 71 SPARK = "spark" 72 SPARK2 = "spark2" 73 SQLITE = "sqlite" 74 STARROCKS = "starrocks" 75 TABLEAU = "tableau" 76 TERADATA = "teradata" 77 TRINO = "trino" 78 TSQL = "tsql"
Dialects supported by SQLGLot.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
81class NormalizationStrategy(str, AutoName): 82 """Specifies the strategy according to which identifiers should be normalized.""" 83 84 LOWERCASE = auto() 85 """Unquoted identifiers are lowercased.""" 86 87 UPPERCASE = auto() 88 """Unquoted identifiers are uppercased.""" 89 90 CASE_SENSITIVE = auto() 91 """Always case-sensitive, regardless of quotes.""" 92 93 CASE_INSENSITIVE = auto() 94 """Always case-insensitive, regardless of quotes."""
Specifies the strategy according to which identifiers should be normalized.
Always case-sensitive, regardless of quotes.
Always case-insensitive, regardless of quotes.
Inherited Members
- enum.Enum
- name
- value
- builtins.str
- encode
- replace
- split
- rsplit
- join
- capitalize
- casefold
- title
- center
- count
- expandtabs
- find
- partition
- index
- ljust
- lower
- lstrip
- rfind
- rindex
- rjust
- rstrip
- rpartition
- splitlines
- strip
- swapcase
- translate
- upper
- startswith
- endswith
- removeprefix
- removesuffix
- isascii
- islower
- isupper
- istitle
- isspace
- isdecimal
- isdigit
- isnumeric
- isalpha
- isalnum
- isidentifier
- isprintable
- zfill
- format
- format_map
- maketrans
222class Dialect(metaclass=_Dialect): 223 INDEX_OFFSET = 0 224 """The base index offset for arrays.""" 225 226 WEEK_OFFSET = 0 227 """First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.""" 228 229 UNNEST_COLUMN_ONLY = False 230 """Whether `UNNEST` table aliases are treated as column aliases.""" 231 232 ALIAS_POST_TABLESAMPLE = False 233 """Whether the table alias comes after tablesample.""" 234 235 TABLESAMPLE_SIZE_IS_PERCENT = False 236 """Whether a size in the table sample clause represents percentage.""" 237 238 NORMALIZATION_STRATEGY = NormalizationStrategy.LOWERCASE 239 """Specifies the strategy according to which identifiers should be normalized.""" 240 241 IDENTIFIERS_CAN_START_WITH_DIGIT = False 242 """Whether an unquoted identifier can start with a digit.""" 243 244 DPIPE_IS_STRING_CONCAT = True 245 """Whether the DPIPE token (`||`) is a string concatenation operator.""" 246 247 STRICT_STRING_CONCAT = False 248 """Whether `CONCAT`'s arguments must be strings.""" 249 250 SUPPORTS_USER_DEFINED_TYPES = True 251 """Whether user-defined data types are supported.""" 252 253 SUPPORTS_SEMI_ANTI_JOIN = True 254 """Whether `SEMI` or `ANTI` joins are supported.""" 255 256 SUPPORTS_COLUMN_JOIN_MARKS = False 257 """Whether the old-style outer join (+) syntax is supported.""" 258 259 COPY_PARAMS_ARE_CSV = True 260 """Separator of COPY statement parameters.""" 261 262 NORMALIZE_FUNCTIONS: bool | str = "upper" 263 """ 264 Determines how function names are going to be normalized. 265 Possible values: 266 "upper" or True: Convert names to uppercase. 267 "lower": Convert names to lowercase. 268 False: Disables function name normalization. 269 """ 270 271 LOG_BASE_FIRST: t.Optional[bool] = True 272 """ 273 Whether the base comes first in the `LOG` function. 274 Possible values: `True`, `False`, `None` (two arguments are not supported by `LOG`) 275 """ 276 277 NULL_ORDERING = "nulls_are_small" 278 """ 279 Default `NULL` ordering method to use if not explicitly set. 280 Possible values: `"nulls_are_small"`, `"nulls_are_large"`, `"nulls_are_last"` 281 """ 282 283 TYPED_DIVISION = False 284 """ 285 Whether the behavior of `a / b` depends on the types of `a` and `b`. 286 False means `a / b` is always float division. 287 True means `a / b` is integer division if both `a` and `b` are integers. 288 """ 289 290 SAFE_DIVISION = False 291 """Whether division by zero throws an error (`False`) or returns NULL (`True`).""" 292 293 CONCAT_COALESCE = False 294 """A `NULL` arg in `CONCAT` yields `NULL` by default, but in some dialects it yields an empty string.""" 295 296 HEX_LOWERCASE = False 297 """Whether the `HEX` function returns a lowercase hexadecimal string.""" 298 299 DATE_FORMAT = "'%Y-%m-%d'" 300 DATEINT_FORMAT = "'%Y%m%d'" 301 TIME_FORMAT = "'%Y-%m-%d %H:%M:%S'" 302 303 TIME_MAPPING: t.Dict[str, str] = {} 304 """Associates this dialect's time formats with their equivalent Python `strftime` formats.""" 305 306 # https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#format_model_rules_date_time 307 # https://docs.teradata.com/r/Teradata-Database-SQL-Functions-Operators-Expressions-and-Predicates/March-2017/Data-Type-Conversions/Character-to-DATE-Conversion/Forcing-a-FORMAT-on-CAST-for-Converting-Character-to-DATE 308 FORMAT_MAPPING: t.Dict[str, str] = {} 309 """ 310 Helper which is used for parsing the special syntax `CAST(x AS DATE FORMAT 'yyyy')`. 311 If empty, the corresponding trie will be constructed off of `TIME_MAPPING`. 312 """ 313 314 UNESCAPED_SEQUENCES: t.Dict[str, str] = {} 315 """Mapping of an escaped sequence (`\\n`) to its unescaped version (`\n`).""" 316 317 PSEUDOCOLUMNS: t.Set[str] = set() 318 """ 319 Columns that are auto-generated by the engine corresponding to this dialect. 320 For example, such columns may be excluded from `SELECT *` queries. 321 """ 322 323 PREFER_CTE_ALIAS_COLUMN = False 324 """ 325 Some dialects, such as Snowflake, allow you to reference a CTE column alias in the 326 HAVING clause of the CTE. This flag will cause the CTE alias columns to override 327 any projection aliases in the subquery. 328 329 For example, 330 WITH y(c) AS ( 331 SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 332 ) SELECT c FROM y; 333 334 will be rewritten as 335 336 WITH y(c) AS ( 337 SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0 338 ) SELECT c FROM y; 339 """ 340 341 COPY_PARAMS_ARE_CSV = True 342 """ 343 Whether COPY statement parameters are separated by comma or whitespace 344 """ 345 346 FORCE_EARLY_ALIAS_REF_EXPANSION = False 347 """ 348 Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()). 349 350 For example: 351 WITH data AS ( 352 SELECT 353 1 AS id, 354 2 AS my_id 355 ) 356 SELECT 357 id AS my_id 358 FROM 359 data 360 WHERE 361 my_id = 1 362 GROUP BY 363 my_id, 364 HAVING 365 my_id = 1 366 367 In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: 368 - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" 369 - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1" 370 """ 371 372 EXPAND_ALIAS_REFS_EARLY_ONLY_IN_GROUP_BY = False 373 """Whether alias reference expansion before qualification should only happen for the GROUP BY clause.""" 374 375 SUPPORTS_ORDER_BY_ALL = False 376 """ 377 Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks 378 """ 379 380 HAS_DISTINCT_ARRAY_CONSTRUCTORS = False 381 """ 382 Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) 383 as the former is of type INT[] vs the latter which is SUPER 384 """ 385 386 SUPPORTS_FIXED_SIZE_ARRAYS = False 387 """ 388 Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In 389 dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator 390 """ 391 392 CREATABLE_KIND_MAPPING: dict[str, str] = {} 393 """ 394 Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse 395 equivalent of CREATE SCHEMA is CREATE DATABASE. 396 """ 397 398 # --- Autofilled --- 399 400 tokenizer_class = Tokenizer 401 jsonpath_tokenizer_class = JSONPathTokenizer 402 parser_class = Parser 403 generator_class = Generator 404 405 # A trie of the time_mapping keys 406 TIME_TRIE: t.Dict = {} 407 FORMAT_TRIE: t.Dict = {} 408 409 INVERSE_TIME_MAPPING: t.Dict[str, str] = {} 410 INVERSE_TIME_TRIE: t.Dict = {} 411 INVERSE_FORMAT_MAPPING: t.Dict[str, str] = {} 412 INVERSE_FORMAT_TRIE: t.Dict = {} 413 414 INVERSE_CREATABLE_KIND_MAPPING: dict[str, str] = {} 415 416 ESCAPED_SEQUENCES: t.Dict[str, str] = {} 417 418 # Delimiters for string literals and identifiers 419 QUOTE_START = "'" 420 QUOTE_END = "'" 421 IDENTIFIER_START = '"' 422 IDENTIFIER_END = '"' 423 424 # Delimiters for bit, hex, byte and unicode literals 425 BIT_START: t.Optional[str] = None 426 BIT_END: t.Optional[str] = None 427 HEX_START: t.Optional[str] = None 428 HEX_END: t.Optional[str] = None 429 BYTE_START: t.Optional[str] = None 430 BYTE_END: t.Optional[str] = None 431 UNICODE_START: t.Optional[str] = None 432 UNICODE_END: t.Optional[str] = None 433 434 DATE_PART_MAPPING = { 435 "Y": "YEAR", 436 "YY": "YEAR", 437 "YYY": "YEAR", 438 "YYYY": "YEAR", 439 "YR": "YEAR", 440 "YEARS": "YEAR", 441 "YRS": "YEAR", 442 "MM": "MONTH", 443 "MON": "MONTH", 444 "MONS": "MONTH", 445 "MONTHS": "MONTH", 446 "D": "DAY", 447 "DD": "DAY", 448 "DAYS": "DAY", 449 "DAYOFMONTH": "DAY", 450 "DAY OF WEEK": "DAYOFWEEK", 451 "WEEKDAY": "DAYOFWEEK", 452 "DOW": "DAYOFWEEK", 453 "DW": "DAYOFWEEK", 454 "WEEKDAY_ISO": "DAYOFWEEKISO", 455 "DOW_ISO": "DAYOFWEEKISO", 456 "DW_ISO": "DAYOFWEEKISO", 457 "DAY OF YEAR": "DAYOFYEAR", 458 "DOY": "DAYOFYEAR", 459 "DY": "DAYOFYEAR", 460 "W": "WEEK", 461 "WK": "WEEK", 462 "WEEKOFYEAR": "WEEK", 463 "WOY": "WEEK", 464 "WY": "WEEK", 465 "WEEK_ISO": "WEEKISO", 466 "WEEKOFYEARISO": "WEEKISO", 467 "WEEKOFYEAR_ISO": "WEEKISO", 468 "Q": "QUARTER", 469 "QTR": "QUARTER", 470 "QTRS": "QUARTER", 471 "QUARTERS": "QUARTER", 472 "H": "HOUR", 473 "HH": "HOUR", 474 "HR": "HOUR", 475 "HOURS": "HOUR", 476 "HRS": "HOUR", 477 "M": "MINUTE", 478 "MI": "MINUTE", 479 "MIN": "MINUTE", 480 "MINUTES": "MINUTE", 481 "MINS": "MINUTE", 482 "S": "SECOND", 483 "SEC": "SECOND", 484 "SECONDS": "SECOND", 485 "SECS": "SECOND", 486 "MS": "MILLISECOND", 487 "MSEC": "MILLISECOND", 488 "MSECS": "MILLISECOND", 489 "MSECOND": "MILLISECOND", 490 "MSECONDS": "MILLISECOND", 491 "MILLISEC": "MILLISECOND", 492 "MILLISECS": "MILLISECOND", 493 "MILLISECON": "MILLISECOND", 494 "MILLISECONDS": "MILLISECOND", 495 "US": "MICROSECOND", 496 "USEC": "MICROSECOND", 497 "USECS": "MICROSECOND", 498 "MICROSEC": "MICROSECOND", 499 "MICROSECS": "MICROSECOND", 500 "USECOND": "MICROSECOND", 501 "USECONDS": "MICROSECOND", 502 "MICROSECONDS": "MICROSECOND", 503 "NS": "NANOSECOND", 504 "NSEC": "NANOSECOND", 505 "NANOSEC": "NANOSECOND", 506 "NSECOND": "NANOSECOND", 507 "NSECONDS": "NANOSECOND", 508 "NANOSECS": "NANOSECOND", 509 "EPOCH_SECOND": "EPOCH", 510 "EPOCH_SECONDS": "EPOCH", 511 "EPOCH_MILLISECONDS": "EPOCH_MILLISECOND", 512 "EPOCH_MICROSECONDS": "EPOCH_MICROSECOND", 513 "EPOCH_NANOSECONDS": "EPOCH_NANOSECOND", 514 "TZH": "TIMEZONE_HOUR", 515 "TZM": "TIMEZONE_MINUTE", 516 "DEC": "DECADE", 517 "DECS": "DECADE", 518 "DECADES": "DECADE", 519 "MIL": "MILLENIUM", 520 "MILS": "MILLENIUM", 521 "MILLENIA": "MILLENIUM", 522 "C": "CENTURY", 523 "CENT": "CENTURY", 524 "CENTS": "CENTURY", 525 "CENTURIES": "CENTURY", 526 } 527 528 TYPE_TO_EXPRESSIONS: t.Dict[exp.DataType.Type, t.Set[t.Type[exp.Expression]]] = { 529 exp.DataType.Type.BIGINT: { 530 exp.ApproxDistinct, 531 exp.ArraySize, 532 exp.Count, 533 exp.Length, 534 }, 535 exp.DataType.Type.BOOLEAN: { 536 exp.Between, 537 exp.Boolean, 538 exp.In, 539 exp.RegexpLike, 540 }, 541 exp.DataType.Type.DATE: { 542 exp.CurrentDate, 543 exp.Date, 544 exp.DateFromParts, 545 exp.DateStrToDate, 546 exp.DiToDate, 547 exp.StrToDate, 548 exp.TimeStrToDate, 549 exp.TsOrDsToDate, 550 }, 551 exp.DataType.Type.DATETIME: { 552 exp.CurrentDatetime, 553 exp.Datetime, 554 exp.DatetimeAdd, 555 exp.DatetimeSub, 556 }, 557 exp.DataType.Type.DOUBLE: { 558 exp.ApproxQuantile, 559 exp.Avg, 560 exp.Div, 561 exp.Exp, 562 exp.Ln, 563 exp.Log, 564 exp.Pow, 565 exp.Quantile, 566 exp.Round, 567 exp.SafeDivide, 568 exp.Sqrt, 569 exp.Stddev, 570 exp.StddevPop, 571 exp.StddevSamp, 572 exp.Variance, 573 exp.VariancePop, 574 }, 575 exp.DataType.Type.INT: { 576 exp.Ceil, 577 exp.DatetimeDiff, 578 exp.DateDiff, 579 exp.TimestampDiff, 580 exp.TimeDiff, 581 exp.DateToDi, 582 exp.Levenshtein, 583 exp.Sign, 584 exp.StrPosition, 585 exp.TsOrDiToDi, 586 }, 587 exp.DataType.Type.JSON: { 588 exp.ParseJSON, 589 }, 590 exp.DataType.Type.TIME: { 591 exp.Time, 592 }, 593 exp.DataType.Type.TIMESTAMP: { 594 exp.CurrentTime, 595 exp.CurrentTimestamp, 596 exp.StrToTime, 597 exp.TimeAdd, 598 exp.TimeStrToTime, 599 exp.TimeSub, 600 exp.TimestampAdd, 601 exp.TimestampSub, 602 exp.UnixToTime, 603 }, 604 exp.DataType.Type.TINYINT: { 605 exp.Day, 606 exp.Month, 607 exp.Week, 608 exp.Year, 609 exp.Quarter, 610 }, 611 exp.DataType.Type.VARCHAR: { 612 exp.ArrayConcat, 613 exp.Concat, 614 exp.ConcatWs, 615 exp.DateToDateStr, 616 exp.GroupConcat, 617 exp.Initcap, 618 exp.Lower, 619 exp.Substring, 620 exp.TimeToStr, 621 exp.TimeToTimeStr, 622 exp.Trim, 623 exp.TsOrDsToDateStr, 624 exp.UnixToStr, 625 exp.UnixToTimeStr, 626 exp.Upper, 627 }, 628 } 629 630 ANNOTATORS: AnnotatorsType = { 631 **{ 632 expr_type: lambda self, e: self._annotate_unary(e) 633 for expr_type in subclasses(exp.__name__, (exp.Unary, exp.Alias)) 634 }, 635 **{ 636 expr_type: lambda self, e: self._annotate_binary(e) 637 for expr_type in subclasses(exp.__name__, exp.Binary) 638 }, 639 **{ 640 expr_type: _annotate_with_type_lambda(data_type) 641 for data_type, expressions in TYPE_TO_EXPRESSIONS.items() 642 for expr_type in expressions 643 }, 644 exp.Abs: lambda self, e: self._annotate_by_args(e, "this"), 645 exp.Anonymous: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 646 exp.Array: lambda self, e: self._annotate_by_args(e, "expressions", array=True), 647 exp.ArrayAgg: lambda self, e: self._annotate_by_args(e, "this", array=True), 648 exp.ArrayConcat: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 649 exp.Bracket: lambda self, e: self._annotate_bracket(e), 650 exp.Cast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 651 exp.Case: lambda self, e: self._annotate_by_args(e, "default", "ifs"), 652 exp.Coalesce: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 653 exp.DataType: lambda self, e: self._annotate_with_type(e, e.copy()), 654 exp.DateAdd: lambda self, e: self._annotate_timeunit(e), 655 exp.DateSub: lambda self, e: self._annotate_timeunit(e), 656 exp.DateTrunc: lambda self, e: self._annotate_timeunit(e), 657 exp.Distinct: lambda self, e: self._annotate_by_args(e, "expressions"), 658 exp.Div: lambda self, e: self._annotate_div(e), 659 exp.Dot: lambda self, e: self._annotate_dot(e), 660 exp.Explode: lambda self, e: self._annotate_explode(e), 661 exp.Extract: lambda self, e: self._annotate_extract(e), 662 exp.Filter: lambda self, e: self._annotate_by_args(e, "this"), 663 exp.GenerateDateArray: lambda self, e: self._annotate_with_type( 664 e, exp.DataType.build("ARRAY<DATE>") 665 ), 666 exp.GenerateTimestampArray: lambda self, e: self._annotate_with_type( 667 e, exp.DataType.build("ARRAY<TIMESTAMP>") 668 ), 669 exp.If: lambda self, e: self._annotate_by_args(e, "true", "false"), 670 exp.Interval: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.INTERVAL), 671 exp.Least: lambda self, e: self._annotate_by_args(e, "expressions"), 672 exp.Literal: lambda self, e: self._annotate_literal(e), 673 exp.Map: lambda self, e: self._annotate_map(e), 674 exp.Max: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 675 exp.Min: lambda self, e: self._annotate_by_args(e, "this", "expressions"), 676 exp.Null: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.NULL), 677 exp.Nullif: lambda self, e: self._annotate_by_args(e, "this", "expression"), 678 exp.PropertyEQ: lambda self, e: self._annotate_by_args(e, "expression"), 679 exp.Slice: lambda self, e: self._annotate_with_type(e, exp.DataType.Type.UNKNOWN), 680 exp.Struct: lambda self, e: self._annotate_struct(e), 681 exp.Sum: lambda self, e: self._annotate_by_args(e, "this", "expressions", promote=True), 682 exp.Timestamp: lambda self, e: self._annotate_with_type( 683 e, 684 exp.DataType.Type.TIMESTAMPTZ if e.args.get("with_tz") else exp.DataType.Type.TIMESTAMP, 685 ), 686 exp.ToMap: lambda self, e: self._annotate_to_map(e), 687 exp.TryCast: lambda self, e: self._annotate_with_type(e, e.args["to"]), 688 exp.Unnest: lambda self, e: self._annotate_unnest(e), 689 exp.VarMap: lambda self, e: self._annotate_map(e), 690 } 691 692 @classmethod 693 def get_or_raise(cls, dialect: DialectType) -> Dialect: 694 """ 695 Look up a dialect in the global dialect registry and return it if it exists. 696 697 Args: 698 dialect: The target dialect. If this is a string, it can be optionally followed by 699 additional key-value pairs that are separated by commas and are used to specify 700 dialect settings, such as whether the dialect's identifiers are case-sensitive. 701 702 Example: 703 >>> dialect = dialect_class = get_or_raise("duckdb") 704 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 705 706 Returns: 707 The corresponding Dialect instance. 708 """ 709 710 if not dialect: 711 return cls() 712 if isinstance(dialect, _Dialect): 713 return dialect() 714 if isinstance(dialect, Dialect): 715 return dialect 716 if isinstance(dialect, str): 717 try: 718 dialect_name, *kv_strings = dialect.split(",") 719 kv_pairs = (kv.split("=") for kv in kv_strings) 720 kwargs = {} 721 for pair in kv_pairs: 722 key = pair[0].strip() 723 value: t.Union[bool | str | None] = None 724 725 if len(pair) == 1: 726 # Default initialize standalone settings to True 727 value = True 728 elif len(pair) == 2: 729 value = pair[1].strip() 730 731 # Coerce the value to boolean if it matches to the truthy/falsy values below 732 value_lower = value.lower() 733 if value_lower in ("true", "1"): 734 value = True 735 elif value_lower in ("false", "0"): 736 value = False 737 738 kwargs[key] = value 739 740 except ValueError: 741 raise ValueError( 742 f"Invalid dialect format: '{dialect}'. " 743 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 744 ) 745 746 result = cls.get(dialect_name.strip()) 747 if not result: 748 from difflib import get_close_matches 749 750 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 751 if similar: 752 similar = f" Did you mean {similar}?" 753 754 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 755 756 return result(**kwargs) 757 758 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.") 759 760 @classmethod 761 def format_time( 762 cls, expression: t.Optional[str | exp.Expression] 763 ) -> t.Optional[exp.Expression]: 764 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 765 if isinstance(expression, str): 766 return exp.Literal.string( 767 # the time formats are quoted 768 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 769 ) 770 771 if expression and expression.is_string: 772 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 773 774 return expression 775 776 def __init__(self, **kwargs) -> None: 777 normalization_strategy = kwargs.pop("normalization_strategy", None) 778 779 if normalization_strategy is None: 780 self.normalization_strategy = self.NORMALIZATION_STRATEGY 781 else: 782 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 783 784 self.settings = kwargs 785 786 def __eq__(self, other: t.Any) -> bool: 787 # Does not currently take dialect state into account 788 return type(self) == other 789 790 def __hash__(self) -> int: 791 # Does not currently take dialect state into account 792 return hash(type(self)) 793 794 def normalize_identifier(self, expression: E) -> E: 795 """ 796 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 797 798 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 799 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 800 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 801 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 802 803 There are also dialects like Spark, which are case-insensitive even when quotes are 804 present, and dialects like MySQL, whose resolution rules match those employed by the 805 underlying operating system, for example they may always be case-sensitive in Linux. 806 807 Finally, the normalization behavior of some engines can even be controlled through flags, 808 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 809 810 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 811 that it can analyze queries in the optimizer and successfully capture their semantics. 812 """ 813 if ( 814 isinstance(expression, exp.Identifier) 815 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 816 and ( 817 not expression.quoted 818 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 819 ) 820 ): 821 expression.set( 822 "this", 823 ( 824 expression.this.upper() 825 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 826 else expression.this.lower() 827 ), 828 ) 829 830 return expression 831 832 def case_sensitive(self, text: str) -> bool: 833 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 834 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 835 return False 836 837 unsafe = ( 838 str.islower 839 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 840 else str.isupper 841 ) 842 return any(unsafe(char) for char in text) 843 844 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 845 """Checks if text can be identified given an identify option. 846 847 Args: 848 text: The text to check. 849 identify: 850 `"always"` or `True`: Always returns `True`. 851 `"safe"`: Only returns `True` if the identifier is case-insensitive. 852 853 Returns: 854 Whether the given text can be identified. 855 """ 856 if identify is True or identify == "always": 857 return True 858 859 if identify == "safe": 860 return not self.case_sensitive(text) 861 862 return False 863 864 def quote_identifier(self, expression: E, identify: bool = True) -> E: 865 """ 866 Adds quotes to a given identifier. 867 868 Args: 869 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 870 identify: If set to `False`, the quotes will only be added if the identifier is deemed 871 "unsafe", with respect to its characters and this dialect's normalization strategy. 872 """ 873 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 874 name = expression.this 875 expression.set( 876 "quoted", 877 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 878 ) 879 880 return expression 881 882 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 883 if isinstance(path, exp.Literal): 884 path_text = path.name 885 if path.is_number: 886 path_text = f"[{path_text}]" 887 try: 888 return parse_json_path(path_text, self) 889 except ParseError as e: 890 logger.warning(f"Invalid JSON path syntax. {str(e)}") 891 892 return path 893 894 def parse(self, sql: str, **opts) -> t.List[t.Optional[exp.Expression]]: 895 return self.parser(**opts).parse(self.tokenize(sql), sql) 896 897 def parse_into( 898 self, expression_type: exp.IntoType, sql: str, **opts 899 ) -> t.List[t.Optional[exp.Expression]]: 900 return self.parser(**opts).parse_into(expression_type, self.tokenize(sql), sql) 901 902 def generate(self, expression: exp.Expression, copy: bool = True, **opts) -> str: 903 return self.generator(**opts).generate(expression, copy=copy) 904 905 def transpile(self, sql: str, **opts) -> t.List[str]: 906 return [ 907 self.generate(expression, copy=False, **opts) if expression else "" 908 for expression in self.parse(sql) 909 ] 910 911 def tokenize(self, sql: str) -> t.List[Token]: 912 return self.tokenizer.tokenize(sql) 913 914 @property 915 def tokenizer(self) -> Tokenizer: 916 return self.tokenizer_class(dialect=self) 917 918 @property 919 def jsonpath_tokenizer(self) -> JSONPathTokenizer: 920 return self.jsonpath_tokenizer_class(dialect=self) 921 922 def parser(self, **opts) -> Parser: 923 return self.parser_class(dialect=self, **opts) 924 925 def generator(self, **opts) -> Generator: 926 return self.generator_class(dialect=self, **opts)
776 def __init__(self, **kwargs) -> None: 777 normalization_strategy = kwargs.pop("normalization_strategy", None) 778 779 if normalization_strategy is None: 780 self.normalization_strategy = self.NORMALIZATION_STRATEGY 781 else: 782 self.normalization_strategy = NormalizationStrategy(normalization_strategy.upper()) 783 784 self.settings = kwargs
First day of the week in DATE_TRUNC(week). Defaults to 0 (Monday). -1 would be Sunday.
Whether a size in the table sample clause represents percentage.
Specifies the strategy according to which identifiers should be normalized.
Determines how function names are going to be normalized.
Possible values:
"upper" or True: Convert names to uppercase. "lower": Convert names to lowercase. False: Disables function name normalization.
Whether the base comes first in the LOG
function.
Possible values: True
, False
, None
(two arguments are not supported by LOG
)
Default NULL
ordering method to use if not explicitly set.
Possible values: "nulls_are_small"
, "nulls_are_large"
, "nulls_are_last"
Whether the behavior of a / b
depends on the types of a
and b
.
False means a / b
is always float division.
True means a / b
is integer division if both a
and b
are integers.
A NULL
arg in CONCAT
yields NULL
by default, but in some dialects it yields an empty string.
Associates this dialect's time formats with their equivalent Python strftime
formats.
Helper which is used for parsing the special syntax CAST(x AS DATE FORMAT 'yyyy')
.
If empty, the corresponding trie will be constructed off of TIME_MAPPING
.
Mapping of an escaped sequence (\n
) to its unescaped version (
).
Columns that are auto-generated by the engine corresponding to this dialect.
For example, such columns may be excluded from SELECT *
queries.
Some dialects, such as Snowflake, allow you to reference a CTE column alias in the HAVING clause of the CTE. This flag will cause the CTE alias columns to override any projection aliases in the subquery.
For example, WITH y(c) AS ( SELECT SUM(a) FROM (SELECT 1 a) AS x HAVING c > 0 ) SELECT c FROM y;
will be rewritten as
WITH y(c) AS (
SELECT SUM(a) AS c FROM (SELECT 1 AS a) AS x HAVING c > 0
) SELECT c FROM y;
Whether alias reference expansion (_expand_alias_refs()) should run before column qualification (_qualify_columns()).
For example:
WITH data AS ( SELECT 1 AS id, 2 AS my_id ) SELECT id AS my_id FROM data WHERE my_id = 1 GROUP BY my_id, HAVING my_id = 1
In most dialects "my_id" would refer to "data.my_id" (which is done in _qualify_columns()) across the query, except: - BigQuery, which will forward the alias to GROUP BY + HAVING clauses i.e it resolves to "WHERE my_id = 1 GROUP BY id HAVING id = 1" - Clickhouse, which will forward the alias across the query i.e it resolves to "WHERE id = 1 GROUP BY id HAVING id = 1"
Whether alias reference expansion before qualification should only happen for the GROUP BY clause.
Whether ORDER BY ALL is supported (expands to all the selected columns) as in DuckDB, Spark3/Databricks
Whether the ARRAY constructor is context-sensitive, i.e in Redshift ARRAY[1, 2, 3] != ARRAY(1, 2, 3) as the former is of type INT[] vs the latter which is SUPER
Whether expressions such as x::INT[5] should be parsed as fixed-size array defs/casts e.g. in DuckDB. In dialects which don't support fixed size arrays such as Snowflake, this should be interpreted as a subscript/index operator
Helper for dialects that use a different name for the same creatable kind. For example, the Clickhouse equivalent of CREATE SCHEMA is CREATE DATABASE.
692 @classmethod 693 def get_or_raise(cls, dialect: DialectType) -> Dialect: 694 """ 695 Look up a dialect in the global dialect registry and return it if it exists. 696 697 Args: 698 dialect: The target dialect. If this is a string, it can be optionally followed by 699 additional key-value pairs that are separated by commas and are used to specify 700 dialect settings, such as whether the dialect's identifiers are case-sensitive. 701 702 Example: 703 >>> dialect = dialect_class = get_or_raise("duckdb") 704 >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive") 705 706 Returns: 707 The corresponding Dialect instance. 708 """ 709 710 if not dialect: 711 return cls() 712 if isinstance(dialect, _Dialect): 713 return dialect() 714 if isinstance(dialect, Dialect): 715 return dialect 716 if isinstance(dialect, str): 717 try: 718 dialect_name, *kv_strings = dialect.split(",") 719 kv_pairs = (kv.split("=") for kv in kv_strings) 720 kwargs = {} 721 for pair in kv_pairs: 722 key = pair[0].strip() 723 value: t.Union[bool | str | None] = None 724 725 if len(pair) == 1: 726 # Default initialize standalone settings to True 727 value = True 728 elif len(pair) == 2: 729 value = pair[1].strip() 730 731 # Coerce the value to boolean if it matches to the truthy/falsy values below 732 value_lower = value.lower() 733 if value_lower in ("true", "1"): 734 value = True 735 elif value_lower in ("false", "0"): 736 value = False 737 738 kwargs[key] = value 739 740 except ValueError: 741 raise ValueError( 742 f"Invalid dialect format: '{dialect}'. " 743 "Please use the correct format: 'dialect [, k1 = v2 [, ...]]'." 744 ) 745 746 result = cls.get(dialect_name.strip()) 747 if not result: 748 from difflib import get_close_matches 749 750 similar = seq_get(get_close_matches(dialect_name, cls.classes, n=1), 0) or "" 751 if similar: 752 similar = f" Did you mean {similar}?" 753 754 raise ValueError(f"Unknown dialect '{dialect_name}'.{similar}") 755 756 return result(**kwargs) 757 758 raise ValueError(f"Invalid dialect type for '{dialect}': '{type(dialect)}'.")
Look up a dialect in the global dialect registry and return it if it exists.
Arguments:
- dialect: The target dialect. If this is a string, it can be optionally followed by additional key-value pairs that are separated by commas and are used to specify dialect settings, such as whether the dialect's identifiers are case-sensitive.
Example:
>>> dialect = dialect_class = get_or_raise("duckdb") >>> dialect = get_or_raise("mysql, normalization_strategy = case_sensitive")
Returns:
The corresponding Dialect instance.
760 @classmethod 761 def format_time( 762 cls, expression: t.Optional[str | exp.Expression] 763 ) -> t.Optional[exp.Expression]: 764 """Converts a time format in this dialect to its equivalent Python `strftime` format.""" 765 if isinstance(expression, str): 766 return exp.Literal.string( 767 # the time formats are quoted 768 format_time(expression[1:-1], cls.TIME_MAPPING, cls.TIME_TRIE) 769 ) 770 771 if expression and expression.is_string: 772 return exp.Literal.string(format_time(expression.this, cls.TIME_MAPPING, cls.TIME_TRIE)) 773 774 return expression
Converts a time format in this dialect to its equivalent Python strftime
format.
794 def normalize_identifier(self, expression: E) -> E: 795 """ 796 Transforms an identifier in a way that resembles how it'd be resolved by this dialect. 797 798 For example, an identifier like `FoO` would be resolved as `foo` in Postgres, because it 799 lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so 800 it would resolve it as `FOO`. If it was quoted, it'd need to be treated as case-sensitive, 801 and so any normalization would be prohibited in order to avoid "breaking" the identifier. 802 803 There are also dialects like Spark, which are case-insensitive even when quotes are 804 present, and dialects like MySQL, whose resolution rules match those employed by the 805 underlying operating system, for example they may always be case-sensitive in Linux. 806 807 Finally, the normalization behavior of some engines can even be controlled through flags, 808 like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier. 809 810 SQLGlot aims to understand and handle all of these different behaviors gracefully, so 811 that it can analyze queries in the optimizer and successfully capture their semantics. 812 """ 813 if ( 814 isinstance(expression, exp.Identifier) 815 and self.normalization_strategy is not NormalizationStrategy.CASE_SENSITIVE 816 and ( 817 not expression.quoted 818 or self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE 819 ) 820 ): 821 expression.set( 822 "this", 823 ( 824 expression.this.upper() 825 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 826 else expression.this.lower() 827 ), 828 ) 829 830 return expression
Transforms an identifier in a way that resembles how it'd be resolved by this dialect.
For example, an identifier like FoO
would be resolved as foo
in Postgres, because it
lowercases all unquoted identifiers. On the other hand, Snowflake uppercases them, so
it would resolve it as FOO
. If it was quoted, it'd need to be treated as case-sensitive,
and so any normalization would be prohibited in order to avoid "breaking" the identifier.
There are also dialects like Spark, which are case-insensitive even when quotes are present, and dialects like MySQL, whose resolution rules match those employed by the underlying operating system, for example they may always be case-sensitive in Linux.
Finally, the normalization behavior of some engines can even be controlled through flags, like in Redshift's case, where users can explicitly set enable_case_sensitive_identifier.
SQLGlot aims to understand and handle all of these different behaviors gracefully, so that it can analyze queries in the optimizer and successfully capture their semantics.
832 def case_sensitive(self, text: str) -> bool: 833 """Checks if text contains any case sensitive characters, based on the dialect's rules.""" 834 if self.normalization_strategy is NormalizationStrategy.CASE_INSENSITIVE: 835 return False 836 837 unsafe = ( 838 str.islower 839 if self.normalization_strategy is NormalizationStrategy.UPPERCASE 840 else str.isupper 841 ) 842 return any(unsafe(char) for char in text)
Checks if text contains any case sensitive characters, based on the dialect's rules.
844 def can_identify(self, text: str, identify: str | bool = "safe") -> bool: 845 """Checks if text can be identified given an identify option. 846 847 Args: 848 text: The text to check. 849 identify: 850 `"always"` or `True`: Always returns `True`. 851 `"safe"`: Only returns `True` if the identifier is case-insensitive. 852 853 Returns: 854 Whether the given text can be identified. 855 """ 856 if identify is True or identify == "always": 857 return True 858 859 if identify == "safe": 860 return not self.case_sensitive(text) 861 862 return False
Checks if text can be identified given an identify option.
Arguments:
- text: The text to check.
- identify:
"always"
orTrue
: Always returnsTrue
."safe"
: Only returnsTrue
if the identifier is case-insensitive.
Returns:
Whether the given text can be identified.
864 def quote_identifier(self, expression: E, identify: bool = True) -> E: 865 """ 866 Adds quotes to a given identifier. 867 868 Args: 869 expression: The expression of interest. If it's not an `Identifier`, this method is a no-op. 870 identify: If set to `False`, the quotes will only be added if the identifier is deemed 871 "unsafe", with respect to its characters and this dialect's normalization strategy. 872 """ 873 if isinstance(expression, exp.Identifier) and not isinstance(expression.parent, exp.Func): 874 name = expression.this 875 expression.set( 876 "quoted", 877 identify or self.case_sensitive(name) or not exp.SAFE_IDENTIFIER_RE.match(name), 878 ) 879 880 return expression
Adds quotes to a given identifier.
Arguments:
- expression: The expression of interest. If it's not an
Identifier
, this method is a no-op. - identify: If set to
False
, the quotes will only be added if the identifier is deemed "unsafe", with respect to its characters and this dialect's normalization strategy.
882 def to_json_path(self, path: t.Optional[exp.Expression]) -> t.Optional[exp.Expression]: 883 if isinstance(path, exp.Literal): 884 path_text = path.name 885 if path.is_number: 886 path_text = f"[{path_text}]" 887 try: 888 return parse_json_path(path_text, self) 889 except ParseError as e: 890 logger.warning(f"Invalid JSON path syntax. {str(e)}") 891 892 return path
942def if_sql( 943 name: str = "IF", false_value: t.Optional[exp.Expression | str] = None 944) -> t.Callable[[Generator, exp.If], str]: 945 def _if_sql(self: Generator, expression: exp.If) -> str: 946 return self.func( 947 name, 948 expression.this, 949 expression.args.get("true"), 950 expression.args.get("false") or false_value, 951 ) 952 953 return _if_sql
956def arrow_json_extract_sql(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 957 this = expression.this 958 if self.JSON_TYPE_REQUIRED_FOR_EXTRACTION and isinstance(this, exp.Literal) and this.is_string: 959 this.replace(exp.cast(this, exp.DataType.Type.JSON)) 960 961 return self.binary(expression, "->" if isinstance(expression, exp.JSONExtract) else "->>")
1027def str_position_sql( 1028 self: Generator, expression: exp.StrPosition, generate_instance: bool = False 1029) -> str: 1030 this = self.sql(expression, "this") 1031 substr = self.sql(expression, "substr") 1032 position = self.sql(expression, "position") 1033 instance = expression.args.get("instance") if generate_instance else None 1034 position_offset = "" 1035 1036 if position: 1037 # Normalize third 'pos' argument into 'SUBSTR(..) + offset' across dialects 1038 this = self.func("SUBSTR", this, position) 1039 position_offset = f" + {position} - 1" 1040 1041 return self.func("STRPOS", this, substr, instance) + position_offset
1050def var_map_sql( 1051 self: Generator, expression: exp.Map | exp.VarMap, map_func_name: str = "MAP" 1052) -> str: 1053 keys = expression.args["keys"] 1054 values = expression.args["values"] 1055 1056 if not isinstance(keys, exp.Array) or not isinstance(values, exp.Array): 1057 self.unsupported("Cannot convert array columns into map.") 1058 return self.func(map_func_name, keys, values) 1059 1060 args = [] 1061 for key, value in zip(keys.expressions, values.expressions): 1062 args.append(self.sql(key)) 1063 args.append(self.sql(value)) 1064 1065 return self.func(map_func_name, *args)
1068def build_formatted_time( 1069 exp_class: t.Type[E], dialect: str, default: t.Optional[bool | str] = None 1070) -> t.Callable[[t.List], E]: 1071 """Helper used for time expressions. 1072 1073 Args: 1074 exp_class: the expression class to instantiate. 1075 dialect: target sql dialect. 1076 default: the default format, True being time. 1077 1078 Returns: 1079 A callable that can be used to return the appropriately formatted time expression. 1080 """ 1081 1082 def _builder(args: t.List): 1083 return exp_class( 1084 this=seq_get(args, 0), 1085 format=Dialect[dialect].format_time( 1086 seq_get(args, 1) 1087 or (Dialect[dialect].TIME_FORMAT if default is True else default or None) 1088 ), 1089 ) 1090 1091 return _builder
Helper used for time expressions.
Arguments:
- exp_class: the expression class to instantiate.
- dialect: target sql dialect.
- default: the default format, True being time.
Returns:
A callable that can be used to return the appropriately formatted time expression.
1094def time_format( 1095 dialect: DialectType = None, 1096) -> t.Callable[[Generator, exp.UnixToStr | exp.StrToUnix], t.Optional[str]]: 1097 def _time_format(self: Generator, expression: exp.UnixToStr | exp.StrToUnix) -> t.Optional[str]: 1098 """ 1099 Returns the time format for a given expression, unless it's equivalent 1100 to the default time format of the dialect of interest. 1101 """ 1102 time_format = self.format_time(expression) 1103 return time_format if time_format != Dialect.get_or_raise(dialect).TIME_FORMAT else None 1104 1105 return _time_format
1108def build_date_delta( 1109 exp_class: t.Type[E], 1110 unit_mapping: t.Optional[t.Dict[str, str]] = None, 1111 default_unit: t.Optional[str] = "DAY", 1112) -> t.Callable[[t.List], E]: 1113 def _builder(args: t.List) -> E: 1114 unit_based = len(args) == 3 1115 this = args[2] if unit_based else seq_get(args, 0) 1116 unit = None 1117 if unit_based or default_unit: 1118 unit = args[0] if unit_based else exp.Literal.string(default_unit) 1119 unit = exp.var(unit_mapping.get(unit.name.lower(), unit.name)) if unit_mapping else unit 1120 return exp_class(this=this, expression=seq_get(args, 1), unit=unit) 1121 1122 return _builder
1125def build_date_delta_with_interval( 1126 expression_class: t.Type[E], 1127) -> t.Callable[[t.List], t.Optional[E]]: 1128 def _builder(args: t.List) -> t.Optional[E]: 1129 if len(args) < 2: 1130 return None 1131 1132 interval = args[1] 1133 1134 if not isinstance(interval, exp.Interval): 1135 raise ParseError(f"INTERVAL expression expected but got '{interval}'") 1136 1137 expression = interval.this 1138 if expression and expression.is_string: 1139 expression = exp.Literal.number(expression.this) 1140 1141 return expression_class(this=args[0], expression=expression, unit=unit_to_str(interval)) 1142 1143 return _builder
1146def date_trunc_to_time(args: t.List) -> exp.DateTrunc | exp.TimestampTrunc: 1147 unit = seq_get(args, 0) 1148 this = seq_get(args, 1) 1149 1150 if isinstance(this, exp.Cast) and this.is_type("date"): 1151 return exp.DateTrunc(unit=unit, this=this) 1152 return exp.TimestampTrunc(this=this, unit=unit)
1155def date_add_interval_sql( 1156 data_type: str, kind: str 1157) -> t.Callable[[Generator, exp.Expression], str]: 1158 def func(self: Generator, expression: exp.Expression) -> str: 1159 this = self.sql(expression, "this") 1160 interval = exp.Interval(this=expression.expression, unit=unit_to_var(expression)) 1161 return f"{data_type}_{kind}({this}, {self.sql(interval)})" 1162 1163 return func
1166def timestamptrunc_sql(zone: bool = False) -> t.Callable[[Generator, exp.TimestampTrunc], str]: 1167 def _timestamptrunc_sql(self: Generator, expression: exp.TimestampTrunc) -> str: 1168 args = [unit_to_str(expression), expression.this] 1169 if zone: 1170 args.append(expression.args.get("zone")) 1171 return self.func("DATE_TRUNC", *args) 1172 1173 return _timestamptrunc_sql
1176def no_timestamp_sql(self: Generator, expression: exp.Timestamp) -> str: 1177 zone = expression.args.get("zone") 1178 if not zone: 1179 from sqlglot.optimizer.annotate_types import annotate_types 1180 1181 target_type = annotate_types(expression).type or exp.DataType.Type.TIMESTAMP 1182 return self.sql(exp.cast(expression.this, target_type)) 1183 if zone.name.lower() in TIMEZONES: 1184 return self.sql( 1185 exp.AtTimeZone( 1186 this=exp.cast(expression.this, exp.DataType.Type.TIMESTAMP), 1187 zone=zone, 1188 ) 1189 ) 1190 return self.func("TIMESTAMP", expression.this, zone)
1193def no_time_sql(self: Generator, expression: exp.Time) -> str: 1194 # Transpile BQ's TIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIME) 1195 this = exp.cast(expression.this, exp.DataType.Type.TIMESTAMPTZ) 1196 expr = exp.cast( 1197 exp.AtTimeZone(this=this, zone=expression.args.get("zone")), exp.DataType.Type.TIME 1198 ) 1199 return self.sql(expr)
1202def no_datetime_sql(self: Generator, expression: exp.Datetime) -> str: 1203 this = expression.this 1204 expr = expression.expression 1205 1206 if expr.name.lower() in TIMEZONES: 1207 # Transpile BQ's DATETIME(timestamp, zone) to CAST(TIMESTAMPTZ <timestamp> AT TIME ZONE <zone> AS TIMESTAMP) 1208 this = exp.cast(this, exp.DataType.Type.TIMESTAMPTZ) 1209 this = exp.cast(exp.AtTimeZone(this=this, zone=expr), exp.DataType.Type.TIMESTAMP) 1210 return self.sql(this) 1211 1212 this = exp.cast(this, exp.DataType.Type.DATE) 1213 expr = exp.cast(expr, exp.DataType.Type.TIME) 1214 1215 return self.sql(exp.cast(exp.Add(this=this, expression=expr), exp.DataType.Type.TIMESTAMP))
1247def timestrtotime_sql( 1248 self: Generator, 1249 expression: exp.TimeStrToTime, 1250 include_precision: bool = False, 1251) -> str: 1252 datatype = exp.DataType.build( 1253 exp.DataType.Type.TIMESTAMPTZ 1254 if expression.args.get("zone") 1255 else exp.DataType.Type.TIMESTAMP 1256 ) 1257 1258 if isinstance(expression.this, exp.Literal) and include_precision: 1259 precision = subsecond_precision(expression.this.name) 1260 if precision > 0: 1261 datatype = exp.DataType.build( 1262 datatype.this, expressions=[exp.DataTypeParam(this=exp.Literal.number(precision))] 1263 ) 1264 1265 return self.sql(exp.cast(expression.this, datatype, dialect=self.dialect))
1273def encode_decode_sql( 1274 self: Generator, expression: exp.Expression, name: str, replace: bool = True 1275) -> str: 1276 charset = expression.args.get("charset") 1277 if charset and charset.name.lower() != "utf-8": 1278 self.unsupported(f"Expected utf-8 character set, got {charset}.") 1279 1280 return self.func(name, expression.this, expression.args.get("replace") if replace else None)
1293def count_if_to_sum(self: Generator, expression: exp.CountIf) -> str: 1294 cond = expression.this 1295 1296 if isinstance(expression.this, exp.Distinct): 1297 cond = expression.this.expressions[0] 1298 self.unsupported("DISTINCT is not supported when converting COUNT_IF to SUM") 1299 1300 return self.func("sum", exp.func("if", cond, 1, 0))
1303def trim_sql(self: Generator, expression: exp.Trim) -> str: 1304 target = self.sql(expression, "this") 1305 trim_type = self.sql(expression, "position") 1306 remove_chars = self.sql(expression, "expression") 1307 collation = self.sql(expression, "collation") 1308 1309 # Use TRIM/LTRIM/RTRIM syntax if the expression isn't database-specific 1310 if not remove_chars: 1311 return self.trim_sql(expression) 1312 1313 trim_type = f"{trim_type} " if trim_type else "" 1314 remove_chars = f"{remove_chars} " if remove_chars else "" 1315 from_part = "FROM " if trim_type or remove_chars else "" 1316 collation = f" COLLATE {collation}" if collation else "" 1317 return f"TRIM({trim_type}{remove_chars}{from_part}{target}{collation})"
1338def regexp_extract_sql(self: Generator, expression: exp.RegexpExtract) -> str: 1339 bad_args = list(filter(expression.args.get, ("position", "occurrence", "parameters"))) 1340 if bad_args: 1341 self.unsupported(f"REGEXP_EXTRACT does not support the following arg(s): {bad_args}") 1342 1343 return self.func( 1344 "REGEXP_EXTRACT", expression.this, expression.expression, expression.args.get("group") 1345 )
1348def regexp_replace_sql(self: Generator, expression: exp.RegexpReplace) -> str: 1349 bad_args = list(filter(expression.args.get, ("position", "occurrence", "modifiers"))) 1350 if bad_args: 1351 self.unsupported(f"REGEXP_REPLACE does not support the following arg(s): {bad_args}") 1352 1353 return self.func( 1354 "REGEXP_REPLACE", expression.this, expression.expression, expression.args["replacement"] 1355 )
1358def pivot_column_names(aggregations: t.List[exp.Expression], dialect: DialectType) -> t.List[str]: 1359 names = [] 1360 for agg in aggregations: 1361 if isinstance(agg, exp.Alias): 1362 names.append(agg.alias) 1363 else: 1364 """ 1365 This case corresponds to aggregations without aliases being used as suffixes 1366 (e.g. col_avg(foo)). We need to unquote identifiers because they're going to 1367 be quoted in the base parser's `_parse_pivot` method, due to `to_identifier`. 1368 Otherwise, we'd end up with `col_avg(`foo`)` (notice the double quotes). 1369 """ 1370 agg_all_unquoted = agg.transform( 1371 lambda node: ( 1372 exp.Identifier(this=node.name, quoted=False) 1373 if isinstance(node, exp.Identifier) 1374 else node 1375 ) 1376 ) 1377 names.append(agg_all_unquoted.sql(dialect=dialect, normalize_functions="lower")) 1378 1379 return names
1419def arg_max_or_min_no_count(name: str) -> t.Callable[[Generator, exp.ArgMax | exp.ArgMin], str]: 1420 def _arg_max_or_min_sql(self: Generator, expression: exp.ArgMax | exp.ArgMin) -> str: 1421 if expression.args.get("count"): 1422 self.unsupported(f"Only two arguments are supported in function {name}.") 1423 1424 return self.func(name, expression.this, expression.expression) 1425 1426 return _arg_max_or_min_sql
1429def ts_or_ds_add_cast(expression: exp.TsOrDsAdd) -> exp.TsOrDsAdd: 1430 this = expression.this.copy() 1431 1432 return_type = expression.return_type 1433 if return_type.is_type(exp.DataType.Type.DATE): 1434 # If we need to cast to a DATE, we cast to TIMESTAMP first to make sure we 1435 # can truncate timestamp strings, because some dialects can't cast them to DATE 1436 this = exp.cast(this, exp.DataType.Type.TIMESTAMP) 1437 1438 expression.this.replace(exp.cast(this, return_type)) 1439 return expression
1442def date_delta_sql(name: str, cast: bool = False) -> t.Callable[[Generator, DATE_ADD_OR_DIFF], str]: 1443 def _delta_sql(self: Generator, expression: DATE_ADD_OR_DIFF) -> str: 1444 if cast and isinstance(expression, exp.TsOrDsAdd): 1445 expression = ts_or_ds_add_cast(expression) 1446 1447 return self.func( 1448 name, 1449 unit_to_var(expression), 1450 expression.expression, 1451 expression.this, 1452 ) 1453 1454 return _delta_sql
1457def unit_to_str(expression: exp.Expression, default: str = "DAY") -> t.Optional[exp.Expression]: 1458 unit = expression.args.get("unit") 1459 1460 if isinstance(unit, exp.Placeholder): 1461 return unit 1462 if unit: 1463 return exp.Literal.string(unit.name) 1464 return exp.Literal.string(default) if default else None
1494def no_last_day_sql(self: Generator, expression: exp.LastDay) -> str: 1495 trunc_curr_date = exp.func("date_trunc", "month", expression.this) 1496 plus_one_month = exp.func("date_add", trunc_curr_date, 1, "month") 1497 minus_one_day = exp.func("date_sub", plus_one_month, 1, "day") 1498 1499 return self.sql(exp.cast(minus_one_day, exp.DataType.Type.DATE))
1502def merge_without_target_sql(self: Generator, expression: exp.Merge) -> str: 1503 """Remove table refs from columns in when statements.""" 1504 alias = expression.this.args.get("alias") 1505 1506 def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]: 1507 return self.dialect.normalize_identifier(identifier).name if identifier else None 1508 1509 targets = {normalize(expression.this.this)} 1510 1511 if alias: 1512 targets.add(normalize(alias.this)) 1513 1514 for when in expression.expressions: 1515 # only remove the target names from the THEN clause 1516 # theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED 1517 # ref: https://github.com/TobikoData/sqlmesh/issues/2934 1518 then = when.args.get("then") 1519 if then: 1520 then.transform( 1521 lambda node: ( 1522 exp.column(node.this) 1523 if isinstance(node, exp.Column) and normalize(node.args.get("table")) in targets 1524 else node 1525 ), 1526 copy=False, 1527 ) 1528 1529 return self.merge_sql(expression)
Remove table refs from columns in when statements.
1532def build_json_extract_path( 1533 expr_type: t.Type[F], zero_based_indexing: bool = True, arrow_req_json_type: bool = False 1534) -> t.Callable[[t.List], F]: 1535 def _builder(args: t.List) -> F: 1536 segments: t.List[exp.JSONPathPart] = [exp.JSONPathRoot()] 1537 for arg in args[1:]: 1538 if not isinstance(arg, exp.Literal): 1539 # We use the fallback parser because we can't really transpile non-literals safely 1540 return expr_type.from_arg_list(args) 1541 1542 text = arg.name 1543 if is_int(text): 1544 index = int(text) 1545 segments.append( 1546 exp.JSONPathSubscript(this=index if zero_based_indexing else index - 1) 1547 ) 1548 else: 1549 segments.append(exp.JSONPathKey(this=text)) 1550 1551 # This is done to avoid failing in the expression validator due to the arg count 1552 del args[2:] 1553 return expr_type( 1554 this=seq_get(args, 0), 1555 expression=exp.JSONPath(expressions=segments), 1556 only_json_types=arrow_req_json_type, 1557 ) 1558 1559 return _builder
1562def json_extract_segments( 1563 name: str, quoted_index: bool = True, op: t.Optional[str] = None 1564) -> t.Callable[[Generator, JSON_EXTRACT_TYPE], str]: 1565 def _json_extract_segments(self: Generator, expression: JSON_EXTRACT_TYPE) -> str: 1566 path = expression.expression 1567 if not isinstance(path, exp.JSONPath): 1568 return rename_func(name)(self, expression) 1569 1570 segments = [] 1571 for segment in path.expressions: 1572 path = self.sql(segment) 1573 if path: 1574 if isinstance(segment, exp.JSONPathPart) and ( 1575 quoted_index or not isinstance(segment, exp.JSONPathSubscript) 1576 ): 1577 path = f"{self.dialect.QUOTE_START}{path}{self.dialect.QUOTE_END}" 1578 1579 segments.append(path) 1580 1581 if op: 1582 return f" {op} ".join([self.sql(expression.this), *segments]) 1583 return self.func(name, expression.this, *segments) 1584 1585 return _json_extract_segments
1595def filter_array_using_unnest(self: Generator, expression: exp.ArrayFilter) -> str: 1596 cond = expression.expression 1597 if isinstance(cond, exp.Lambda) and len(cond.expressions) == 1: 1598 alias = cond.expressions[0] 1599 cond = cond.this 1600 elif isinstance(cond, exp.Predicate): 1601 alias = "_u" 1602 else: 1603 self.unsupported("Unsupported filter condition") 1604 return "" 1605 1606 unnest = exp.Unnest(expressions=[expression.this]) 1607 filtered = exp.select(alias).from_(exp.alias_(unnest, None, table=[alias])).where(cond) 1608 return self.sql(exp.Array(expressions=[filtered]))
1620def build_default_decimal_type( 1621 precision: t.Optional[int] = None, scale: t.Optional[int] = None 1622) -> t.Callable[[exp.DataType], exp.DataType]: 1623 def _builder(dtype: exp.DataType) -> exp.DataType: 1624 if dtype.expressions or precision is None: 1625 return dtype 1626 1627 params = f"{precision}{f', {scale}' if scale is not None else ''}" 1628 return exp.DataType.build(f"DECIMAL({params})") 1629 1630 return _builder
1633def build_timestamp_from_parts(args: t.List) -> exp.Func: 1634 if len(args) == 2: 1635 # Other dialects don't have the TIMESTAMP_FROM_PARTS(date, time) concept, 1636 # so we parse this into Anonymous for now instead of introducing complexity 1637 return exp.Anonymous(this="TIMESTAMP_FROM_PARTS", expressions=args) 1638 1639 return exp.TimestampFromParts.from_arg_list(args)
1646def sequence_sql(self: Generator, expression: exp.GenerateSeries | exp.GenerateDateArray) -> str: 1647 start = expression.args.get("start") 1648 end = expression.args.get("end") 1649 step = expression.args.get("step") 1650 1651 if isinstance(start, exp.Cast): 1652 target_type = start.to 1653 elif isinstance(end, exp.Cast): 1654 target_type = end.to 1655 else: 1656 target_type = None 1657 1658 if start and end and target_type and target_type.is_type("date", "timestamp"): 1659 if isinstance(start, exp.Cast) and target_type is start.to: 1660 end = exp.cast(end, target_type) 1661 else: 1662 start = exp.cast(start, target_type) 1663 1664 return self.func("SEQUENCE", start, end, step)