Patrick Williams | c124f4f | 2015-09-15 14:41:29 -0500 | [diff] [blame] | 1 | # coding: utf-8 |
| 2 | ''' |
| 3 | Implements conditional aggregates. |
| 4 | |
| 5 | This code was based on the work of others found on the internet: |
| 6 | |
| 7 | 1. http://web.archive.org/web/20101115170804/http://www.voteruniverse.com/Members/jlantz/blog/conditional-aggregates-in-django |
| 8 | 2. https://code.djangoproject.com/ticket/11305 |
| 9 | 3. https://groups.google.com/forum/?fromgroups=#!topic/django-users/cjzloTUwmS0 |
| 10 | 4. https://groups.google.com/forum/?fromgroups=#!topic/django-users/vVprMpsAnPo |
| 11 | ''' |
| 12 | from __future__ import unicode_literals |
| 13 | from django.utils import six |
| 14 | import django |
| 15 | from django.db.models.aggregates import Aggregate as DjangoAggregate |
| 16 | from django.db.models.sql.aggregates import Aggregate as DjangoSqlAggregate |
| 17 | |
| 18 | |
| 19 | VERSION = django.VERSION[:2] |
| 20 | |
| 21 | |
| 22 | class SqlAggregate(DjangoSqlAggregate): |
| 23 | conditional_template = '%(function)s(CASE WHEN %(condition)s THEN %(field)s ELSE null END)' |
| 24 | |
| 25 | def __init__(self, col, source=None, is_summary=False, condition=None, **extra): |
| 26 | super(SqlAggregate, self).__init__(col, source, is_summary, **extra) |
| 27 | self.condition = condition |
| 28 | |
| 29 | def relabel_aliases(self, change_map): |
| 30 | if VERSION < (1, 7): |
| 31 | super(SqlAggregate, self).relabel_aliases(change_map) |
| 32 | if self.has_condition: |
| 33 | condition_change_map = dict((k, v) for k, v in \ |
| 34 | change_map.items() if k in self.condition.query.alias_map |
| 35 | ) |
| 36 | self.condition.query.change_aliases(condition_change_map) |
| 37 | |
| 38 | def relabeled_clone(self, change_map): |
| 39 | self.relabel_aliases(change_map) |
| 40 | return super(SqlAggregate, self).relabeled_clone(change_map) |
| 41 | |
| 42 | def as_sql(self, qn, connection): |
| 43 | if self.has_condition: |
| 44 | self.sql_template = self.conditional_template |
| 45 | self.extra['condition'] = self._condition_as_sql(qn, connection) |
| 46 | |
| 47 | return super(SqlAggregate, self).as_sql(qn, connection) |
| 48 | |
| 49 | @property |
| 50 | def has_condition(self): |
| 51 | # Warning: bool(QuerySet) will hit the database |
| 52 | return self.condition is not None |
| 53 | |
| 54 | def _condition_as_sql(self, qn, connection): |
| 55 | ''' |
| 56 | Return sql for condition. |
| 57 | ''' |
| 58 | def escape(value): |
| 59 | if isinstance(value, bool): |
| 60 | value = str(int(value)) |
| 61 | if isinstance(value, six.string_types): |
| 62 | # Escape params used with LIKE |
| 63 | if '%' in value: |
| 64 | value = value.replace('%', '%%') |
| 65 | # Escape single quotes |
| 66 | if "'" in value: |
| 67 | value = value.replace("'", "''") |
| 68 | # Add single quote to text values |
| 69 | value = "'" + value + "'" |
| 70 | return value |
| 71 | |
| 72 | sql, param = self.condition.query.where.as_sql(qn, connection) |
| 73 | param = map(escape, param) |
| 74 | |
| 75 | return sql % tuple(param) |
| 76 | |
| 77 | |
| 78 | class SqlSum(SqlAggregate): |
| 79 | sql_function = 'SUM' |
| 80 | |
| 81 | |
| 82 | class SqlCount(SqlAggregate): |
| 83 | is_ordinal = True |
| 84 | sql_function = 'COUNT' |
| 85 | sql_template = '%(function)s(%(distinct)s%(field)s)' |
| 86 | conditional_template = '%(function)s(%(distinct)sCASE WHEN %(condition)s THEN %(field)s ELSE null END)' |
| 87 | |
| 88 | def __init__(self, col, distinct=False, **extra): |
| 89 | super(SqlCount, self).__init__(col, distinct=distinct and 'DISTINCT ' or '', **extra) |
| 90 | |
| 91 | |
| 92 | class SqlAvg(SqlAggregate): |
| 93 | is_computed = True |
| 94 | sql_function = 'AVG' |
| 95 | |
| 96 | |
| 97 | class SqlMax(SqlAggregate): |
| 98 | sql_function = 'MAX' |
| 99 | |
| 100 | |
| 101 | class SqlMin(SqlAggregate): |
| 102 | sql_function = 'MIN' |
| 103 | |
| 104 | |
| 105 | class Aggregate(DjangoAggregate): |
| 106 | def __init__(self, lookup, only=None, **extra): |
| 107 | super(Aggregate, self).__init__(lookup, **extra) |
| 108 | self.only = only |
| 109 | self.condition = None |
| 110 | |
| 111 | def _get_fields_from_Q(self, q): |
| 112 | fields = [] |
| 113 | for child in q.children: |
| 114 | if hasattr(child, 'children'): |
| 115 | fields.extend(self._get_fields_from_Q(child)) |
| 116 | else: |
| 117 | fields.append(child) |
| 118 | return fields |
| 119 | |
| 120 | def add_to_query(self, query, alias, col, source, is_summary): |
| 121 | if self.only: |
| 122 | self.condition = query.model._default_manager.filter(self.only) |
| 123 | for child in self._get_fields_from_Q(self.only): |
| 124 | field_list = child[0].split('__') |
| 125 | # Pop off the last field if it's a query term ('gte', 'contains', 'isnull', etc.) |
| 126 | if field_list[-1] in query.query_terms: |
| 127 | field_list.pop() |
| 128 | # setup_joins have different returns in Django 1.5 and 1.6, but the order of what we need remains. |
| 129 | result = query.setup_joins(field_list, query.model._meta, query.get_initial_alias(), None) |
| 130 | join_list = result[3] |
| 131 | |
| 132 | fname = 'promote_alias_chain' if VERSION < (1, 5) else 'promote_joins' |
| 133 | args = (join_list, True) if VERSION < (1, 7) else (join_list,) |
| 134 | |
| 135 | promote = getattr(query, fname) |
| 136 | promote(*args) |
| 137 | |
| 138 | aggregate = self.sql_klass(col, source=source, is_summary=is_summary, condition=self.condition, **self.extra) |
| 139 | query.aggregates[alias] = aggregate |
| 140 | |
| 141 | |
| 142 | class Sum(Aggregate): |
| 143 | name = 'Sum' |
| 144 | sql_klass = SqlSum |
| 145 | |
| 146 | |
| 147 | class Count(Aggregate): |
| 148 | name = 'Count' |
| 149 | sql_klass = SqlCount |
| 150 | |
| 151 | |
| 152 | class Avg(Aggregate): |
| 153 | name = 'Avg' |
| 154 | sql_klass = SqlAvg |
| 155 | |
| 156 | |
| 157 | class Max(Aggregate): |
| 158 | name = 'Max' |
| 159 | sql_klass = SqlMax |
| 160 | |
| 161 | |
| 162 | class Min(Aggregate): |
| 163 | name = 'Min' |
| 164 | sql_klass = SqlMin |