blob: 4c41916ad547c574ef647aa5edd5de004d3749e6 [file] [log] [blame]
Norman James8b2b7222015-10-08 07:01:38 -05001"""This module implements all contexts for state handling during uploads and
2downloads, the main interface to which being the TftpContext base class.
3
4The concept is simple. Each context object represents a single upload or
5download, and the state object in the context object represents the current
6state of that transfer. The state object has a handle() method that expects
7the next packet in the transfer, and returns a state object until the transfer
8is complete, at which point it returns None. That is, unless there is a fatal
9error, in which case a TftpException is returned instead."""
10
11from TftpShared import *
12from TftpPacketTypes import *
13from TftpPacketFactory import TftpPacketFactory
14from TftpStates import *
15import socket, time, sys
16
17###############################################################################
18# Utility classes
19###############################################################################
20
21class TftpMetrics(object):
22 """A class representing metrics of the transfer."""
23 def __init__(self):
24 # Bytes transferred
25 self.bytes = 0
26 # Bytes re-sent
27 self.resent_bytes = 0
28 # Duplicate packets received
29 self.dups = {}
30 self.dupcount = 0
31 # Times
32 self.start_time = 0
33 self.end_time = 0
34 self.duration = 0
35 # Rates
36 self.bps = 0
37 self.kbps = 0
38 # Generic errors
39 self.errors = 0
40
41 def compute(self):
42 # Compute transfer time
43 self.duration = self.end_time - self.start_time
44 if self.duration == 0:
45 self.duration = 1
46 log.debug("TftpMetrics.compute: duration is %s", self.duration)
47 self.bps = (self.bytes * 8.0) / self.duration
48 self.kbps = self.bps / 1024.0
49 log.debug("TftpMetrics.compute: kbps is %s", self.kbps)
50 for key in self.dups:
51 self.dupcount += self.dups[key]
52
53 def add_dup(self, pkt):
54 """This method adds a dup for a packet to the metrics."""
55 log.debug("Recording a dup of %s", pkt)
56 s = str(pkt)
57 if self.dups.has_key(s):
58 self.dups[s] += 1
59 else:
60 self.dups[s] = 1
61 tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached")
62
63###############################################################################
64# Context classes
65###############################################################################
66
67class TftpContext(object):
68 """The base class of the contexts."""
69
70 def __init__(self, host, port, timeout, localip = ""):
71 """Constructor for the base context, setting shared instance
72 variables."""
73 self.file_to_transfer = None
74 self.fileobj = None
75 self.options = None
76 self.packethook = None
77 self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
78 if localip != "":
79 self.sock.bind((localip, 0))
80 self.sock.settimeout(timeout)
81 self.timeout = timeout
82 self.state = None
83 self.next_block = 0
84 self.factory = TftpPacketFactory()
85 # Note, setting the host will also set self.address, as it's a property.
86 self.host = host
87 self.port = port
88 # The port associated with the TID
89 self.tidport = None
90 # Metrics
91 self.metrics = TftpMetrics()
92 # Fluag when the transfer is pending completion.
93 self.pending_complete = False
94 # Time when this context last received any traffic.
95 # FIXME: does this belong in metrics?
96 self.last_update = 0
97 # The last packet we sent, if applicable, to make resending easy.
98 self.last_pkt = None
99 # Count the number of retry attempts.
100 self.retry_count = 0
101
102 def getBlocksize(self):
103 """Fetch the current blocksize for this session."""
104 return int(self.options.get('blksize', 512))
105
106 def __del__(self):
107 """Simple destructor to try to call housekeeping in the end method if
108 not called explicitely. Leaking file descriptors is not a good
109 thing."""
110 self.end()
111
112 def checkTimeout(self, now):
113 """Compare current time with last_update time, and raise an exception
114 if we're over the timeout time."""
115 log.debug("checking for timeout on session %s", self)
116 if now - self.last_update > self.timeout:
117 raise TftpTimeout, "Timeout waiting for traffic"
118
119 def start(self):
120 raise NotImplementedError, "Abstract method"
121
122 def end(self):
123 """Perform session cleanup, since the end method should always be
124 called explicitely by the calling code, this works better than the
125 destructor."""
126 log.debug("in TftpContext.end")
127 self.sock.close()
128 if self.fileobj is not None and not self.fileobj.closed:
129 log.debug("self.fileobj is open - closing")
130 self.fileobj.close()
131
132 def gethost(self):
133 "Simple getter method for use in a property."
134 return self.__host
135
136 def sethost(self, host):
137 """Setter method that also sets the address property as a result
138 of the host that is set."""
139 self.__host = host
140 self.address = socket.gethostbyname(host)
141
142 host = property(gethost, sethost)
143
144 def setNextBlock(self, block):
145 if block >= 2 ** 16:
146 log.debug("Block number rollover to 0 again")
147 block = 0
148 self.__eblock = block
149
150 def getNextBlock(self):
151 return self.__eblock
152
153 next_block = property(getNextBlock, setNextBlock)
154
155 def cycle(self):
156 """Here we wait for a response from the server after sending it
157 something, and dispatch appropriate action to that response."""
158 try:
159 (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE)
160 except socket.timeout:
161 log.warn("Timeout waiting for traffic, retrying...")
162 raise TftpTimeout, "Timed-out waiting for traffic"
163
164 # Ok, we've received a packet. Log it.
165 log.debug("Received %d bytes from %s:%s",
166 len(buffer), raddress, rport)
167 # And update our last updated time.
168 self.last_update = time.time()
169
170 # Decode it.
171 recvpkt = self.factory.parse(buffer)
172
173 # Check for known "connection".
174 if raddress != self.address:
175 log.warn("Received traffic from %s, expected host %s. Discarding"
176 % (raddress, self.host))
177
178 if self.tidport and self.tidport != rport:
179 log.warn("Received traffic from %s:%s but we're "
180 "connected to %s:%s. Discarding."
181 % (raddress, rport,
182 self.host, self.tidport))
183
184 # If there is a packethook defined, call it. We unconditionally
185 # pass all packets, it's up to the client to screen out different
186 # kinds of packets. This way, the client is privy to things like
187 # negotiated options.
188 if self.packethook:
189 self.packethook(recvpkt)
190
191 # And handle it, possibly changing state.
192 self.state = self.state.handle(recvpkt, raddress, rport)
193 # If we didn't throw any exceptions here, reset the retry_count to
194 # zero.
195 self.retry_count = 0
196
197class TftpContextServer(TftpContext):
198 """The context for the server."""
199 def __init__(self, host, port, timeout, root, dyn_file_func=None):
200 TftpContext.__init__(self,
201 host,
202 port,
203 timeout,
204 )
205 # At this point we have no idea if this is a download or an upload. We
206 # need to let the start state determine that.
207 self.state = TftpStateServerStart(self)
208
209 self.root = root
210 self.dyn_file_func = dyn_file_func
211
212 def __str__(self):
213 return "%s:%s %s" % (self.host, self.port, self.state)
214
215 def start(self, buffer):
216 """Start the state cycle. Note that the server context receives an
217 initial packet in its start method. Also note that the server does not
218 loop on cycle(), as it expects the TftpServer object to manage
219 that."""
220 log.debug("In TftpContextServer.start")
221 self.metrics.start_time = time.time()
222 log.debug("Set metrics.start_time to %s", self.metrics.start_time)
223 # And update our last updated time.
224 self.last_update = time.time()
225
226 pkt = self.factory.parse(buffer)
227 log.debug("TftpContextServer.start() - factory returned a %s", pkt)
228
229 # Call handle once with the initial packet. This should put us into
230 # the download or the upload state.
231 self.state = self.state.handle(pkt,
232 self.host,
233 self.port)
234
235 def end(self):
236 """Finish up the context."""
237 TftpContext.end(self)
238 self.metrics.end_time = time.time()
239 log.debug("Set metrics.end_time to %s", self.metrics.end_time)
240 self.metrics.compute()
241
242class TftpContextClientUpload(TftpContext):
243 """The upload context for the client during an upload.
244 Note: If input is a hyphen, then we will use stdin."""
245 def __init__(self,
246 host,
247 port,
248 filename,
249 input,
250 options,
251 packethook,
252 timeout,
253 localip = ""):
254 TftpContext.__init__(self,
255 host,
256 port,
257 timeout,
258 localip)
259 self.file_to_transfer = filename
260 self.options = options
261 self.packethook = packethook
262 # If the input object has a read() function,
263 # assume it is file-like.
264 if hasattr(input, 'read'):
265 self.fileobj = input
266 elif input == '-':
267 self.fileobj = sys.stdin
268 else:
269 self.fileobj = open(input, "rb")
270
271 log.debug("TftpContextClientUpload.__init__()")
272 log.debug("file_to_transfer = %s, options = %s",
273 self.file_to_transfer, self.options)
274
275 def __str__(self):
276 return "%s:%s %s" % (self.host, self.port, self.state)
277
278 def start(self):
279 log.info("Sending tftp upload request to %s" % self.host)
280 log.info(" filename -> %s" % self.file_to_transfer)
281 log.info(" options -> %s" % self.options)
282
283 self.metrics.start_time = time.time()
284 log.debug("Set metrics.start_time to %s", self.metrics.start_time)
285
286 # FIXME: put this in a sendWRQ method?
287 pkt = TftpPacketWRQ()
288 pkt.filename = self.file_to_transfer
289 pkt.mode = "octet" # FIXME - shouldn't hardcode this
290 pkt.options = self.options
291 self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
292 self.next_block = 1
293 self.last_pkt = pkt
294 # FIXME: should we centralize sendto operations so we can refactor all
295 # saving of the packet to the last_pkt field?
296
297 self.state = TftpStateSentWRQ(self)
298
299 while self.state:
300 try:
301 log.debug("State is %s", self.state)
302 self.cycle()
303 except TftpTimeout, err:
304 log.error(str(err))
305 self.retry_count += 1
306 if self.retry_count >= TIMEOUT_RETRIES:
307 log.debug("hit max retries, giving up")
308 raise
309 else:
310 log.warn("resending last packet")
311 self.state.resendLast()
312
313 def end(self):
314 """Finish up the context."""
315 TftpContext.end(self)
316 self.metrics.end_time = time.time()
317 log.debug("Set metrics.end_time to %s", self.metrics.end_time)
318 self.metrics.compute()
319
320class TftpContextClientDownload(TftpContext):
321 """The download context for the client during a download.
322 Note: If output is a hyphen, then the output will be sent to stdout."""
323 def __init__(self,
324 host,
325 port,
326 filename,
327 output,
328 options,
329 packethook,
330 timeout,
331 localip = ""):
332 TftpContext.__init__(self,
333 host,
334 port,
335 timeout,
336 localip)
337 # FIXME: should we refactor setting of these params?
338 self.file_to_transfer = filename
339 self.options = options
340 self.packethook = packethook
341 # If the output object has a write() function,
342 # assume it is file-like.
343 if hasattr(output, 'write'):
344 self.fileobj = output
345 # If the output filename is -, then use stdout
346 elif output == '-':
347 self.fileobj = sys.stdout
348 else:
349 self.fileobj = open(output, "wb")
350
351 log.debug("TftpContextClientDownload.__init__()")
352 log.debug("file_to_transfer = %s, options = %s",
353 self.file_to_transfer, self.options)
354
355 def __str__(self):
356 return "%s:%s %s" % (self.host, self.port, self.state)
357
358 def start(self):
359 """Initiate the download."""
360 log.info("Sending tftp download request to %s" % self.host)
361 log.info(" filename -> %s" % self.file_to_transfer)
362 log.info(" options -> %s" % self.options)
363
364 self.metrics.start_time = time.time()
365 log.debug("Set metrics.start_time to %s", self.metrics.start_time)
366
367 # FIXME: put this in a sendRRQ method?
368 pkt = TftpPacketRRQ()
369 pkt.filename = self.file_to_transfer
370 pkt.mode = "octet" # FIXME - shouldn't hardcode this
371 pkt.options = self.options
372 self.sock.sendto(pkt.encode().buffer, (self.host, self.port))
373 self.next_block = 1
374 self.last_pkt = pkt
375
376 self.state = TftpStateSentRRQ(self)
377
378 while self.state:
379 try:
380 log.debug("State is %s", self.state)
381 self.cycle()
382 except TftpTimeout, err:
383 log.error(str(err))
384 self.retry_count += 1
385 if self.retry_count >= TIMEOUT_RETRIES:
386 log.debug("hit max retries, giving up")
387 raise
388 else:
389 log.warn("resending last packet")
390 self.state.resendLast()
391
392 def end(self):
393 """Finish up the context."""
394 TftpContext.end(self)
395 self.metrics.end_time = time.time()
396 log.debug("Set metrics.end_time to %s", self.metrics.end_time)
397 self.metrics.compute()