first comit
This commit is contained in:
+5
@@ -0,0 +1,5 @@
|
||||
from .array import * # NOQA
|
||||
from .citext import * # NOQA
|
||||
from .hstore import * # NOQA
|
||||
from .jsonb import * # NOQA
|
||||
from .ranges import * # NOQA
|
||||
BIN
Binary file not shown.
Executable
BIN
Binary file not shown.
Executable
BIN
Binary file not shown.
Executable
BIN
Binary file not shown.
Executable
BIN
Binary file not shown.
Executable
BIN
Binary file not shown.
Executable
BIN
Binary file not shown.
@@ -0,0 +1,365 @@
|
||||
import json
|
||||
|
||||
from django.contrib.postgres import lookups
|
||||
from django.contrib.postgres.forms import SimpleArrayField
|
||||
from django.contrib.postgres.validators import ArrayMaxLengthValidator
|
||||
from django.core import checks, exceptions
|
||||
from django.db.models import Field, Func, IntegerField, Transform, Value
|
||||
from django.db.models.fields.mixins import CheckFieldDefaultMixin
|
||||
from django.db.models.lookups import Exact, In
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from ..utils import prefix_validation_error
|
||||
from .utils import AttributeSetter
|
||||
|
||||
__all__ = ["ArrayField"]
|
||||
|
||||
|
||||
class ArrayField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
default_error_messages = {
|
||||
"item_invalid": _("Item %(nth)s in the array did not validate:"),
|
||||
"nested_array_mismatch": _("Nested arrays must have the same length."),
|
||||
}
|
||||
_default_hint = ("list", "[]")
|
||||
|
||||
def __init__(self, base_field, size=None, **kwargs):
|
||||
self.base_field = base_field
|
||||
self.db_collation = getattr(self.base_field, "db_collation", None)
|
||||
self.size = size
|
||||
if self.size:
|
||||
self.default_validators = [
|
||||
*self.default_validators,
|
||||
ArrayMaxLengthValidator(self.size),
|
||||
]
|
||||
# For performance, only add a from_db_value() method if the base field
|
||||
# implements it.
|
||||
if hasattr(self.base_field, "from_db_value"):
|
||||
self.from_db_value = self._from_db_value
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
try:
|
||||
return self.__dict__["model"]
|
||||
except KeyError:
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute 'model'" % self.__class__.__name__
|
||||
)
|
||||
|
||||
@model.setter
|
||||
def model(self, model):
|
||||
self.__dict__["model"] = model
|
||||
self.base_field.model = model
|
||||
|
||||
@classmethod
|
||||
def _choices_is_value(cls, value):
|
||||
return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
|
||||
|
||||
def check(self, **kwargs):
|
||||
errors = super().check(**kwargs)
|
||||
if self.base_field.remote_field:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"Base field for array cannot be a related field.",
|
||||
obj=self,
|
||||
id="postgres.E002",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Remove the field name checks as they are not needed here.
|
||||
base_checks = self.base_field.check()
|
||||
if base_checks:
|
||||
error_messages = "\n ".join(
|
||||
"%s (%s)" % (base_check.msg, base_check.id)
|
||||
for base_check in base_checks
|
||||
if isinstance(base_check, checks.Error)
|
||||
)
|
||||
if error_messages:
|
||||
errors.append(
|
||||
checks.Error(
|
||||
"Base field for array has errors:\n %s" % error_messages,
|
||||
obj=self,
|
||||
id="postgres.E001",
|
||||
)
|
||||
)
|
||||
warning_messages = "\n ".join(
|
||||
"%s (%s)" % (base_check.msg, base_check.id)
|
||||
for base_check in base_checks
|
||||
if isinstance(base_check, checks.Warning)
|
||||
)
|
||||
if warning_messages:
|
||||
errors.append(
|
||||
checks.Warning(
|
||||
"Base field for array has warnings:\n %s"
|
||||
% warning_messages,
|
||||
obj=self,
|
||||
id="postgres.W004",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
def set_attributes_from_name(self, name):
|
||||
super().set_attributes_from_name(name)
|
||||
self.base_field.set_attributes_from_name(name)
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Array of %s" % self.base_field.description
|
||||
|
||||
def db_type(self, connection):
|
||||
size = self.size or ""
|
||||
return "%s[%s]" % (self.base_field.db_type(connection), size)
|
||||
|
||||
def cast_db_type(self, connection):
|
||||
size = self.size or ""
|
||||
return "%s[%s]" % (self.base_field.cast_db_type(connection), size)
|
||||
|
||||
def db_parameters(self, connection):
|
||||
db_params = super().db_parameters(connection)
|
||||
db_params["collation"] = self.db_collation
|
||||
return db_params
|
||||
|
||||
def get_placeholder(self, value, compiler, connection):
|
||||
return "%s::{}".format(self.db_type(connection))
|
||||
|
||||
def get_db_prep_value(self, value, connection, prepared=False):
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [
|
||||
self.base_field.get_db_prep_value(i, connection, prepared=False)
|
||||
for i in value
|
||||
]
|
||||
return value
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if path == "django.contrib.postgres.fields.array.ArrayField":
|
||||
path = "django.contrib.postgres.fields.ArrayField"
|
||||
kwargs.update(
|
||||
{
|
||||
"base_field": self.base_field.clone(),
|
||||
"size": self.size,
|
||||
}
|
||||
)
|
||||
return name, path, args, kwargs
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing
|
||||
vals = json.loads(value)
|
||||
value = [self.base_field.to_python(val) for val in vals]
|
||||
return value
|
||||
|
||||
def _from_db_value(self, value, expression, connection):
|
||||
if value is None:
|
||||
return value
|
||||
return [
|
||||
self.base_field.from_db_value(item, expression, connection)
|
||||
for item in value
|
||||
]
|
||||
|
||||
def value_to_string(self, obj):
|
||||
values = []
|
||||
vals = self.value_from_object(obj)
|
||||
base_field = self.base_field
|
||||
|
||||
for val in vals:
|
||||
if val is None:
|
||||
values.append(None)
|
||||
else:
|
||||
obj = AttributeSetter(base_field.attname, val)
|
||||
values.append(base_field.value_to_string(obj))
|
||||
return json.dumps(values)
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
if "_" not in name:
|
||||
try:
|
||||
index = int(name)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
index += 1 # postgres uses 1-indexing
|
||||
return IndexTransformFactory(index, self.base_field)
|
||||
try:
|
||||
start, end = name.split("_")
|
||||
start = int(start) + 1
|
||||
end = int(end) # don't add one here because postgres slices are weird
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
return SliceTransformFactory(start, end)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
for index, part in enumerate(value):
|
||||
try:
|
||||
self.base_field.validate(part, model_instance)
|
||||
except exceptions.ValidationError as error:
|
||||
raise prefix_validation_error(
|
||||
error,
|
||||
prefix=self.error_messages["item_invalid"],
|
||||
code="item_invalid",
|
||||
params={"nth": index + 1},
|
||||
)
|
||||
if isinstance(self.base_field, ArrayField):
|
||||
if len({len(i) for i in value}) > 1:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["nested_array_mismatch"],
|
||||
code="nested_array_mismatch",
|
||||
)
|
||||
|
||||
def run_validators(self, value):
|
||||
super().run_validators(value)
|
||||
for index, part in enumerate(value):
|
||||
try:
|
||||
self.base_field.run_validators(part)
|
||||
except exceptions.ValidationError as error:
|
||||
raise prefix_validation_error(
|
||||
error,
|
||||
prefix=self.error_messages["item_invalid"],
|
||||
code="item_invalid",
|
||||
params={"nth": index + 1},
|
||||
)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": SimpleArrayField,
|
||||
"base_field": self.base_field.formfield(),
|
||||
"max_length": self.size,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class ArrayRHSMixin:
|
||||
def __init__(self, lhs, rhs):
|
||||
# Don't wrap arrays that contains only None values, psycopg doesn't
|
||||
# allow this.
|
||||
if isinstance(rhs, (tuple, list)) and any(self._rhs_not_none_values(rhs)):
|
||||
expressions = []
|
||||
for value in rhs:
|
||||
if not hasattr(value, "resolve_expression"):
|
||||
field = lhs.output_field
|
||||
value = Value(field.base_field.get_prep_value(value))
|
||||
expressions.append(value)
|
||||
rhs = Func(
|
||||
*expressions,
|
||||
function="ARRAY",
|
||||
template="%(function)s[%(expressions)s]",
|
||||
)
|
||||
super().__init__(lhs, rhs)
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
cast_type = self.lhs.output_field.cast_db_type(connection)
|
||||
return "%s::%s" % (rhs, cast_type), rhs_params
|
||||
|
||||
def _rhs_not_none_values(self, rhs):
|
||||
for x in rhs:
|
||||
if isinstance(x, (list, tuple)):
|
||||
yield from self._rhs_not_none_values(x)
|
||||
elif x is not None:
|
||||
yield True
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayContains(ArrayRHSMixin, lookups.DataContains):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayExact(ArrayRHSMixin, Exact):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayOverlap(ArrayRHSMixin, lookups.Overlap):
|
||||
pass
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayLenTransform(Transform):
|
||||
lookup_name = "len"
|
||||
output_field = IntegerField()
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
# Distinguish NULL and empty arrays
|
||||
return (
|
||||
"CASE WHEN %(lhs)s IS NULL THEN NULL ELSE "
|
||||
"coalesce(array_length(%(lhs)s, 1), 0) END"
|
||||
) % {"lhs": lhs}, params * 2
|
||||
|
||||
|
||||
@ArrayField.register_lookup
|
||||
class ArrayInLookup(In):
|
||||
def get_prep_lookup(self):
|
||||
values = super().get_prep_lookup()
|
||||
if hasattr(values, "resolve_expression"):
|
||||
return values
|
||||
# In.process_rhs() expects values to be hashable, so convert lists
|
||||
# to tuples.
|
||||
prepared_values = []
|
||||
for value in values:
|
||||
if hasattr(value, "resolve_expression"):
|
||||
prepared_values.append(value)
|
||||
else:
|
||||
prepared_values.append(tuple(value))
|
||||
return prepared_values
|
||||
|
||||
|
||||
class IndexTransform(Transform):
|
||||
def __init__(self, index, base_field, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index = index
|
||||
self.base_field = base_field
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
if not lhs.endswith("]"):
|
||||
lhs = "(%s)" % lhs
|
||||
return "%s[%%s]" % lhs, (*params, self.index)
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return self.base_field
|
||||
|
||||
|
||||
class IndexTransformFactory:
|
||||
def __init__(self, index, base_field):
|
||||
self.index = index
|
||||
self.base_field = base_field
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return IndexTransform(self.index, self.base_field, *args, **kwargs)
|
||||
|
||||
|
||||
class SliceTransform(Transform):
|
||||
def __init__(self, start, end, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
if not lhs.endswith("]"):
|
||||
lhs = "(%s)" % lhs
|
||||
return "%s[%%s:%%s]" % lhs, (*params, self.start, self.end)
|
||||
|
||||
|
||||
class SliceTransformFactory:
|
||||
def __init__(self, start, end):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return SliceTransform(self.start, self.end, *args, **kwargs)
|
||||
@@ -0,0 +1,78 @@
|
||||
import warnings
|
||||
|
||||
from django.db.models import CharField, EmailField, TextField
|
||||
from django.test.utils import ignore_warnings
|
||||
from django.utils.deprecation import RemovedInDjango51Warning
|
||||
|
||||
__all__ = ["CICharField", "CIEmailField", "CIText", "CITextField"]
|
||||
|
||||
|
||||
# RemovedInDjango51Warning.
|
||||
class CIText:
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"django.contrib.postgres.fields.CIText mixin is deprecated.",
|
||||
RemovedInDjango51Warning,
|
||||
stacklevel=2,
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_internal_type(self):
|
||||
return "CI" + super().get_internal_type()
|
||||
|
||||
def db_type(self, connection):
|
||||
return "citext"
|
||||
|
||||
|
||||
class CICharField(CIText, CharField):
|
||||
system_check_deprecated_details = {
|
||||
"msg": (
|
||||
"django.contrib.postgres.fields.CICharField is deprecated. Support for it "
|
||||
"(except in historical migrations) will be removed in Django 5.1."
|
||||
),
|
||||
"hint": (
|
||||
'Use CharField(db_collation="…") with a case-insensitive non-deterministic '
|
||||
"collation instead."
|
||||
),
|
||||
"id": "fields.W905",
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
with ignore_warnings(category=RemovedInDjango51Warning):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CIEmailField(CIText, EmailField):
|
||||
system_check_deprecated_details = {
|
||||
"msg": (
|
||||
"django.contrib.postgres.fields.CIEmailField is deprecated. Support for it "
|
||||
"(except in historical migrations) will be removed in Django 5.1."
|
||||
),
|
||||
"hint": (
|
||||
'Use EmailField(db_collation="…") with a case-insensitive '
|
||||
"non-deterministic collation instead."
|
||||
),
|
||||
"id": "fields.W906",
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
with ignore_warnings(category=RemovedInDjango51Warning):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class CITextField(CIText, TextField):
|
||||
system_check_deprecated_details = {
|
||||
"msg": (
|
||||
"django.contrib.postgres.fields.CITextField is deprecated. Support for it "
|
||||
"(except in historical migrations) will be removed in Django 5.1."
|
||||
),
|
||||
"hint": (
|
||||
'Use TextField(db_collation="…") with a case-insensitive non-deterministic '
|
||||
"collation instead."
|
||||
),
|
||||
"id": "fields.W907",
|
||||
}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
with ignore_warnings(category=RemovedInDjango51Warning):
|
||||
super().__init__(*args, **kwargs)
|
||||
+112
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
|
||||
from django.contrib.postgres import forms, lookups
|
||||
from django.contrib.postgres.fields.array import ArrayField
|
||||
from django.core import exceptions
|
||||
from django.db.models import Field, TextField, Transform
|
||||
from django.db.models.fields.mixins import CheckFieldDefaultMixin
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
__all__ = ["HStoreField"]
|
||||
|
||||
|
||||
class HStoreField(CheckFieldDefaultMixin, Field):
|
||||
empty_strings_allowed = False
|
||||
description = _("Map of strings to strings/nulls")
|
||||
default_error_messages = {
|
||||
"not_a_string": _("The value of “%(key)s” is not a string or null."),
|
||||
}
|
||||
_default_hint = ("dict", "{}")
|
||||
|
||||
def db_type(self, connection):
|
||||
return "hstore"
|
||||
|
||||
def get_transform(self, name):
|
||||
transform = super().get_transform(name)
|
||||
if transform:
|
||||
return transform
|
||||
return KeyTransformFactory(name)
|
||||
|
||||
def validate(self, value, model_instance):
|
||||
super().validate(value, model_instance)
|
||||
for key, val in value.items():
|
||||
if not isinstance(val, str) and val is not None:
|
||||
raise exceptions.ValidationError(
|
||||
self.error_messages["not_a_string"],
|
||||
code="not_a_string",
|
||||
params={"key": key},
|
||||
)
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
value = json.loads(value)
|
||||
return value
|
||||
|
||||
def value_to_string(self, obj):
|
||||
return json.dumps(self.value_from_object(obj))
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
return super().formfield(
|
||||
**{
|
||||
"form_class": forms.HStoreField,
|
||||
**kwargs,
|
||||
}
|
||||
)
|
||||
|
||||
def get_prep_value(self, value):
|
||||
value = super().get_prep_value(value)
|
||||
|
||||
if isinstance(value, dict):
|
||||
prep_value = {}
|
||||
for key, val in value.items():
|
||||
key = str(key)
|
||||
if val is not None:
|
||||
val = str(val)
|
||||
prep_value[key] = val
|
||||
value = prep_value
|
||||
|
||||
if isinstance(value, list):
|
||||
value = [str(item) for item in value]
|
||||
|
||||
return value
|
||||
|
||||
|
||||
HStoreField.register_lookup(lookups.DataContains)
|
||||
HStoreField.register_lookup(lookups.ContainedBy)
|
||||
HStoreField.register_lookup(lookups.HasKey)
|
||||
HStoreField.register_lookup(lookups.HasKeys)
|
||||
HStoreField.register_lookup(lookups.HasAnyKeys)
|
||||
|
||||
|
||||
class KeyTransform(Transform):
|
||||
output_field = TextField()
|
||||
|
||||
def __init__(self, key_name, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.key_name = key_name
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
lhs, params = compiler.compile(self.lhs)
|
||||
return "(%s -> %%s)" % lhs, tuple(params) + (self.key_name,)
|
||||
|
||||
|
||||
class KeyTransformFactory:
|
||||
def __init__(self, key_name):
|
||||
self.key_name = key_name
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return KeyTransform(self.key_name, *args, **kwargs)
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class KeysTransform(Transform):
|
||||
lookup_name = "keys"
|
||||
function = "akeys"
|
||||
output_field = ArrayField(TextField())
|
||||
|
||||
|
||||
@HStoreField.register_lookup
|
||||
class ValuesTransform(Transform):
|
||||
lookup_name = "values"
|
||||
function = "avals"
|
||||
output_field = ArrayField(TextField())
|
||||
@@ -0,0 +1,14 @@
|
||||
from django.db.models import JSONField as BuiltinJSONField
|
||||
|
||||
__all__ = ["JSONField"]
|
||||
|
||||
|
||||
class JSONField(BuiltinJSONField):
|
||||
system_check_removed_details = {
|
||||
"msg": (
|
||||
"django.contrib.postgres.fields.JSONField is removed except for "
|
||||
"support in historical migrations."
|
||||
),
|
||||
"hint": "Use django.db.models.JSONField instead.",
|
||||
"id": "fields.E904",
|
||||
}
|
||||
+383
@@ -0,0 +1,383 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from django.contrib.postgres import forms, lookups
|
||||
from django.db import models
|
||||
from django.db.backends.postgresql.psycopg_any import (
|
||||
DateRange,
|
||||
DateTimeTZRange,
|
||||
NumericRange,
|
||||
Range,
|
||||
)
|
||||
from django.db.models.functions import Cast
|
||||
from django.db.models.lookups import PostgresOperatorLookup
|
||||
|
||||
from .utils import AttributeSetter
|
||||
|
||||
__all__ = [
|
||||
"RangeField",
|
||||
"IntegerRangeField",
|
||||
"BigIntegerRangeField",
|
||||
"DecimalRangeField",
|
||||
"DateTimeRangeField",
|
||||
"DateRangeField",
|
||||
"RangeBoundary",
|
||||
"RangeOperators",
|
||||
]
|
||||
|
||||
|
||||
class RangeBoundary(models.Expression):
|
||||
"""A class that represents range boundaries."""
|
||||
|
||||
def __init__(self, inclusive_lower=True, inclusive_upper=False):
|
||||
self.lower = "[" if inclusive_lower else "("
|
||||
self.upper = "]" if inclusive_upper else ")"
|
||||
|
||||
def as_sql(self, compiler, connection):
|
||||
return "'%s%s'" % (self.lower, self.upper), []
|
||||
|
||||
|
||||
class RangeOperators:
|
||||
# https://www.postgresql.org/docs/current/functions-range.html#RANGE-OPERATORS-TABLE
|
||||
EQUAL = "="
|
||||
NOT_EQUAL = "<>"
|
||||
CONTAINS = "@>"
|
||||
CONTAINED_BY = "<@"
|
||||
OVERLAPS = "&&"
|
||||
FULLY_LT = "<<"
|
||||
FULLY_GT = ">>"
|
||||
NOT_LT = "&>"
|
||||
NOT_GT = "&<"
|
||||
ADJACENT_TO = "-|-"
|
||||
|
||||
|
||||
class RangeField(models.Field):
|
||||
empty_strings_allowed = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "default_bounds" in kwargs:
|
||||
raise TypeError(
|
||||
f"Cannot use 'default_bounds' with {self.__class__.__name__}."
|
||||
)
|
||||
# Initializing base_field here ensures that its model matches the model
|
||||
# for self.
|
||||
if hasattr(self, "base_field"):
|
||||
self.base_field = self.base_field()
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
try:
|
||||
return self.__dict__["model"]
|
||||
except KeyError:
|
||||
raise AttributeError(
|
||||
"'%s' object has no attribute 'model'" % self.__class__.__name__
|
||||
)
|
||||
|
||||
@model.setter
|
||||
def model(self, model):
|
||||
self.__dict__["model"] = model
|
||||
self.base_field.model = model
|
||||
|
||||
@classmethod
|
||||
def _choices_is_value(cls, value):
|
||||
return isinstance(value, (list, tuple)) or super()._choices_is_value(value)
|
||||
|
||||
def get_placeholder(self, value, compiler, connection):
|
||||
return "%s::{}".format(self.db_type(connection))
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, Range):
|
||||
return value
|
||||
elif isinstance(value, (list, tuple)):
|
||||
return self.range_type(value[0], value[1])
|
||||
return value
|
||||
|
||||
def to_python(self, value):
|
||||
if isinstance(value, str):
|
||||
# Assume we're deserializing
|
||||
vals = json.loads(value)
|
||||
for end in ("lower", "upper"):
|
||||
if end in vals:
|
||||
vals[end] = self.base_field.to_python(vals[end])
|
||||
value = self.range_type(**vals)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
value = self.range_type(value[0], value[1])
|
||||
return value
|
||||
|
||||
def set_attributes_from_name(self, name):
|
||||
super().set_attributes_from_name(name)
|
||||
self.base_field.set_attributes_from_name(name)
|
||||
|
||||
def value_to_string(self, obj):
|
||||
value = self.value_from_object(obj)
|
||||
if value is None:
|
||||
return None
|
||||
if value.isempty:
|
||||
return json.dumps({"empty": True})
|
||||
base_field = self.base_field
|
||||
result = {"bounds": value._bounds}
|
||||
for end in ("lower", "upper"):
|
||||
val = getattr(value, end)
|
||||
if val is None:
|
||||
result[end] = None
|
||||
else:
|
||||
obj = AttributeSetter(base_field.attname, val)
|
||||
result[end] = base_field.value_to_string(obj)
|
||||
return json.dumps(result)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
kwargs.setdefault("form_class", self.form_field)
|
||||
return super().formfield(**kwargs)
|
||||
|
||||
|
||||
CANONICAL_RANGE_BOUNDS = "[)"
|
||||
|
||||
|
||||
class ContinuousRangeField(RangeField):
|
||||
"""
|
||||
Continuous range field. It allows specifying default bounds for list and
|
||||
tuple inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, default_bounds=CANONICAL_RANGE_BOUNDS, **kwargs):
|
||||
if default_bounds not in ("[)", "(]", "()", "[]"):
|
||||
raise ValueError("default_bounds must be one of '[)', '(]', '()', or '[]'.")
|
||||
self.default_bounds = default_bounds
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def get_prep_value(self, value):
|
||||
if isinstance(value, (list, tuple)):
|
||||
return self.range_type(value[0], value[1], self.default_bounds)
|
||||
return super().get_prep_value(value)
|
||||
|
||||
def formfield(self, **kwargs):
|
||||
kwargs.setdefault("default_bounds", self.default_bounds)
|
||||
return super().formfield(**kwargs)
|
||||
|
||||
def deconstruct(self):
|
||||
name, path, args, kwargs = super().deconstruct()
|
||||
if self.default_bounds and self.default_bounds != CANONICAL_RANGE_BOUNDS:
|
||||
kwargs["default_bounds"] = self.default_bounds
|
||||
return name, path, args, kwargs
|
||||
|
||||
|
||||
class IntegerRangeField(RangeField):
|
||||
base_field = models.IntegerField
|
||||
range_type = NumericRange
|
||||
form_field = forms.IntegerRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "int4range"
|
||||
|
||||
|
||||
class BigIntegerRangeField(RangeField):
|
||||
base_field = models.BigIntegerField
|
||||
range_type = NumericRange
|
||||
form_field = forms.IntegerRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "int8range"
|
||||
|
||||
|
||||
class DecimalRangeField(ContinuousRangeField):
|
||||
base_field = models.DecimalField
|
||||
range_type = NumericRange
|
||||
form_field = forms.DecimalRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "numrange"
|
||||
|
||||
|
||||
class DateTimeRangeField(ContinuousRangeField):
|
||||
base_field = models.DateTimeField
|
||||
range_type = DateTimeTZRange
|
||||
form_field = forms.DateTimeRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "tstzrange"
|
||||
|
||||
|
||||
class DateRangeField(RangeField):
|
||||
base_field = models.DateField
|
||||
range_type = DateRange
|
||||
form_field = forms.DateRangeField
|
||||
|
||||
def db_type(self, connection):
|
||||
return "daterange"
|
||||
|
||||
|
||||
class RangeContains(lookups.DataContains):
|
||||
def get_prep_lookup(self):
|
||||
if not isinstance(self.rhs, (list, tuple, Range)):
|
||||
return Cast(self.rhs, self.lhs.field.base_field)
|
||||
return super().get_prep_lookup()
|
||||
|
||||
|
||||
RangeField.register_lookup(RangeContains)
|
||||
RangeField.register_lookup(lookups.ContainedBy)
|
||||
RangeField.register_lookup(lookups.Overlap)
|
||||
|
||||
|
||||
class DateTimeRangeContains(PostgresOperatorLookup):
|
||||
"""
|
||||
Lookup for Date/DateTimeRange containment to cast the rhs to the correct
|
||||
type.
|
||||
"""
|
||||
|
||||
lookup_name = "contains"
|
||||
postgres_operator = RangeOperators.CONTAINS
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
# Transform rhs value for db lookup.
|
||||
if isinstance(self.rhs, datetime.date):
|
||||
value = models.Value(self.rhs)
|
||||
self.rhs = value.resolve_expression(compiler.query)
|
||||
return super().process_rhs(compiler, connection)
|
||||
|
||||
def as_postgresql(self, compiler, connection):
|
||||
sql, params = super().as_postgresql(compiler, connection)
|
||||
# Cast the rhs if needed.
|
||||
cast_sql = ""
|
||||
if (
|
||||
isinstance(self.rhs, models.Expression)
|
||||
and self.rhs._output_field_or_none
|
||||
and
|
||||
# Skip cast if rhs has a matching range type.
|
||||
not isinstance(
|
||||
self.rhs._output_field_or_none, self.lhs.output_field.__class__
|
||||
)
|
||||
):
|
||||
cast_internal_type = self.lhs.output_field.base_field.get_internal_type()
|
||||
cast_sql = "::{}".format(connection.data_types.get(cast_internal_type))
|
||||
return "%s%s" % (sql, cast_sql), params
|
||||
|
||||
|
||||
DateRangeField.register_lookup(DateTimeRangeContains)
|
||||
DateTimeRangeField.register_lookup(DateTimeRangeContains)
|
||||
|
||||
|
||||
class RangeContainedBy(PostgresOperatorLookup):
|
||||
lookup_name = "contained_by"
|
||||
type_mapping = {
|
||||
"smallint": "int4range",
|
||||
"integer": "int4range",
|
||||
"bigint": "int8range",
|
||||
"double precision": "numrange",
|
||||
"numeric": "numrange",
|
||||
"date": "daterange",
|
||||
"timestamp with time zone": "tstzrange",
|
||||
}
|
||||
postgres_operator = RangeOperators.CONTAINED_BY
|
||||
|
||||
def process_rhs(self, compiler, connection):
|
||||
rhs, rhs_params = super().process_rhs(compiler, connection)
|
||||
# Ignore precision for DecimalFields.
|
||||
db_type = self.lhs.output_field.cast_db_type(connection).split("(")[0]
|
||||
cast_type = self.type_mapping[db_type]
|
||||
return "%s::%s" % (rhs, cast_type), rhs_params
|
||||
|
||||
def process_lhs(self, compiler, connection):
|
||||
lhs, lhs_params = super().process_lhs(compiler, connection)
|
||||
if isinstance(self.lhs.output_field, models.FloatField):
|
||||
lhs = "%s::numeric" % lhs
|
||||
elif isinstance(self.lhs.output_field, models.SmallIntegerField):
|
||||
lhs = "%s::integer" % lhs
|
||||
return lhs, lhs_params
|
||||
|
||||
def get_prep_lookup(self):
|
||||
return RangeField().get_prep_value(self.rhs)
|
||||
|
||||
|
||||
models.DateField.register_lookup(RangeContainedBy)
|
||||
models.DateTimeField.register_lookup(RangeContainedBy)
|
||||
models.IntegerField.register_lookup(RangeContainedBy)
|
||||
models.FloatField.register_lookup(RangeContainedBy)
|
||||
models.DecimalField.register_lookup(RangeContainedBy)
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class FullyLessThan(PostgresOperatorLookup):
|
||||
lookup_name = "fully_lt"
|
||||
postgres_operator = RangeOperators.FULLY_LT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class FullGreaterThan(PostgresOperatorLookup):
|
||||
lookup_name = "fully_gt"
|
||||
postgres_operator = RangeOperators.FULLY_GT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class NotLessThan(PostgresOperatorLookup):
|
||||
lookup_name = "not_lt"
|
||||
postgres_operator = RangeOperators.NOT_LT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class NotGreaterThan(PostgresOperatorLookup):
|
||||
lookup_name = "not_gt"
|
||||
postgres_operator = RangeOperators.NOT_GT
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class AdjacentToLookup(PostgresOperatorLookup):
|
||||
lookup_name = "adjacent_to"
|
||||
postgres_operator = RangeOperators.ADJACENT_TO
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeStartsWith(models.Transform):
|
||||
lookup_name = "startswith"
|
||||
function = "lower"
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return self.lhs.output_field.base_field
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class RangeEndsWith(models.Transform):
|
||||
lookup_name = "endswith"
|
||||
function = "upper"
|
||||
|
||||
@property
|
||||
def output_field(self):
|
||||
return self.lhs.output_field.base_field
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class IsEmpty(models.Transform):
|
||||
lookup_name = "isempty"
|
||||
function = "isempty"
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class LowerInclusive(models.Transform):
|
||||
lookup_name = "lower_inc"
|
||||
function = "LOWER_INC"
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class LowerInfinite(models.Transform):
|
||||
lookup_name = "lower_inf"
|
||||
function = "LOWER_INF"
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class UpperInclusive(models.Transform):
|
||||
lookup_name = "upper_inc"
|
||||
function = "UPPER_INC"
|
||||
output_field = models.BooleanField()
|
||||
|
||||
|
||||
@RangeField.register_lookup
|
||||
class UpperInfinite(models.Transform):
|
||||
lookup_name = "upper_inf"
|
||||
function = "UPPER_INF"
|
||||
output_field = models.BooleanField()
|
||||
@@ -0,0 +1,3 @@
|
||||
class AttributeSetter:
|
||||
def __init__(self, name, value):
|
||||
setattr(self, name, value)
|
||||
Reference in New Issue
Block a user