// http://www.postgresql.org/docs/current/interactive/protocol.html package main import ( "fmt" "flag" "net" "os" "encoding/binary" "crypto/md5" ) const ( debug = 3 ) func stringlength(s string) uint32 { return (uint32(len(s)) + 1) // For the null byte at the end } func readString(buffer []byte) (uint32, []byte) { var i uint32 for i = 0; i < uint32(len(buffer)); i++ { if buffer[i] == 0 { return i, buffer[0:i] } } return uint32(len(buffer)), buffer } func main() { var ( port = "5432" name = []byte("gadjin") password = []byte("tropdiff") db = []byte("essais") query = []byte("SELECT * FROM Foobar;") ) if flag.NArg() != 1 { fmt.Printf("Usage: test-proto host\n") os.Exit(1) } // Connect host := flag.Arg(0) remote := host + ":" + port con, error := net.Dial("tcp", "", remote) if error != nil { fmt.Printf("Cannot connect to %s: %s\n", host, error) os.Exit(1) } defer con.Close() packet := make([]byte, 1024) // Authenticate length := 4 + 4 + stringlength("user") + uint32(len(name)) + 1 + stringlength("database") + uint32(len(db)) + 1 + 1 binary.BigEndian.PutUint32(packet[0:], length) binary.BigEndian.PutUint16(packet[4:], 3) binary.BigEndian.PutUint16(packet[6:], 0) last := 8 s := []byte("user") n := copy(packet[last:], s) last += n packet[last] = 0 last++ n = copy(packet[last:], name) last += n packet[last] = 0 last++ s = []byte("database") n = copy(packet[last:], s) last += n packet[last] = 0 last++ n = copy(packet[last:], db) last += n packet[last] = 0 last++ _, error = con.Write(packet[0 : last+1]) if error != nil { fmt.Printf("Error in Write: %s\n", error) os.Exit(1) } message := make([]byte, 1024) _, error = con.Read(message) if error != nil { fmt.Printf("Error in Read: %s\n", error) os.Exit(1) } status := message[0] length = binary.BigEndian.Uint32(message[1:]) auth := binary.BigEndian.Uint32(message[5:]) if debug > 2 { fmt.Printf("First response from server: %c/%d (%d bytes)\n", status, auth, length) } if status != 'R' { fmt.Printf("Error from server, received code %c\n", status) os.Exit(1) } switch auth { case 0: if debug > 1 { fmt.Printf("Immediately accepted") } case 3: if debug > 2 { fmt.Printf("Sending back a clear-text password\n") } length := 4 + uint32(len(password)) + 1 packet[0] = 'p' binary.BigEndian.PutUint32(packet[1:], length) last = 5 n = copy(packet[last:], password) last += n packet[last] = 0 last++ case 5: salt := message[9:13] hasher := md5.New() hasher.Write(password) hasher.Write(name) // Encoding as hexadecimal is not documented. But this // is what the predefined SQL function md5() do. firsthashedpass := []byte(fmt.Sprintf("%x", hasher.Sum())) hasher.Reset() // Not really documented, except in the // source. hasher.Write(firsthashedpass) hasher.Write(salt) hashedpass := []byte(fmt.Sprintf("%x", hasher.Sum())) if debug > 2 { fmt.Printf("Sending back a MD5-hashed password\n") } length := 4 + 3 + uint32(len(hashedpass)) + 1 packet[0] = 'p' binary.BigEndian.PutUint32(packet[1:], length) last = 5 n = copy(packet[last:], []byte("md5")) last += 3 n = copy(packet[last:], hashedpass) last += n packet[last] = 0 last++ default: fmt.Printf("Unsupported authentication method %i\n", auth) os.Exit(1) } _, error = con.Write(packet[0:last]) if error != nil { fmt.Printf("Error in Write: %s\n", error) os.Exit(1) } _, error = con.Read(message) if error != nil { fmt.Printf("Error in Read: %s\n", error) os.Exit(1) } status = message[0] length = binary.BigEndian.Uint32(message[1:]) if debug > 2 { fmt.Printf("Second response from server: %c (%d bytes)\n", status, length) } if status == 'R' { auth = binary.BigEndian.Uint32(message[5:]) if auth == 0 { fmt.Printf("User %s logged in\n", name) } else { fmt.Printf("Not authentified\n") os.Exit(1) } } else { fmt.Printf("Wrong status (expected a positive response)\n") os.Exit(1) } // Send the query packet[0] = 'Q' length = 4 + uint32(len(query)) + 1 binary.BigEndian.PutUint32(packet[1:], length) last = 5 n = copy(packet[last:], query) last += n packet[last] = 0 last++ n, error = con.Write(packet[0:last]) if error != nil { fmt.Printf("Error in Write: %s\n", error) os.Exit(1) } n, error = con.Read(message) if error != nil { fmt.Printf("Error in Read: %s\n", error) os.Exit(1) } if debug > 2 { fmt.Printf("Response to query: %d bytes read\n", n) } position := 0 for over := false; !over; { status = message[position] length = binary.BigEndian.Uint32(message[position+1:]) if status == 'C' { // CommandComplete if debug > 2 { fmt.Printf("Query OK\n") } over = true } else if status == 'T' { // RowDescription, we currently ignore it position += int(length) + 1 status = message[position] length = binary.BigEndian.Uint32(message[position+1:]) } else if status == 'D' { // DataRow ncols := int(binary.BigEndian.Uint16(message[position+5:])) colposition := position + 7 for i := 0; i < ncols; i++ { datalen := int(binary.BigEndian.Uint32(message[colposition:])) colposition += 4 data := message[colposition : colposition+datalen] colposition += datalen fmt.Printf("\t%s", data) } fmt.Printf("\n") position += int(length) + 1 status = message[position] length = binary.BigEndian.Uint32(message[position+1:]) } else if status == 'E' { fmt.Printf("Error in query:\n") index := uint32(5) fieldtype := message[index] for n, fieldvalue := readString(message[index+1:]); n != 0 && fieldtype != 0 && index <= uint32(len(message)); n, fieldvalue = readString(message[index+1:]) { fmt.Printf("\t%c: %s\n", fieldtype, fieldvalue) index += n + 2 fieldtype = message[index] } os.Exit(1) } else { fmt.Printf("Unexpected status %c (expected a positive response to query)\n", status) os.Exit(1) } } // Disconnect packet[0] = 'X' length = 4 binary.BigEndian.PutUint32(packet[1:], length) last = 5 n, error = con.Write(packet[0:last]) if error != nil { fmt.Printf("Error in Write: %s\n", error) os.Exit(1) } }