blob: d5f3427170eb5740a6448aeadd5aec8d43c668e6 [file] [log] [blame]
Patrick Williamsc124f4f2015-09-15 14:41:29 -05001# coding: utf-8
2'''
3Implements conditional aggregates.
4
5This code was based on the work of others found on the internet:
6
71. http://web.archive.org/web/20101115170804/http://www.voteruniverse.com/Members/jlantz/blog/conditional-aggregates-in-django
82. https://code.djangoproject.com/ticket/11305
93. https://groups.google.com/forum/?fromgroups=#!topic/django-users/cjzloTUwmS0
104. https://groups.google.com/forum/?fromgroups=#!topic/django-users/vVprMpsAnPo
11'''
12from __future__ import unicode_literals
13from django.utils import six
14import django
15from django.db.models.aggregates import Aggregate as DjangoAggregate
16from django.db.models.sql.aggregates import Aggregate as DjangoSqlAggregate
17
18
19VERSION = django.VERSION[:2]
20
21
22class SqlAggregate(DjangoSqlAggregate):
23 conditional_template = '%(function)s(CASE WHEN %(condition)s THEN %(field)s ELSE null END)'
24
25 def __init__(self, col, source=None, is_summary=False, condition=None, **extra):
26 super(SqlAggregate, self).__init__(col, source, is_summary, **extra)
27 self.condition = condition
28
29 def relabel_aliases(self, change_map):
30 if VERSION < (1, 7):
31 super(SqlAggregate, self).relabel_aliases(change_map)
32 if self.has_condition:
33 condition_change_map = dict((k, v) for k, v in \
34 change_map.items() if k in self.condition.query.alias_map
35 )
36 self.condition.query.change_aliases(condition_change_map)
37
38 def relabeled_clone(self, change_map):
39 self.relabel_aliases(change_map)
40 return super(SqlAggregate, self).relabeled_clone(change_map)
41
42 def as_sql(self, qn, connection):
43 if self.has_condition:
44 self.sql_template = self.conditional_template
45 self.extra['condition'] = self._condition_as_sql(qn, connection)
46
47 return super(SqlAggregate, self).as_sql(qn, connection)
48
49 @property
50 def has_condition(self):
51 # Warning: bool(QuerySet) will hit the database
52 return self.condition is not None
53
54 def _condition_as_sql(self, qn, connection):
55 '''
56 Return sql for condition.
57 '''
58 def escape(value):
59 if isinstance(value, bool):
60 value = str(int(value))
61 if isinstance(value, six.string_types):
62 # Escape params used with LIKE
63 if '%' in value:
64 value = value.replace('%', '%%')
65 # Escape single quotes
66 if "'" in value:
67 value = value.replace("'", "''")
68 # Add single quote to text values
69 value = "'" + value + "'"
70 return value
71
72 sql, param = self.condition.query.where.as_sql(qn, connection)
73 param = map(escape, param)
74
75 return sql % tuple(param)
76
77
78class SqlSum(SqlAggregate):
79 sql_function = 'SUM'
80
81
82class SqlCount(SqlAggregate):
83 is_ordinal = True
84 sql_function = 'COUNT'
85 sql_template = '%(function)s(%(distinct)s%(field)s)'
86 conditional_template = '%(function)s(%(distinct)sCASE WHEN %(condition)s THEN %(field)s ELSE null END)'
87
88 def __init__(self, col, distinct=False, **extra):
89 super(SqlCount, self).__init__(col, distinct=distinct and 'DISTINCT ' or '', **extra)
90
91
92class SqlAvg(SqlAggregate):
93 is_computed = True
94 sql_function = 'AVG'
95
96
97class SqlMax(SqlAggregate):
98 sql_function = 'MAX'
99
100
101class SqlMin(SqlAggregate):
102 sql_function = 'MIN'
103
104
105class Aggregate(DjangoAggregate):
106 def __init__(self, lookup, only=None, **extra):
107 super(Aggregate, self).__init__(lookup, **extra)
108 self.only = only
109 self.condition = None
110
111 def _get_fields_from_Q(self, q):
112 fields = []
113 for child in q.children:
114 if hasattr(child, 'children'):
115 fields.extend(self._get_fields_from_Q(child))
116 else:
117 fields.append(child)
118 return fields
119
120 def add_to_query(self, query, alias, col, source, is_summary):
121 if self.only:
122 self.condition = query.model._default_manager.filter(self.only)
123 for child in self._get_fields_from_Q(self.only):
124 field_list = child[0].split('__')
125 # Pop off the last field if it's a query term ('gte', 'contains', 'isnull', etc.)
126 if field_list[-1] in query.query_terms:
127 field_list.pop()
128 # setup_joins have different returns in Django 1.5 and 1.6, but the order of what we need remains.
129 result = query.setup_joins(field_list, query.model._meta, query.get_initial_alias(), None)
130 join_list = result[3]
131
132 fname = 'promote_alias_chain' if VERSION < (1, 5) else 'promote_joins'
133 args = (join_list, True) if VERSION < (1, 7) else (join_list,)
134
135 promote = getattr(query, fname)
136 promote(*args)
137
138 aggregate = self.sql_klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra)
139 query.aggregates[alias] = aggregate
140
141
142class Sum(Aggregate):
143 name = 'Sum'
144 sql_klass = SqlSum
145
146
147class Count(Aggregate):
148 name = 'Count'
149 sql_klass = SqlCount
150
151
152class Avg(Aggregate):
153 name = 'Avg'
154 sql_klass = SqlAvg
155
156
157class Max(Aggregate):
158 name = 'Max'
159 sql_klass = SqlMax
160
161
162class Min(Aggregate):
163 name = 'Min'
164 sql_klass = SqlMin