Norman James | 8b2b722 | 2015-10-08 07:01:38 -0500 | [diff] [blame] | 1 | """This module implements all contexts for state handling during uploads and |
| 2 | downloads, the main interface to which being the TftpContext base class. |
| 3 | |
| 4 | The concept is simple. Each context object represents a single upload or |
| 5 | download, and the state object in the context object represents the current |
| 6 | state of that transfer. The state object has a handle() method that expects |
| 7 | the next packet in the transfer, and returns a state object until the transfer |
| 8 | is complete, at which point it returns None. That is, unless there is a fatal |
| 9 | error, in which case a TftpException is returned instead.""" |
| 10 | |
| 11 | from TftpShared import * |
| 12 | from TftpPacketTypes import * |
| 13 | from TftpPacketFactory import TftpPacketFactory |
| 14 | from TftpStates import * |
| 15 | import socket, time, sys |
| 16 | |
| 17 | ############################################################################### |
| 18 | # Utility classes |
| 19 | ############################################################################### |
| 20 | |
| 21 | class TftpMetrics(object): |
| 22 | """A class representing metrics of the transfer.""" |
| 23 | def __init__(self): |
| 24 | # Bytes transferred |
| 25 | self.bytes = 0 |
| 26 | # Bytes re-sent |
| 27 | self.resent_bytes = 0 |
| 28 | # Duplicate packets received |
| 29 | self.dups = {} |
| 30 | self.dupcount = 0 |
| 31 | # Times |
| 32 | self.start_time = 0 |
| 33 | self.end_time = 0 |
| 34 | self.duration = 0 |
| 35 | # Rates |
| 36 | self.bps = 0 |
| 37 | self.kbps = 0 |
| 38 | # Generic errors |
| 39 | self.errors = 0 |
| 40 | |
| 41 | def compute(self): |
| 42 | # Compute transfer time |
| 43 | self.duration = self.end_time - self.start_time |
| 44 | if self.duration == 0: |
| 45 | self.duration = 1 |
| 46 | log.debug("TftpMetrics.compute: duration is %s", self.duration) |
| 47 | self.bps = (self.bytes * 8.0) / self.duration |
| 48 | self.kbps = self.bps / 1024.0 |
| 49 | log.debug("TftpMetrics.compute: kbps is %s", self.kbps) |
| 50 | for key in self.dups: |
| 51 | self.dupcount += self.dups[key] |
| 52 | |
| 53 | def add_dup(self, pkt): |
| 54 | """This method adds a dup for a packet to the metrics.""" |
| 55 | log.debug("Recording a dup of %s", pkt) |
| 56 | s = str(pkt) |
| 57 | if self.dups.has_key(s): |
| 58 | self.dups[s] += 1 |
| 59 | else: |
| 60 | self.dups[s] = 1 |
| 61 | tftpassert(self.dups[s] < MAX_DUPS, "Max duplicates reached") |
| 62 | |
| 63 | ############################################################################### |
| 64 | # Context classes |
| 65 | ############################################################################### |
| 66 | |
| 67 | class TftpContext(object): |
| 68 | """The base class of the contexts.""" |
| 69 | |
| 70 | def __init__(self, host, port, timeout, localip = ""): |
| 71 | """Constructor for the base context, setting shared instance |
| 72 | variables.""" |
| 73 | self.file_to_transfer = None |
| 74 | self.fileobj = None |
| 75 | self.options = None |
| 76 | self.packethook = None |
| 77 | self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |
| 78 | if localip != "": |
| 79 | self.sock.bind((localip, 0)) |
| 80 | self.sock.settimeout(timeout) |
| 81 | self.timeout = timeout |
| 82 | self.state = None |
| 83 | self.next_block = 0 |
| 84 | self.factory = TftpPacketFactory() |
| 85 | # Note, setting the host will also set self.address, as it's a property. |
| 86 | self.host = host |
| 87 | self.port = port |
| 88 | # The port associated with the TID |
| 89 | self.tidport = None |
| 90 | # Metrics |
| 91 | self.metrics = TftpMetrics() |
| 92 | # Fluag when the transfer is pending completion. |
| 93 | self.pending_complete = False |
| 94 | # Time when this context last received any traffic. |
| 95 | # FIXME: does this belong in metrics? |
| 96 | self.last_update = 0 |
| 97 | # The last packet we sent, if applicable, to make resending easy. |
| 98 | self.last_pkt = None |
| 99 | # Count the number of retry attempts. |
| 100 | self.retry_count = 0 |
| 101 | |
| 102 | def getBlocksize(self): |
| 103 | """Fetch the current blocksize for this session.""" |
| 104 | return int(self.options.get('blksize', 512)) |
| 105 | |
| 106 | def __del__(self): |
| 107 | """Simple destructor to try to call housekeeping in the end method if |
| 108 | not called explicitely. Leaking file descriptors is not a good |
| 109 | thing.""" |
| 110 | self.end() |
| 111 | |
| 112 | def checkTimeout(self, now): |
| 113 | """Compare current time with last_update time, and raise an exception |
| 114 | if we're over the timeout time.""" |
| 115 | log.debug("checking for timeout on session %s", self) |
| 116 | if now - self.last_update > self.timeout: |
| 117 | raise TftpTimeout, "Timeout waiting for traffic" |
| 118 | |
| 119 | def start(self): |
| 120 | raise NotImplementedError, "Abstract method" |
| 121 | |
| 122 | def end(self): |
| 123 | """Perform session cleanup, since the end method should always be |
| 124 | called explicitely by the calling code, this works better than the |
| 125 | destructor.""" |
| 126 | log.debug("in TftpContext.end") |
| 127 | self.sock.close() |
| 128 | if self.fileobj is not None and not self.fileobj.closed: |
| 129 | log.debug("self.fileobj is open - closing") |
| 130 | self.fileobj.close() |
| 131 | |
| 132 | def gethost(self): |
| 133 | "Simple getter method for use in a property." |
| 134 | return self.__host |
| 135 | |
| 136 | def sethost(self, host): |
| 137 | """Setter method that also sets the address property as a result |
| 138 | of the host that is set.""" |
| 139 | self.__host = host |
| 140 | self.address = socket.gethostbyname(host) |
| 141 | |
| 142 | host = property(gethost, sethost) |
| 143 | |
| 144 | def setNextBlock(self, block): |
| 145 | if block >= 2 ** 16: |
| 146 | log.debug("Block number rollover to 0 again") |
| 147 | block = 0 |
| 148 | self.__eblock = block |
| 149 | |
| 150 | def getNextBlock(self): |
| 151 | return self.__eblock |
| 152 | |
| 153 | next_block = property(getNextBlock, setNextBlock) |
| 154 | |
| 155 | def cycle(self): |
| 156 | """Here we wait for a response from the server after sending it |
| 157 | something, and dispatch appropriate action to that response.""" |
| 158 | try: |
| 159 | (buffer, (raddress, rport)) = self.sock.recvfrom(MAX_BLKSIZE) |
| 160 | except socket.timeout: |
| 161 | log.warn("Timeout waiting for traffic, retrying...") |
| 162 | raise TftpTimeout, "Timed-out waiting for traffic" |
| 163 | |
| 164 | # Ok, we've received a packet. Log it. |
| 165 | log.debug("Received %d bytes from %s:%s", |
| 166 | len(buffer), raddress, rport) |
| 167 | # And update our last updated time. |
| 168 | self.last_update = time.time() |
| 169 | |
| 170 | # Decode it. |
| 171 | recvpkt = self.factory.parse(buffer) |
| 172 | |
| 173 | # Check for known "connection". |
| 174 | if raddress != self.address: |
| 175 | log.warn("Received traffic from %s, expected host %s. Discarding" |
| 176 | % (raddress, self.host)) |
| 177 | |
| 178 | if self.tidport and self.tidport != rport: |
| 179 | log.warn("Received traffic from %s:%s but we're " |
| 180 | "connected to %s:%s. Discarding." |
| 181 | % (raddress, rport, |
| 182 | self.host, self.tidport)) |
| 183 | |
| 184 | # If there is a packethook defined, call it. We unconditionally |
| 185 | # pass all packets, it's up to the client to screen out different |
| 186 | # kinds of packets. This way, the client is privy to things like |
| 187 | # negotiated options. |
| 188 | if self.packethook: |
| 189 | self.packethook(recvpkt) |
| 190 | |
| 191 | # And handle it, possibly changing state. |
| 192 | self.state = self.state.handle(recvpkt, raddress, rport) |
| 193 | # If we didn't throw any exceptions here, reset the retry_count to |
| 194 | # zero. |
| 195 | self.retry_count = 0 |
| 196 | |
| 197 | class TftpContextServer(TftpContext): |
| 198 | """The context for the server.""" |
| 199 | def __init__(self, host, port, timeout, root, dyn_file_func=None): |
| 200 | TftpContext.__init__(self, |
| 201 | host, |
| 202 | port, |
| 203 | timeout, |
| 204 | ) |
| 205 | # At this point we have no idea if this is a download or an upload. We |
| 206 | # need to let the start state determine that. |
| 207 | self.state = TftpStateServerStart(self) |
| 208 | |
| 209 | self.root = root |
| 210 | self.dyn_file_func = dyn_file_func |
| 211 | |
| 212 | def __str__(self): |
| 213 | return "%s:%s %s" % (self.host, self.port, self.state) |
| 214 | |
| 215 | def start(self, buffer): |
| 216 | """Start the state cycle. Note that the server context receives an |
| 217 | initial packet in its start method. Also note that the server does not |
| 218 | loop on cycle(), as it expects the TftpServer object to manage |
| 219 | that.""" |
| 220 | log.debug("In TftpContextServer.start") |
| 221 | self.metrics.start_time = time.time() |
| 222 | log.debug("Set metrics.start_time to %s", self.metrics.start_time) |
| 223 | # And update our last updated time. |
| 224 | self.last_update = time.time() |
| 225 | |
| 226 | pkt = self.factory.parse(buffer) |
| 227 | log.debug("TftpContextServer.start() - factory returned a %s", pkt) |
| 228 | |
| 229 | # Call handle once with the initial packet. This should put us into |
| 230 | # the download or the upload state. |
| 231 | self.state = self.state.handle(pkt, |
| 232 | self.host, |
| 233 | self.port) |
| 234 | |
| 235 | def end(self): |
| 236 | """Finish up the context.""" |
| 237 | TftpContext.end(self) |
| 238 | self.metrics.end_time = time.time() |
| 239 | log.debug("Set metrics.end_time to %s", self.metrics.end_time) |
| 240 | self.metrics.compute() |
| 241 | |
| 242 | class TftpContextClientUpload(TftpContext): |
| 243 | """The upload context for the client during an upload. |
| 244 | Note: If input is a hyphen, then we will use stdin.""" |
| 245 | def __init__(self, |
| 246 | host, |
| 247 | port, |
| 248 | filename, |
| 249 | input, |
| 250 | options, |
| 251 | packethook, |
| 252 | timeout, |
| 253 | localip = ""): |
| 254 | TftpContext.__init__(self, |
| 255 | host, |
| 256 | port, |
| 257 | timeout, |
| 258 | localip) |
| 259 | self.file_to_transfer = filename |
| 260 | self.options = options |
| 261 | self.packethook = packethook |
| 262 | # If the input object has a read() function, |
| 263 | # assume it is file-like. |
| 264 | if hasattr(input, 'read'): |
| 265 | self.fileobj = input |
| 266 | elif input == '-': |
| 267 | self.fileobj = sys.stdin |
| 268 | else: |
| 269 | self.fileobj = open(input, "rb") |
| 270 | |
| 271 | log.debug("TftpContextClientUpload.__init__()") |
| 272 | log.debug("file_to_transfer = %s, options = %s", |
| 273 | self.file_to_transfer, self.options) |
| 274 | |
| 275 | def __str__(self): |
| 276 | return "%s:%s %s" % (self.host, self.port, self.state) |
| 277 | |
| 278 | def start(self): |
| 279 | log.info("Sending tftp upload request to %s" % self.host) |
| 280 | log.info(" filename -> %s" % self.file_to_transfer) |
| 281 | log.info(" options -> %s" % self.options) |
| 282 | |
| 283 | self.metrics.start_time = time.time() |
| 284 | log.debug("Set metrics.start_time to %s", self.metrics.start_time) |
| 285 | |
| 286 | # FIXME: put this in a sendWRQ method? |
| 287 | pkt = TftpPacketWRQ() |
| 288 | pkt.filename = self.file_to_transfer |
| 289 | pkt.mode = "octet" # FIXME - shouldn't hardcode this |
| 290 | pkt.options = self.options |
| 291 | self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) |
| 292 | self.next_block = 1 |
| 293 | self.last_pkt = pkt |
| 294 | # FIXME: should we centralize sendto operations so we can refactor all |
| 295 | # saving of the packet to the last_pkt field? |
| 296 | |
| 297 | self.state = TftpStateSentWRQ(self) |
| 298 | |
| 299 | while self.state: |
| 300 | try: |
| 301 | log.debug("State is %s", self.state) |
| 302 | self.cycle() |
| 303 | except TftpTimeout, err: |
| 304 | log.error(str(err)) |
| 305 | self.retry_count += 1 |
| 306 | if self.retry_count >= TIMEOUT_RETRIES: |
| 307 | log.debug("hit max retries, giving up") |
| 308 | raise |
| 309 | else: |
| 310 | log.warn("resending last packet") |
| 311 | self.state.resendLast() |
| 312 | |
| 313 | def end(self): |
| 314 | """Finish up the context.""" |
| 315 | TftpContext.end(self) |
| 316 | self.metrics.end_time = time.time() |
| 317 | log.debug("Set metrics.end_time to %s", self.metrics.end_time) |
| 318 | self.metrics.compute() |
| 319 | |
| 320 | class TftpContextClientDownload(TftpContext): |
| 321 | """The download context for the client during a download. |
| 322 | Note: If output is a hyphen, then the output will be sent to stdout.""" |
| 323 | def __init__(self, |
| 324 | host, |
| 325 | port, |
| 326 | filename, |
| 327 | output, |
| 328 | options, |
| 329 | packethook, |
| 330 | timeout, |
| 331 | localip = ""): |
| 332 | TftpContext.__init__(self, |
| 333 | host, |
| 334 | port, |
| 335 | timeout, |
| 336 | localip) |
| 337 | # FIXME: should we refactor setting of these params? |
| 338 | self.file_to_transfer = filename |
| 339 | self.options = options |
| 340 | self.packethook = packethook |
| 341 | # If the output object has a write() function, |
| 342 | # assume it is file-like. |
| 343 | if hasattr(output, 'write'): |
| 344 | self.fileobj = output |
| 345 | # If the output filename is -, then use stdout |
| 346 | elif output == '-': |
| 347 | self.fileobj = sys.stdout |
| 348 | else: |
| 349 | self.fileobj = open(output, "wb") |
| 350 | |
| 351 | log.debug("TftpContextClientDownload.__init__()") |
| 352 | log.debug("file_to_transfer = %s, options = %s", |
| 353 | self.file_to_transfer, self.options) |
| 354 | |
| 355 | def __str__(self): |
| 356 | return "%s:%s %s" % (self.host, self.port, self.state) |
| 357 | |
| 358 | def start(self): |
| 359 | """Initiate the download.""" |
| 360 | log.info("Sending tftp download request to %s" % self.host) |
| 361 | log.info(" filename -> %s" % self.file_to_transfer) |
| 362 | log.info(" options -> %s" % self.options) |
| 363 | |
| 364 | self.metrics.start_time = time.time() |
| 365 | log.debug("Set metrics.start_time to %s", self.metrics.start_time) |
| 366 | |
| 367 | # FIXME: put this in a sendRRQ method? |
| 368 | pkt = TftpPacketRRQ() |
| 369 | pkt.filename = self.file_to_transfer |
| 370 | pkt.mode = "octet" # FIXME - shouldn't hardcode this |
| 371 | pkt.options = self.options |
| 372 | self.sock.sendto(pkt.encode().buffer, (self.host, self.port)) |
| 373 | self.next_block = 1 |
| 374 | self.last_pkt = pkt |
| 375 | |
| 376 | self.state = TftpStateSentRRQ(self) |
| 377 | |
| 378 | while self.state: |
| 379 | try: |
| 380 | log.debug("State is %s", self.state) |
| 381 | self.cycle() |
| 382 | except TftpTimeout, err: |
| 383 | log.error(str(err)) |
| 384 | self.retry_count += 1 |
| 385 | if self.retry_count >= TIMEOUT_RETRIES: |
| 386 | log.debug("hit max retries, giving up") |
| 387 | raise |
| 388 | else: |
| 389 | log.warn("resending last packet") |
| 390 | self.state.resendLast() |
| 391 | |
| 392 | def end(self): |
| 393 | """Finish up the context.""" |
| 394 | TftpContext.end(self) |
| 395 | self.metrics.end_time = time.time() |
| 396 | log.debug("Set metrics.end_time to %s", self.metrics.end_time) |
| 397 | self.metrics.compute() |