00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 '''
00038 MySQL Protocol Packet Objects
00039 '''
00040
00041 import struct
00042 import unittest
00043
00044 class PacketException(Exception):
00045 pass
00046
00047 class Packet(object):
00048 '''This class represents a packet header.'''
00049
00050 def __init__(self, packed=None, size=0, sequence=0):
00051 if packed is None:
00052 self.size = size
00053 self.sequence = sequence
00054 else:
00055 data = struct.unpack('4B', packed)
00056 self.size = data[0] | (data[1] << 8) | (data[2] << 16)
00057 self.sequence = data[3]
00058
00059 self.verify()
00060
00061 def pack(self):
00062 self.verify()
00063 return struct.pack('4B',
00064 self.size & 0xFF,
00065 (self.size >> 8) & 0xFF,
00066 (self.size >> 16) & 0xFF,
00067 self.sequence % 256)
00068
00069 def verify(self):
00070 if self.size >= 16777216:
00071 raise PacketException('Packet size cannot exceed 16777215 bytes (%d)' %
00072 self.size)
00073
00074 def __str__(self):
00075 return '''Packet
00076 size = %s
00077 sequence = %s
00078 ''' % (self.size, self.sequence)
00079
00080 class TestPacket(unittest.TestCase):
00081
00082 def testDefaultInit(self):
00083 packet = Packet()
00084 self.assertEqual(packet.size, 0)
00085 self.assertEqual(packet.sequence, 0)
00086 packet.__str__()
00087
00088 def testKeywordInit(self):
00089 packet = Packet(size=1234, sequence=5)
00090 self.assertEqual(packet.size, 1234)
00091 self.assertEqual(packet.sequence, 5)
00092 packet.__str__()
00093
00094 def testUnpackInit(self):
00095 packet = Packet(struct.pack('4B', 210, 4, 0, 5))
00096 self.assertEqual(packet.size, 1234)
00097 self.assertEqual(packet.sequence, 5)
00098
00099 def testPack(self):
00100 packet = Packet(Packet(size=1234, sequence=5).pack())
00101 self.assertEqual(packet.size, 1234)
00102
00103 def testPackRange(self):
00104 for x in range(0, 300):
00105 packet = Packet(Packet(size=x, sequence=x).pack())
00106 self.assertEqual(packet.size, x)
00107 self.assertEqual(packet.sequence, x % 256)
00108
00109
00110 for x in range(300, 16777216, 997):
00111 packet = Packet(Packet(size=x, sequence=x).pack())
00112 self.assertEqual(packet.size, x)
00113 self.assertEqual(packet.sequence, x % 256)
00114
00115 packet = Packet(Packet(size=16777215).pack())
00116 self.assertEqual(packet.size, 16777215)
00117 self.assertEqual(packet.sequence, 0)
00118
00119 self.assertRaises(PacketException, Packet, size=16777216)
00120 self.assertRaises(PacketException, Packet, size=16777217)
00121 self.assertRaises(PacketException, Packet, size=4294967295)
00122 self.assertRaises(PacketException, Packet, size=4294967296)
00123 self.assertRaises(PacketException, Packet, size=4294967297)
00124
00125 def parse_row(count, data):
00126 row = []
00127 while count > 0:
00128 count -= 1
00129 if ord(data[0]) == 251:
00130 row.append(None)
00131 data = data[1:]
00132 else:
00133 (size, packed_size) = parse_encoded_size(data)
00134 row.append(data[packed_size:packed_size+size])
00135 data = data[packed_size+size:]
00136 return row
00137
00138 class BadSize(Exception):
00139 pass
00140
00141 def parse_encoded_size(data):
00142 size = ord(data[0])
00143 packed_size = 1
00144 if size == 252:
00145 size = struct.unpack('<H', data[1:3])[0]
00146 packed_size = 3
00147 elif size == 253:
00148 data = struct.unpack('<HB', data[1:4])
00149 size = data[0] | (data[1] << 16)
00150 packed_size = 4
00151 elif size == 254:
00152 data = struct.unpack('<II', data[1:9])
00153 size = data[0] | (data[1] << 32)
00154 packed_size = 8
00155 elif size == 255:
00156 raise BadSize(str(size))
00157
00158 return (size, packed_size)
00159
00160 if __name__ == '__main__':
00161 unittest.main()