blob: 162a88fb789eaa66bbf09abccac3d6f708ddb2c6 [file] [log] [blame]
Patrick Williamsc124f4f2015-09-15 14:41:29 -05001# Copyright (C) 2013 Intel Corporation
2#
3# Released under the MIT license (see COPYING.MIT)
4
5# Some custom decorators that can be used by unittests
6# Most useful is skipUnlessPassed which can be used for
7# creating dependecies between two test methods.
8
9import os
10import logging
11import sys
12import unittest
13import threading
14import signal
15from functools import wraps
16
17#get the "result" object from one of the upper frames provided that one of these upper frames is a unittest.case frame
18class getResults(object):
19 def __init__(self):
20 #dynamically determine the unittest.case frame and use it to get the name of the test method
21 ident = threading.current_thread().ident
22 upperf = sys._current_frames()[ident]
23 while (upperf.f_globals['__name__'] != 'unittest.case'):
24 upperf = upperf.f_back
25
26 def handleList(items):
27 ret = []
28 # items is a list of tuples, (test, failure) or (_ErrorHandler(), Exception())
29 for i in items:
30 s = i[0].id()
31 #Handle the _ErrorHolder objects from skipModule failures
32 if "setUpModule (" in s:
33 ret.append(s.replace("setUpModule (", "").replace(")",""))
34 else:
35 ret.append(s)
36 return ret
37 self.faillist = handleList(upperf.f_locals['result'].failures)
38 self.errorlist = handleList(upperf.f_locals['result'].errors)
39 self.skiplist = handleList(upperf.f_locals['result'].skipped)
40
41 def getFailList(self):
42 return self.faillist
43
44 def getErrorList(self):
45 return self.errorlist
46
47 def getSkipList(self):
48 return self.skiplist
49
50class skipIfFailure(object):
51
52 def __init__(self,testcase):
53 self.testcase = testcase
54
55 def __call__(self,f):
56 def wrapped_f(*args):
57 res = getResults()
58 if self.testcase in (res.getFailList() or res.getErrorList()):
59 raise unittest.SkipTest("Testcase dependency not met: %s" % self.testcase)
60 return f(*args)
61 wrapped_f.__name__ = f.__name__
62 return wrapped_f
63
64class skipIfSkipped(object):
65
66 def __init__(self,testcase):
67 self.testcase = testcase
68
69 def __call__(self,f):
70 def wrapped_f(*args):
71 res = getResults()
72 if self.testcase in res.getSkipList():
73 raise unittest.SkipTest("Testcase dependency not met: %s" % self.testcase)
74 return f(*args)
75 wrapped_f.__name__ = f.__name__
76 return wrapped_f
77
78class skipUnlessPassed(object):
79
80 def __init__(self,testcase):
81 self.testcase = testcase
82
83 def __call__(self,f):
84 def wrapped_f(*args):
85 res = getResults()
86 if self.testcase in res.getSkipList() or \
87 self.testcase in res.getFailList() or \
88 self.testcase in res.getErrorList():
89 raise unittest.SkipTest("Testcase dependency not met: %s" % self.testcase)
90 return f(*args)
91 wrapped_f.__name__ = f.__name__
92 wrapped_f._depends_on = self.testcase
93 return wrapped_f
94
95class testcase(object):
96
97 def __init__(self, test_case):
98 self.test_case = test_case
99
100 def __call__(self, func):
101 def wrapped_f(*args):
102 return func(*args)
103 wrapped_f.test_case = self.test_case
104 wrapped_f.__name__ = func.__name__
105 return wrapped_f
106
107class NoParsingFilter(logging.Filter):
108 def filter(self, record):
109 return record.levelno == 100
110
111def LogResults(original_class):
112 orig_method = original_class.run
113
114 #rewrite the run method of unittest.TestCase to add testcase logging
115 def run(self, result, *args, **kws):
116 orig_method(self, result, *args, **kws)
117 passed = True
118 testMethod = getattr(self, self._testMethodName)
119 #if test case is decorated then use it's number, else use it's name
120 try:
121 test_case = testMethod.test_case
122 except AttributeError:
123 test_case = self._testMethodName
124
125 class_name = str(testMethod.im_class).split("'")[1]
126
127 #create custom logging level for filtering.
128 custom_log_level = 100
129 logging.addLevelName(custom_log_level, 'RESULTS')
130 caller = os.path.basename(sys.argv[0])
131
132 def results(self, message, *args, **kws):
133 if self.isEnabledFor(custom_log_level):
134 self.log(custom_log_level, message, *args, **kws)
135 logging.Logger.results = results
136
137 logging.basicConfig(filename=os.path.join(os.getcwd(),'results-'+caller+'.log'),
138 filemode='w',
139 format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
140 datefmt='%H:%M:%S',
141 level=custom_log_level)
142 for handler in logging.root.handlers:
143 handler.addFilter(NoParsingFilter())
144 local_log = logging.getLogger(caller)
145
146 #check status of tests and record it
147
148 for (name, msg) in result.errors:
149 if (self._testMethodName == str(name).split(' ')[0]) and (class_name in str(name).split(' ')[1]):
150 local_log.results("Testcase "+str(test_case)+": ERROR")
151 local_log.results("Testcase "+str(test_case)+":\n"+msg)
152 passed = False
153 for (name, msg) in result.failures:
154 if (self._testMethodName == str(name).split(' ')[0]) and (class_name in str(name).split(' ')[1]):
155 local_log.results("Testcase "+str(test_case)+": FAILED")
156 local_log.results("Testcase "+str(test_case)+":\n"+msg)
157 passed = False
158 for (name, msg) in result.skipped:
159 if (self._testMethodName == str(name).split(' ')[0]) and (class_name in str(name).split(' ')[1]):
160 local_log.results("Testcase "+str(test_case)+": SKIPPED")
161 passed = False
162 if passed:
163 local_log.results("Testcase "+str(test_case)+": PASSED")
164
165 original_class.run = run
166 return original_class
167
168class TimeOut(BaseException):
169 pass
170
171def timeout(seconds):
172 def decorator(fn):
173 if hasattr(signal, 'alarm'):
174 @wraps(fn)
175 def wrapped_f(*args, **kw):
176 current_frame = sys._getframe()
177 def raiseTimeOut(signal, frame):
178 if frame is not current_frame:
179 raise TimeOut('%s seconds' % seconds)
180 prev_handler = signal.signal(signal.SIGALRM, raiseTimeOut)
181 try:
182 signal.alarm(seconds)
183 return fn(*args, **kw)
184 finally:
185 signal.alarm(0)
186 signal.signal(signal.SIGALRM, prev_handler)
187 return wrapped_f
188 else:
189 return fn
190 return decorator
191
192__tag_prefix = "tag__"
193def tag(*args, **kwargs):
194 """Decorator that adds attributes to classes or functions
195 for use with the Attribute (-a) plugin.
196 """
197 def wrap_ob(ob):
198 for name in args:
199 setattr(ob, __tag_prefix + name, True)
200 for name, value in kwargs.iteritems():
201 setattr(ob, __tag_prefix + name, value)
202 return ob
203 return wrap_ob
204
205def gettag(obj, key, default=None):
206 key = __tag_prefix + key
207 if not isinstance(obj, unittest.TestCase):
208 return getattr(obj, key, default)
209 tc_method = getattr(obj, obj._testMethodName)
210 ret = getattr(tc_method, key, getattr(obj, key, default))
211 return ret
212
213def getAllTags(obj):
214 def __gettags(o):
215 r = {k[len(__tag_prefix):]:getattr(o,k) for k in dir(o) if k.startswith(__tag_prefix)}
216 return r
217 if not isinstance(obj, unittest.TestCase):
218 return __gettags(obj)
219 tc_method = getattr(obj, obj._testMethodName)
220 ret = __gettags(obj)
221 ret.update(__gettags(tc_method))
222 return ret