Update vendored dependencies
This commit is contained in:
parent
8f3f403ae0
commit
df761eb66c
|
@ -10,14 +10,14 @@
|
|||
[[projects]]
|
||||
name = "github.com/go-sql-driver/mysql"
|
||||
packages = ["."]
|
||||
revision = "a0583e0143b1624142adab07e0e97fe106d99561"
|
||||
version = "v1.3"
|
||||
revision = "d523deb1b23d913de5bdada721a6071e71283618"
|
||||
version = "v1.4.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/jessevdk/go-flags"
|
||||
packages = ["."]
|
||||
revision = "96dc06278ce32a0e9d957d590bb987c81ee66407"
|
||||
version = "v1.3.0"
|
||||
revision = "c6ca198ec95c841fdb89fc0de7496fed11ab854e"
|
||||
version = "v1.4.0"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/pkg/errors"
|
||||
|
@ -25,6 +25,12 @@
|
|||
revision = "645ef00459ed84a119197bfb8d8205042c6df63d"
|
||||
version = "v0.8.0"
|
||||
|
||||
[[projects]]
|
||||
name = "google.golang.org/appengine"
|
||||
packages = ["cloudsql"]
|
||||
revision = "150dc57a1b433e64154302bdc40b6bb8aefa313a"
|
||||
version = "v1.0.0"
|
||||
|
||||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
|
|
|
@ -20,7 +20,3 @@
|
|||
# name = "github.com/x/y"
|
||||
# version = "2.4.0"
|
||||
|
||||
|
||||
[[constraint]]
|
||||
branch = "master"
|
||||
name = "golang.org/x/oauth2"
|
||||
|
|
|
@ -6,3 +6,4 @@
|
|||
Icon?
|
||||
ehthumbs.db
|
||||
Thumbs.db
|
||||
.idea
|
||||
|
|
|
@ -1,13 +1,107 @@
|
|||
sudo: false
|
||||
language: go
|
||||
go:
|
||||
- 1.2
|
||||
- 1.3
|
||||
- 1.4
|
||||
- 1.5
|
||||
- 1.6
|
||||
- 1.7
|
||||
- tip
|
||||
- 1.7.x
|
||||
- 1.8.x
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
- master
|
||||
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
|
||||
before_script:
|
||||
- echo -e "[server]\ninnodb_log_file_size=256MB\ninnodb_buffer_pool_size=512MB\nmax_allowed_packet=16MB" | sudo tee -a /etc/mysql/my.cnf
|
||||
- sudo service mysql restart
|
||||
- .travis/wait_mysql.sh
|
||||
- mysql -e 'create database gotest;'
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- env: DB=MYSQL8
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mysql:8.0
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mysql:8.0 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
- env: DB=MYSQL57
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mysql:5.7
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mysql:5.7 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
- env: DB=MARIA55
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mariadb:5.5
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mariadb:5.5 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
- env: DB=MARIA10_1
|
||||
sudo: required
|
||||
dist: trusty
|
||||
go: 1.10.x
|
||||
services:
|
||||
- docker
|
||||
before_install:
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
- go get github.com/mattn/goveralls
|
||||
- docker pull mariadb:10.1
|
||||
- docker run -d -p 127.0.0.1:3307:3306 --name mysqld -e MYSQL_DATABASE=gotest -e MYSQL_USER=gotest -e MYSQL_PASSWORD=secret -e MYSQL_ROOT_PASSWORD=verysecret
|
||||
mariadb:10.1 --innodb_log_file_size=256MB --innodb_buffer_pool_size=512MB --max_allowed_packet=16MB --local-infile=1
|
||||
- cp .travis/docker.cnf ~/.my.cnf
|
||||
- .travis/wait_mysql.sh
|
||||
before_script:
|
||||
- export MYSQL_TEST_USER=gotest
|
||||
- export MYSQL_TEST_PASS=secret
|
||||
- export MYSQL_TEST_ADDR=127.0.0.1:3307
|
||||
- export MYSQL_TEST_CONCURRENT=1
|
||||
|
||||
script:
|
||||
- go test -v -covermode=count -coverprofile=coverage.out
|
||||
- go vet ./...
|
||||
- .travis/gofmt.sh
|
||||
after_script:
|
||||
- $HOME/gopath/bin/goveralls -coverprofile=coverage.out -service=travis-ci
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
[client]
|
||||
user = gotest
|
||||
password = secret
|
||||
host = 127.0.0.1
|
||||
port = 3307
|
|
@ -12,35 +12,63 @@
|
|||
# Individual Persons
|
||||
|
||||
Aaron Hopkins <go-sql-driver at die.net>
|
||||
Achille Roussel <achille.roussel at gmail.com>
|
||||
Alexey Palazhchenko <alexey.palazhchenko at gmail.com>
|
||||
Andrew Reid <andrew.reid at tixtrack.com>
|
||||
Arne Hormann <arnehormann at gmail.com>
|
||||
Asta Xie <xiemengjun at gmail.com>
|
||||
Bulat Gaifullin <gaifullinbf at gmail.com>
|
||||
Carlos Nieto <jose.carlos at menteslibres.net>
|
||||
Chris Moos <chris at tech9computers.com>
|
||||
Craig Wilson <craiggwilson at gmail.com>
|
||||
Daniel Montoya <dsmontoyam at gmail.com>
|
||||
Daniel Nichter <nil at codenode.com>
|
||||
Daniël van Eeden <git at myname.nl>
|
||||
Dave Protasowski <dprotaso at gmail.com>
|
||||
DisposaBoy <disposaboy at dby.me>
|
||||
Egor Smolyakov <egorsmkv at gmail.com>
|
||||
Evan Shaw <evan at vendhq.com>
|
||||
Frederick Mayle <frederickmayle at gmail.com>
|
||||
Gustavo Kristic <gkristic at gmail.com>
|
||||
Hajime Nakagami <nakagami at gmail.com>
|
||||
Hanno Braun <mail at hannobraun.com>
|
||||
Henri Yandell <flamefew at gmail.com>
|
||||
Hirotaka Yamamoto <ymmt2005 at gmail.com>
|
||||
ICHINOSE Shogo <shogo82148 at gmail.com>
|
||||
INADA Naoki <songofacandy at gmail.com>
|
||||
Jacek Szwec <szwec.jacek at gmail.com>
|
||||
James Harr <james.harr at gmail.com>
|
||||
Jeff Hodges <jeff at somethingsimilar.com>
|
||||
Jeffrey Charles <jeffreycharles at gmail.com>
|
||||
Jian Zhen <zhenjl at gmail.com>
|
||||
Joshua Prunier <joshua.prunier at gmail.com>
|
||||
Julien Lefevre <julien.lefevr at gmail.com>
|
||||
Julien Schmidt <go-sql-driver at julienschmidt.com>
|
||||
Justin Li <jli at j-li.net>
|
||||
Justin Nuß <nuss.justin at gmail.com>
|
||||
Kamil Dziedzic <kamil at klecza.pl>
|
||||
Kevin Malachowski <kevin at chowski.com>
|
||||
Kieron Woodhouse <kieron.woodhouse at infosum.com>
|
||||
Lennart Rudolph <lrudolph at hmc.edu>
|
||||
Leonardo YongUk Kim <dalinaum at gmail.com>
|
||||
Linh Tran Tuan <linhduonggnu at gmail.com>
|
||||
Lion Yang <lion at aosc.xyz>
|
||||
Luca Looz <luca.looz92 at gmail.com>
|
||||
Lucas Liu <extrafliu at gmail.com>
|
||||
Luke Scott <luke at webconnex.com>
|
||||
Maciej Zimnoch <maciej.zimnoch at codilime.com>
|
||||
Michael Woolnough <michael.woolnough at gmail.com>
|
||||
Nicola Peduzzi <thenikso at gmail.com>
|
||||
Olivier Mengué <dolmen at cpan.org>
|
||||
oscarzhao <oscarzhaosl at gmail.com>
|
||||
Paul Bonser <misterpib at gmail.com>
|
||||
Peter Schultz <peter.schultz at classmarkets.com>
|
||||
Rebecca Chin <rchin at pivotal.io>
|
||||
Reed Allman <rdallman10 at gmail.com>
|
||||
Richard Wilkes <wilkes at me.com>
|
||||
Robert Russell <robert at rrbrussell.com>
|
||||
Runrioter Wung <runrioter at gmail.com>
|
||||
Shuode Li <elemount at qq.com>
|
||||
Soroush Pour <me at soroushjp.com>
|
||||
Stan Putrya <root.vagner at gmail.com>
|
||||
Stanley Gunawan <gunawan.stanley at gmail.com>
|
||||
|
@ -52,5 +80,10 @@ Zhenye Xie <xiezhenye at gmail.com>
|
|||
# Organizations
|
||||
|
||||
Barracuda Networks, Inc.
|
||||
Counting Ltd.
|
||||
Google Inc.
|
||||
InfoSum Ltd.
|
||||
Keybase Inc.
|
||||
Percona LLC
|
||||
Pivotal Inc.
|
||||
Stripe Inc.
|
||||
|
|
|
@ -1,3 +1,51 @@
|
|||
## Version 1.4 (2018-06-03)
|
||||
|
||||
Changes:
|
||||
|
||||
- Documentation fixes (#530, #535, #567)
|
||||
- Refactoring (#575, #579, #580, #581, #603, #615, #704)
|
||||
- Cache column names (#444)
|
||||
- Sort the DSN parameters in DSNs generated from a config (#637)
|
||||
- Allow native password authentication by default (#644)
|
||||
- Use the default port if it is missing in the DSN (#668)
|
||||
- Removed the `strict` mode (#676)
|
||||
- Do not query `max_allowed_packet` by default (#680)
|
||||
- Dropped support Go 1.6 and lower (#696)
|
||||
- Updated `ConvertValue()` to match the database/sql/driver implementation (#760)
|
||||
- Document the usage of `0000-00-00T00:00:00` as the time.Time zero value (#783)
|
||||
- Improved the compatibility of the authentication system (#807)
|
||||
|
||||
New Features:
|
||||
|
||||
- Multi-Results support (#537)
|
||||
- `rejectReadOnly` DSN option (#604)
|
||||
- `context.Context` support (#608, #612, #627, #761)
|
||||
- Transaction isolation level support (#619, #744)
|
||||
- Read-Only transactions support (#618, #634)
|
||||
- `NewConfig` function which initializes a config with default values (#679)
|
||||
- Implemented the `ColumnType` interfaces (#667, #724)
|
||||
- Support for custom string types in `ConvertValue` (#623)
|
||||
- Implemented `NamedValueChecker`, improving support for uint64 with high bit set (#690, #709, #710)
|
||||
- `caching_sha2_password` authentication plugin support (#794, #800, #801, #802)
|
||||
- Implemented `driver.SessionResetter` (#779)
|
||||
- `sha256_password` authentication plugin support (#808)
|
||||
|
||||
Bugfixes:
|
||||
|
||||
- Use the DSN hostname as TLS default ServerName if `tls=true` (#564, #718)
|
||||
- Fixed LOAD LOCAL DATA INFILE for empty files (#590)
|
||||
- Removed columns definition cache since it sometimes cached invalid data (#592)
|
||||
- Don't mutate registered TLS configs (#600)
|
||||
- Make RegisterTLSConfig concurrency-safe (#613)
|
||||
- Handle missing auth data in the handshake packet correctly (#646)
|
||||
- Do not retry queries when data was written to avoid data corruption (#302, #736)
|
||||
- Cache the connection pointer for error handling before invalidating it (#678)
|
||||
- Fixed imports for appengine/cloudsql (#700)
|
||||
- Fix sending STMT_LONG_DATA for 0 byte data (#734)
|
||||
- Set correct capacity for []bytes read from length-encoded strings (#766)
|
||||
- Make RegisterDial concurrency-safe (#773)
|
||||
|
||||
|
||||
## Version 1.3 (2016-12-01)
|
||||
|
||||
Changes:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Go-MySQL-Driver
|
||||
|
||||
A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) package
|
||||
A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) package
|
||||
|
||||
![Go-MySQL-Driver logo](https://raw.github.com/wiki/go-sql-driver/mysql/gomysql_m.png "Golang Gopher holding the MySQL Dolphin")
|
||||
|
||||
|
@ -15,6 +15,9 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa
|
|||
* [Address](#address)
|
||||
* [Parameters](#parameters)
|
||||
* [Examples](#examples)
|
||||
* [Connection pool and timeouts](#connection-pool-and-timeouts)
|
||||
* [context.Context Support](#contextcontext-support)
|
||||
* [ColumnType Support](#columntype-support)
|
||||
* [LOAD DATA LOCAL INFILE support](#load-data-local-infile-support)
|
||||
* [time.Time support](#timetime-support)
|
||||
* [Unicode support](#unicode-support)
|
||||
|
@ -26,31 +29,31 @@ A MySQL-Driver for Go's [database/sql](http://golang.org/pkg/database/sql) packa
|
|||
## Features
|
||||
* Lightweight and [fast](https://github.com/go-sql-driver/sql-benchmark "golang MySQL-Driver performance")
|
||||
* Native Go implementation. No C-bindings, just pure Go
|
||||
* Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](http://godoc.org/github.com/go-sql-driver/mysql#DialFunc)
|
||||
* Connections over TCP/IPv4, TCP/IPv6, Unix domain sockets or [custom protocols](https://godoc.org/github.com/go-sql-driver/mysql#DialFunc)
|
||||
* Automatic handling of broken connections
|
||||
* Automatic Connection Pooling *(by database/sql package)*
|
||||
* Supports queries larger than 16MB
|
||||
* Full [`sql.RawBytes`](http://golang.org/pkg/database/sql/#RawBytes) support.
|
||||
* Full [`sql.RawBytes`](https://golang.org/pkg/database/sql/#RawBytes) support.
|
||||
* Intelligent `LONG DATA` handling in prepared statements
|
||||
* Secure `LOAD DATA LOCAL INFILE` support with file Whitelisting and `io.Reader` support
|
||||
* Optional `time.Time` parsing
|
||||
* Optional placeholder interpolation
|
||||
|
||||
## Requirements
|
||||
* Go 1.2 or higher
|
||||
* Go 1.7 or higher. We aim to support the 3 latest versions of Go.
|
||||
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
|
||||
|
||||
---------------------------------------
|
||||
|
||||
## Installation
|
||||
Simple install the package to your [$GOPATH](http://code.google.com/p/go-wiki/wiki/GOPATH "GOPATH") with the [go tool](http://golang.org/cmd/go/ "go command") from shell:
|
||||
Simple install the package to your [$GOPATH](https://github.com/golang/go/wiki/GOPATH "GOPATH") with the [go tool](https://golang.org/cmd/go/ "go command") from shell:
|
||||
```bash
|
||||
$ go get github.com/go-sql-driver/mysql
|
||||
$ go get -u github.com/go-sql-driver/mysql
|
||||
```
|
||||
Make sure [Git is installed](http://git-scm.com/downloads) on your machine and in your system's `PATH`.
|
||||
Make sure [Git is installed](https://git-scm.com/downloads) on your machine and in your system's `PATH`.
|
||||
|
||||
## Usage
|
||||
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](http://golang.org/pkg/database/sql) API then.
|
||||
_Go MySQL Driver_ is an implementation of Go's `database/sql/driver` interface. You only need to import the driver and can use the full [`database/sql`](https://golang.org/pkg/database/sql/) API then.
|
||||
|
||||
Use `mysql` as `driverName` and a valid [DSN](#dsn-data-source-name) as `dataSourceName`:
|
||||
```go
|
||||
|
@ -95,13 +98,14 @@ Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mys
|
|||
Passwords can consist of any character. Escaping is **not** necessary.
|
||||
|
||||
#### Protocol
|
||||
See [net.Dial](http://golang.org/pkg/net/#Dial) for more information which networks are available.
|
||||
See [net.Dial](https://golang.org/pkg/net/#Dial) for more information which networks are available.
|
||||
In general you should use an Unix domain socket if available and TCP otherwise for best performance.
|
||||
|
||||
#### Address
|
||||
For TCP and UDP networks, addresses have the form `host:port`.
|
||||
For TCP and UDP networks, addresses have the form `host[:port]`.
|
||||
If `port` is omitted, the default port will be used.
|
||||
If `host` is a literal IPv6 address, it must be enclosed in square brackets.
|
||||
The functions [net.JoinHostPort](http://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](http://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form.
|
||||
The functions [net.JoinHostPort](https://golang.org/pkg/net/#JoinHostPort) and [net.SplitHostPort](https://golang.org/pkg/net/#SplitHostPort) manipulate addresses in this form.
|
||||
|
||||
For Unix domain sockets the address is the absolute path to the MySQL-Server-socket, e.g. `/var/run/mysqld/mysqld.sock` or `/tmp/mysql.sock`.
|
||||
|
||||
|
@ -136,9 +140,9 @@ Default: false
|
|||
```
|
||||
Type: bool
|
||||
Valid Values: true, false
|
||||
Default: false
|
||||
Default: true
|
||||
```
|
||||
`allowNativePasswords=true` allows the usage of the mysql native password method.
|
||||
`allowNativePasswords=false` disallows the usage of MySQL native password method.
|
||||
|
||||
##### `allowOldPasswords`
|
||||
|
||||
|
@ -220,19 +224,19 @@ Valid Values: <escaped name>
|
|||
Default: UTC
|
||||
```
|
||||
|
||||
Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](http://golang.org/pkg/time/#LoadLocation) for details.
|
||||
Sets the location for time.Time values (when using `parseTime=true`). *"Local"* sets the system's location. See [time.LoadLocation](https://golang.org/pkg/time/#LoadLocation) for details.
|
||||
|
||||
Note that this sets the location for time.Time values but does not change MySQL's [time_zone setting](https://dev.mysql.com/doc/refman/5.5/en/time-zone-support.html). For that see the [time_zone system variable](#system-variables), which can also be set as a DSN parameter.
|
||||
|
||||
Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
|
||||
Please keep in mind, that param values must be [url.QueryEscape](https://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
|
||||
|
||||
##### `maxAllowedPacket`
|
||||
```
|
||||
Type: decimal number
|
||||
Default: 0
|
||||
Default: 4194304
|
||||
```
|
||||
|
||||
Max packet size allowed in bytes. Use `maxAllowedPacket=0` to automatically fetch the `max_allowed_packet` variable from server.
|
||||
Max packet size allowed in bytes. The default value is 4 MiB and should be adjusted to match the server settings. `maxAllowedPacket=0` can be used to automatically fetch the `max_allowed_packet` variable from server *on every connection*.
|
||||
|
||||
##### `multiStatements`
|
||||
|
||||
|
@ -255,18 +259,19 @@ Default: false
|
|||
```
|
||||
|
||||
`parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
|
||||
The date or datetime like `0000-00-00 00:00:00` is converted into zero value of `time.Time`.
|
||||
|
||||
|
||||
##### `readTimeout`
|
||||
|
||||
```
|
||||
Type: decimal number
|
||||
Type: duration
|
||||
Default: 0
|
||||
```
|
||||
|
||||
I/O read timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
I/O read timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
|
||||
##### `strict`
|
||||
##### `rejectReadOnly`
|
||||
|
||||
```
|
||||
Type: bool
|
||||
|
@ -274,20 +279,50 @@ Valid Values: true, false
|
|||
Default: false
|
||||
```
|
||||
|
||||
`strict=true` enables a driver-side strict mode in which MySQL warnings are treated as errors. This mode should not be used in production as it may lead to data corruption in certain situations.
|
||||
|
||||
A server-side strict mode, which is safe for production use, can be set via the [`sql_mode`](https://dev.mysql.com/doc/refman/5.7/en/sql-mode.html) system variable.
|
||||
`rejectReadOnly=true` causes the driver to reject read-only connections. This
|
||||
is for a possible race condition during an automatic failover, where the mysql
|
||||
client gets connected to a read-only replica after the failover.
|
||||
|
||||
Note that this should be a fairly rare case, as an automatic failover normally
|
||||
happens when the primary is down, and the race condition shouldn't happen
|
||||
unless it comes back up online as soon as the failover is kicked off. On the
|
||||
other hand, when this happens, a MySQL application can get stuck on a
|
||||
read-only connection until restarted. It is however fairly easy to reproduce,
|
||||
for example, using a manual failover on AWS Aurora's MySQL-compatible cluster.
|
||||
|
||||
If you are not relying on read-only transactions to reject writes that aren't
|
||||
supposed to happen, setting this on some MySQL providers (such as AWS Aurora)
|
||||
is safer for failovers.
|
||||
|
||||
Note that ERROR 1290 can be returned for a `read-only` server and this option will
|
||||
cause a retry for that error. However the same error number is used for some
|
||||
other cases. You should ensure your application will never cause an ERROR 1290
|
||||
except for `read-only` mode when enabling this option.
|
||||
|
||||
|
||||
##### `serverPubKey`
|
||||
|
||||
```
|
||||
Type: string
|
||||
Valid Values: <name>
|
||||
Default: none
|
||||
```
|
||||
|
||||
Server public keys can be registered with [`mysql.RegisterServerPubKey`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterServerPubKey), which can then be used by the assigned name in the DSN.
|
||||
Public keys are used to transmit encrypted data, e.g. for authentication.
|
||||
If the server's public key is known, it should be set manually to avoid expensive and potentially insecure transmissions of the public key from the server to the client each time it is required.
|
||||
|
||||
By default MySQL also treats notes as warnings. Use [`sql_notes=false`](http://dev.mysql.com/doc/refman/5.7/en/server-system-variables.html#sysvar_sql_notes) to ignore notes.
|
||||
|
||||
##### `timeout`
|
||||
|
||||
```
|
||||
Type: decimal number
|
||||
Type: duration
|
||||
Default: OS default
|
||||
```
|
||||
|
||||
*Driver* side connection timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
|
||||
Timeout for establishing connections, aka dial timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
|
||||
|
||||
##### `tls`
|
||||
|
||||
|
@ -297,16 +332,17 @@ Valid Values: true, false, skip-verify, <name>
|
|||
Default: false
|
||||
```
|
||||
|
||||
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](http://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
|
||||
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
|
||||
|
||||
|
||||
##### `writeTimeout`
|
||||
|
||||
```
|
||||
Type: decimal number
|
||||
Type: duration
|
||||
Default: 0
|
||||
```
|
||||
|
||||
I/O write timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
I/O write timeout. The value must be a decimal number with a unit suffix (*"ms"*, *"s"*, *"m"*, *"h"*), such as *"30s"*, *"0.5m"* or *"1m30s"*.
|
||||
|
||||
|
||||
##### System Variables
|
||||
|
@ -317,9 +353,9 @@ Any other parameters are interpreted as system variables:
|
|||
* `<string_var>=%27<value>%27`: `SET <string_var>='<value>'`
|
||||
|
||||
Rules:
|
||||
* The values for string variables must be quoted with '
|
||||
* The values for string variables must be quoted with `'`.
|
||||
* The values must also be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed!
|
||||
(which implies values of string variables must be wrapped with `%27`)
|
||||
(which implies values of string variables must be wrapped with `%27`).
|
||||
|
||||
Examples:
|
||||
* `autocommit=1`: `SET autocommit=1`
|
||||
|
@ -380,6 +416,18 @@ No Database preselected:
|
|||
user:password@/
|
||||
```
|
||||
|
||||
|
||||
### Connection pool and timeouts
|
||||
The connection pool is managed by Go's database/sql package. For details on how to configure the size of the pool and how long connections stay in the pool see `*DB.SetMaxOpenConns`, `*DB.SetMaxIdleConns`, and `*DB.SetConnMaxLifetime` in the [database/sql documentation](https://golang.org/pkg/database/sql/). The read, write, and dial timeouts for each individual connection are configured with the DSN parameters [`readTimeout`](#readtimeout), [`writeTimeout`](#writetimeout), and [`timeout`](#timeout), respectively.
|
||||
|
||||
## `ColumnType` Support
|
||||
This driver supports the [`ColumnType` interface](https://golang.org/pkg/database/sql/#ColumnType) introduced in Go 1.8, with the exception of [`ColumnType.Length()`](https://golang.org/pkg/database/sql/#ColumnType.Length), which is currently not supported.
|
||||
|
||||
## `context.Context` Support
|
||||
Go 1.8 added `database/sql` support for `context.Context`. This driver supports query timeouts and cancellation via contexts.
|
||||
See [context support in the database/sql package](https://golang.org/doc/go1.8#database_sql) for more details.
|
||||
|
||||
|
||||
### `LOAD DATA LOCAL INFILE` support
|
||||
For this feature you need direct access to the package. Therefore you must change the import path (no `_`):
|
||||
```go
|
||||
|
@ -390,17 +438,17 @@ Files must be whitelisted by registering them with `mysql.RegisterLocalFile(file
|
|||
|
||||
To use a `io.Reader` a handler function must be registered with `mysql.RegisterReaderHandler(name, handler)` which returns a `io.Reader` or `io.ReadCloser`. The Reader is available with the filepath `Reader::<name>` then. Choose different names for different handlers and `DeregisterReaderHandler` when you don't need it anymore.
|
||||
|
||||
See the [godoc of Go-MySQL-Driver](http://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details.
|
||||
See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/mysql "golang mysql driver documentation") for details.
|
||||
|
||||
|
||||
### `time.Time` support
|
||||
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your programm.
|
||||
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program.
|
||||
|
||||
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](http://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
|
||||
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
|
||||
|
||||
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).
|
||||
|
||||
Alternatively you can use the [`NullTime`](http://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`.
|
||||
Alternatively you can use the [`NullTime`](https://godoc.org/github.com/go-sql-driver/mysql#NullTime) type as the scan destination, which works with both `time.Time` and `string` / `[]byte`.
|
||||
|
||||
|
||||
### Unicode support
|
||||
|
@ -412,7 +460,6 @@ Version 1.0 of the driver recommended adding `&charset=utf8` (alias for `SET NAM
|
|||
|
||||
See http://dev.mysql.com/doc/refman/5.7/en/charset-unicode.html for more details on MySQL's Unicode support.
|
||||
|
||||
|
||||
## Testing / Development
|
||||
To run the driver tests you may need to adjust the configuration. See the [Testing Wiki-Page](https://github.com/go-sql-driver/mysql/wiki/Testing "Testing") for details.
|
||||
|
||||
|
@ -431,13 +478,13 @@ Mozilla summarizes the license scope as follows:
|
|||
|
||||
|
||||
That means:
|
||||
* You can **use** the **unchanged** source code both in private and commercially
|
||||
* When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0)
|
||||
* You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**
|
||||
* You can **use** the **unchanged** source code both in private and commercially.
|
||||
* When distributing, you **must publish** the source code of any **changed files** licensed under the MPL 2.0 under a) the MPL 2.0 itself or b) a compatible license (e.g. GPL 3.0 or Apache License 2.0).
|
||||
* You **needn't publish** the source code of your library as long as the files licensed under the MPL 2.0 are **unchanged**.
|
||||
|
||||
Please read the [MPL 2.0 FAQ](http://www.mozilla.org/MPL/2.0/FAQ.html) if you have further questions regarding the license.
|
||||
Please read the [MPL 2.0 FAQ](https://www.mozilla.org/en-US/MPL/2.0/FAQ/) if you have further questions regarding the license.
|
||||
|
||||
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE)
|
||||
You can read the full terms here: [LICENSE](https://raw.github.com/go-sql-driver/mysql/master/LICENSE).
|
||||
|
||||
![Go Gopher and MySQL Dolphin](https://raw.github.com/wiki/go-sql-driver/mysql/go-mysql-driver_m.jpg "Golang Gopher transporting the MySQL Dolphin in a wheelbarrow")
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"appengine/cloudsql"
|
||||
"google.golang.org/appengine/cloudsql"
|
||||
)
|
||||
|
||||
func init() {
|
||||
|
|
|
@ -0,0 +1,420 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// server pub keys registry
|
||||
var (
|
||||
serverPubKeyLock sync.RWMutex
|
||||
serverPubKeyRegistry map[string]*rsa.PublicKey
|
||||
)
|
||||
|
||||
// RegisterServerPubKey registers a server RSA public key which can be used to
|
||||
// send data in a secure manner to the server without receiving the public key
|
||||
// in a potentially insecure way from the server first.
|
||||
// Registered keys can afterwards be used adding serverPubKey=<name> to the DSN.
|
||||
//
|
||||
// Note: The provided rsa.PublicKey instance is exclusively owned by the driver
|
||||
// after registering it and may not be modified.
|
||||
//
|
||||
// data, err := ioutil.ReadFile("mykey.pem")
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// block, _ := pem.Decode(data)
|
||||
// if block == nil || block.Type != "PUBLIC KEY" {
|
||||
// log.Fatal("failed to decode PEM block containing public key")
|
||||
// }
|
||||
//
|
||||
// pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
// if err != nil {
|
||||
// log.Fatal(err)
|
||||
// }
|
||||
//
|
||||
// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok {
|
||||
// mysql.RegisterServerPubKey("mykey", rsaPubKey)
|
||||
// } else {
|
||||
// log.Fatal("not a RSA public key")
|
||||
// }
|
||||
//
|
||||
func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) {
|
||||
serverPubKeyLock.Lock()
|
||||
if serverPubKeyRegistry == nil {
|
||||
serverPubKeyRegistry = make(map[string]*rsa.PublicKey)
|
||||
}
|
||||
|
||||
serverPubKeyRegistry[name] = pubKey
|
||||
serverPubKeyLock.Unlock()
|
||||
}
|
||||
|
||||
// DeregisterServerPubKey removes the public key registered with the given name.
|
||||
func DeregisterServerPubKey(name string) {
|
||||
serverPubKeyLock.Lock()
|
||||
if serverPubKeyRegistry != nil {
|
||||
delete(serverPubKeyRegistry, name)
|
||||
}
|
||||
serverPubKeyLock.Unlock()
|
||||
}
|
||||
|
||||
func getServerPubKey(name string) (pubKey *rsa.PublicKey) {
|
||||
serverPubKeyLock.RLock()
|
||||
if v, ok := serverPubKeyRegistry[name]; ok {
|
||||
pubKey = v
|
||||
}
|
||||
serverPubKeyLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Hash password using pre 4.1 (old password) method
|
||||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
|
||||
type myRnd struct {
|
||||
seed1, seed2 uint32
|
||||
}
|
||||
|
||||
const myRndMaxVal = 0x3FFFFFFF
|
||||
|
||||
// Pseudo random number generator
|
||||
func newMyRnd(seed1, seed2 uint32) *myRnd {
|
||||
return &myRnd{
|
||||
seed1: seed1 % myRndMaxVal,
|
||||
seed2: seed2 % myRndMaxVal,
|
||||
}
|
||||
}
|
||||
|
||||
// Tested to be equivalent to MariaDB's floating point variant
|
||||
// http://play.golang.org/p/QHvhd4qved
|
||||
// http://play.golang.org/p/RG0q4ElWDx
|
||||
func (r *myRnd) NextByte() byte {
|
||||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
|
||||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
|
||||
|
||||
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
|
||||
}
|
||||
|
||||
// Generate binary hash from byte string using insecure pre 4.1 method
|
||||
func pwHash(password []byte) (result [2]uint32) {
|
||||
var add uint32 = 7
|
||||
var tmp uint32
|
||||
|
||||
result[0] = 1345345333
|
||||
result[1] = 0x12345671
|
||||
|
||||
for _, c := range password {
|
||||
// skip spaces and tabs in password
|
||||
if c == ' ' || c == '\t' {
|
||||
continue
|
||||
}
|
||||
|
||||
tmp = uint32(c)
|
||||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
|
||||
result[1] += (result[1] << 8) ^ result[0]
|
||||
add += tmp
|
||||
}
|
||||
|
||||
// Remove sign bit (1<<31)-1)
|
||||
result[0] &= 0x7FFFFFFF
|
||||
result[1] &= 0x7FFFFFFF
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Hash password using insecure pre 4.1 method
|
||||
func scrambleOldPassword(scramble []byte, password string) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
scramble = scramble[:8]
|
||||
|
||||
hashPw := pwHash([]byte(password))
|
||||
hashSc := pwHash(scramble)
|
||||
|
||||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
|
||||
|
||||
var out [8]byte
|
||||
for i := range out {
|
||||
out[i] = r.NextByte() + 64
|
||||
}
|
||||
|
||||
mask := r.NextByte()
|
||||
for i := range out {
|
||||
out[i] ^= mask
|
||||
}
|
||||
|
||||
return out[:]
|
||||
}
|
||||
|
||||
// Hash password using 4.1+ method (SHA1)
|
||||
func scramblePassword(scramble []byte, password string) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stage1Hash = SHA1(password)
|
||||
crypt := sha1.New()
|
||||
crypt.Write([]byte(password))
|
||||
stage1 := crypt.Sum(nil)
|
||||
|
||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
||||
// inner Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(stage1)
|
||||
hash := crypt.Sum(nil)
|
||||
|
||||
// outer Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(scramble)
|
||||
crypt.Write(hash)
|
||||
scramble = crypt.Sum(nil)
|
||||
|
||||
// token = scrambleHash XOR stage1Hash
|
||||
for i := range scramble {
|
||||
scramble[i] ^= stage1[i]
|
||||
}
|
||||
return scramble
|
||||
}
|
||||
|
||||
// Hash password using MySQL 8+ method (SHA256)
|
||||
func scrambleSHA256Password(scramble []byte, password string) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble))
|
||||
|
||||
crypt := sha256.New()
|
||||
crypt.Write([]byte(password))
|
||||
message1 := crypt.Sum(nil)
|
||||
|
||||
crypt.Reset()
|
||||
crypt.Write(message1)
|
||||
message1Hash := crypt.Sum(nil)
|
||||
|
||||
crypt.Reset()
|
||||
crypt.Write(message1Hash)
|
||||
crypt.Write(scramble)
|
||||
message2 := crypt.Sum(nil)
|
||||
|
||||
for i := range message1 {
|
||||
message1[i] ^= message2[i]
|
||||
}
|
||||
|
||||
return message1
|
||||
}
|
||||
|
||||
func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) {
|
||||
plain := make([]byte, len(password)+1)
|
||||
copy(plain, password)
|
||||
for i := range plain {
|
||||
j := i % len(seed)
|
||||
plain[i] ^= seed[j]
|
||||
}
|
||||
sha1 := sha1.New()
|
||||
return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error {
|
||||
enc, err := encryptPassword(mc.cfg.Passwd, seed, pub)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.writeAuthSwitchPacket(enc, false)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
|
||||
switch plugin {
|
||||
case "caching_sha2_password":
|
||||
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
|
||||
return authResp, (authResp == nil), nil
|
||||
|
||||
case "mysql_old_password":
|
||||
if !mc.cfg.AllowOldPasswords {
|
||||
return nil, false, ErrOldPassword
|
||||
}
|
||||
// Note: there are edge cases where this should work but doesn't;
|
||||
// this is currently "wontfix":
|
||||
// https://github.com/go-sql-driver/mysql/issues/184
|
||||
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
|
||||
return authResp, true, nil
|
||||
|
||||
case "mysql_clear_password":
|
||||
if !mc.cfg.AllowCleartextPasswords {
|
||||
return nil, false, ErrCleartextPassword
|
||||
}
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
|
||||
return []byte(mc.cfg.Passwd), true, nil
|
||||
|
||||
case "mysql_native_password":
|
||||
if !mc.cfg.AllowNativePasswords {
|
||||
return nil, false, ErrNativePassword
|
||||
}
|
||||
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
|
||||
// Native password authentication only need and will need 20-byte challenge.
|
||||
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
|
||||
return authResp, false, nil
|
||||
|
||||
case "sha256_password":
|
||||
if len(mc.cfg.Passwd) == 0 {
|
||||
return nil, true, nil
|
||||
}
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
return []byte(mc.cfg.Passwd), true, nil
|
||||
}
|
||||
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
return []byte{1}, false, nil
|
||||
}
|
||||
|
||||
// encrypted password
|
||||
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
|
||||
return enc, false, err
|
||||
|
||||
default:
|
||||
errLog.Print("unknown auth plugin:", plugin)
|
||||
return nil, false, ErrUnknownPlugin
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
|
||||
// Read Result Packet
|
||||
authData, newPlugin, err := mc.readAuthResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// handle auth plugin switch, if requested
|
||||
if newPlugin != "" {
|
||||
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
|
||||
// sent and we have to keep using the cipher sent in the init packet.
|
||||
if authData == nil {
|
||||
authData = oldAuthData
|
||||
} else {
|
||||
// copy data from read buffer to owned slice
|
||||
copy(oldAuthData, authData)
|
||||
}
|
||||
|
||||
plugin = newPlugin
|
||||
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read Result Packet
|
||||
authData, newPlugin, err = mc.readAuthResult()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Do not allow to change the auth plugin more than once
|
||||
if newPlugin != "" {
|
||||
return ErrMalformPkt
|
||||
}
|
||||
}
|
||||
|
||||
switch plugin {
|
||||
|
||||
// https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/
|
||||
case "caching_sha2_password":
|
||||
switch len(authData) {
|
||||
case 0:
|
||||
return nil // auth successful
|
||||
case 1:
|
||||
switch authData[0] {
|
||||
case cachingSha2PasswordFastAuthSuccess:
|
||||
if err = mc.readResultOK(); err == nil {
|
||||
return nil // auth successful
|
||||
}
|
||||
|
||||
case cachingSha2PasswordPerformFullAuthentication:
|
||||
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
|
||||
// write cleartext auth packet
|
||||
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
pubKey := mc.cfg.pubKey
|
||||
if pubKey == nil {
|
||||
// request public key from server
|
||||
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||
data[4] = cachingSha2PasswordRequestPublicKey
|
||||
mc.writePacket(data)
|
||||
|
||||
// parse public key
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
block, _ := pem.Decode(data[1:])
|
||||
pkix, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pubKey = pkix.(*rsa.PublicKey)
|
||||
}
|
||||
|
||||
// send encrypted password
|
||||
err = mc.sendEncryptedPassword(oldAuthData, pubKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return mc.readResultOK()
|
||||
|
||||
default:
|
||||
return ErrMalformPkt
|
||||
}
|
||||
default:
|
||||
return ErrMalformPkt
|
||||
}
|
||||
|
||||
case "sha256_password":
|
||||
switch len(authData) {
|
||||
case 0:
|
||||
return nil // auth successful
|
||||
default:
|
||||
block, _ := pem.Decode(authData)
|
||||
pub, err := x509.ParsePKIXPublicKey(block.Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// send encrypted password
|
||||
err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
default:
|
||||
return nil // auth successful
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,93 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
|
||||
|
||||
tb := (*TB)(b)
|
||||
stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
|
||||
defer stmt.Close()
|
||||
|
||||
b.SetParallelism(p)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
var got string
|
||||
for pb.Next() {
|
||||
tb.check(stmt.QueryRow(1).Scan(&got))
|
||||
if got != "one" {
|
||||
b.Fatalf("query = %q; want one", got)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkQueryContext(b *testing.B) {
|
||||
db := initDB(b,
|
||||
"DROP TABLE IF EXISTS foo",
|
||||
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
|
||||
`INSERT INTO foo VALUES (1, "one")`,
|
||||
`INSERT INTO foo VALUES (2, "two")`,
|
||||
)
|
||||
defer db.Close()
|
||||
for _, p := range []int{1, 2, 3, 4} {
|
||||
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
|
||||
benchmarkQueryContext(b, db, p)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
|
||||
|
||||
tb := (*TB)(b)
|
||||
stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
|
||||
defer stmt.Close()
|
||||
|
||||
b.SetParallelism(p)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
if _, err := stmt.ExecContext(ctx); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkExecContext(b *testing.B) {
|
||||
db := initDB(b,
|
||||
"DROP TABLE IF EXISTS foo",
|
||||
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
|
||||
`INSERT INTO foo VALUES (1, "one")`,
|
||||
`INSERT INTO foo VALUES (2, "two")`,
|
||||
)
|
||||
defer db.Close()
|
||||
for _, p := range []int{1, 2, 3, 4} {
|
||||
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
|
||||
benchmarkQueryContext(b, db, p)
|
||||
})
|
||||
}
|
||||
}
|
|
@ -48,11 +48,7 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
|
|||
db := tb.checkDB(sql.Open("mysql", dsn))
|
||||
for _, query := range queries {
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
if w, ok := err.(MySQLWarnings); ok {
|
||||
b.Logf("warning on %q: %v", query, w)
|
||||
} else {
|
||||
b.Fatalf("error on %q: %v", query, err)
|
||||
}
|
||||
b.Fatalf("error on %q: %v", query, err)
|
||||
}
|
||||
}
|
||||
return db
|
||||
|
|
|
@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte {
|
|||
// smaller than defaultBufSize
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeSmallBuffer(length int) []byte {
|
||||
if b.length == 0 {
|
||||
return b.buf[:length]
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return b.buf[:length]
|
||||
}
|
||||
|
||||
// takeCompleteBuffer returns the complete existing buffer.
|
||||
// This can be used if the necessary buffer size is unknown.
|
||||
// Only one buffer (total) can be used at a time.
|
||||
func (b *buffer) takeCompleteBuffer() []byte {
|
||||
if b.length == 0 {
|
||||
return b.buf
|
||||
if b.length > 0 {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return b.buf
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
package mysql
|
||||
|
||||
const defaultCollation = "utf8_general_ci"
|
||||
const binaryCollation = "binary"
|
||||
|
||||
// A list of available collations mapped to the internal ID.
|
||||
// To update this map use the following MySQL query:
|
||||
|
|
|
@ -10,12 +10,23 @@ package mysql
|
|||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// a copy of context.Context for Go 1.7 and earlier
|
||||
type mysqlContext interface {
|
||||
Done() <-chan struct{}
|
||||
Err() error
|
||||
|
||||
// defined in context.Context, but not used in this driver:
|
||||
// Deadline() (deadline time.Time, ok bool)
|
||||
// Value(key interface{}) interface{}
|
||||
}
|
||||
|
||||
type mysqlConn struct {
|
||||
buf buffer
|
||||
netConn net.Conn
|
||||
|
@ -29,7 +40,14 @@ type mysqlConn struct {
|
|||
status statusFlag
|
||||
sequence uint8
|
||||
parseTime bool
|
||||
strict bool
|
||||
|
||||
// for context support (Go 1.8+)
|
||||
watching bool
|
||||
watcher chan<- mysqlContext
|
||||
closech chan struct{}
|
||||
finished chan<- struct{}
|
||||
canceled atomicError // set non-nil if conn is canceled
|
||||
closed atomicBool // set when conn is closed, before closech is closed
|
||||
}
|
||||
|
||||
// Handles parameters set in DSN after the connection is established
|
||||
|
@ -62,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) markBadConn(err error) error {
|
||||
if mc == nil {
|
||||
return err
|
||||
}
|
||||
if err != errBadConnNoWrite {
|
||||
return err
|
||||
}
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Begin() (driver.Tx, error) {
|
||||
if mc.netConn == nil {
|
||||
return mc.begin(false)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
err := mc.exec("START TRANSACTION")
|
||||
var q string
|
||||
if readOnly {
|
||||
q = "START TRANSACTION READ ONLY"
|
||||
} else {
|
||||
q = "START TRANSACTION"
|
||||
}
|
||||
err := mc.exec(q)
|
||||
if err == nil {
|
||||
return &mysqlTx{mc}, err
|
||||
}
|
||||
|
||||
return nil, err
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Close() (err error) {
|
||||
// Makes Close idempotent
|
||||
if mc.netConn != nil {
|
||||
if !mc.closed.IsSet() {
|
||||
err = mc.writeCommandPacket(comQuit)
|
||||
}
|
||||
|
||||
|
@ -91,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) {
|
|||
// is called before auth or on auth failure because MySQL will have already
|
||||
// closed the network connection.
|
||||
func (mc *mysqlConn) cleanup() {
|
||||
// Makes cleanup idempotent
|
||||
if mc.netConn != nil {
|
||||
if err := mc.netConn.Close(); err != nil {
|
||||
errLog.Print(err)
|
||||
}
|
||||
mc.netConn = nil
|
||||
if !mc.closed.TrySet(true) {
|
||||
return
|
||||
}
|
||||
mc.cfg = nil
|
||||
mc.buf.nc = nil
|
||||
|
||||
// Makes cleanup idempotent
|
||||
close(mc.closech)
|
||||
if mc.netConn == nil {
|
||||
return
|
||||
}
|
||||
if err := mc.netConn.Close(); err != nil {
|
||||
errLog.Print(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) error() error {
|
||||
if mc.closed.IsSet() {
|
||||
if err := mc.canceled.Value(); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrInvalidConn
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
|
||||
if mc.netConn == nil {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comStmtPrepare, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
stmt := &mysqlStmt{
|
||||
|
@ -144,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
|
|||
if buf == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return "", driver.ErrBadConn
|
||||
return "", ErrInvalidConn
|
||||
}
|
||||
buf = buf[:0]
|
||||
argPos := 0
|
||||
|
@ -257,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
|
|||
}
|
||||
|
||||
func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) {
|
||||
if mc.netConn == nil {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
@ -271,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
|
|||
return nil, err
|
||||
}
|
||||
query = prepared
|
||||
args = nil
|
||||
}
|
||||
mc.affectedRows = 0
|
||||
mc.insertId = 0
|
||||
|
@ -283,32 +332,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
|
|||
insertId: int64(mc.insertId),
|
||||
}, err
|
||||
}
|
||||
return nil, err
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Internal function to execute commands
|
||||
func (mc *mysqlConn) exec(query string) error {
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comQuery, query)
|
||||
if err != nil {
|
||||
return err
|
||||
if err := mc.writeCommandPacketStr(comQuery, query); err != nil {
|
||||
return mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err == nil && resLen > 0 {
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
// columns
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mc.readUntilEOF()
|
||||
// rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
return mc.discardResults()
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
|
||||
if mc.netConn == nil {
|
||||
return mc.query(query, args)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
|
@ -322,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
|
|||
return nil, err
|
||||
}
|
||||
query = prepared
|
||||
args = nil
|
||||
}
|
||||
// Send command
|
||||
err := mc.writeCommandPacketStr(comQuery, query)
|
||||
|
@ -335,15 +394,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
|
|||
rows.mc = mc
|
||||
|
||||
if resLen == 0 {
|
||||
// no columns, no more data
|
||||
return emptyRows{}, nil
|
||||
rows.rs.done = true
|
||||
|
||||
switch err := rows.NextResultSet(); err {
|
||||
case nil, io.EOF:
|
||||
return rows, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Columns
|
||||
rows.columns, err = mc.readColumns(resLen)
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
return rows, err
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
return nil, mc.markBadConn(err)
|
||||
}
|
||||
|
||||
// Gets the value of the given MySQL System Variable
|
||||
|
@ -359,7 +425,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
|||
if err == nil {
|
||||
rows := new(textRows)
|
||||
rows.mc = mc
|
||||
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
||||
rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
|
||||
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
|
@ -375,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
|
|||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// finish is called when the query has canceled.
|
||||
func (mc *mysqlConn) cancel(err error) {
|
||||
mc.canceled.Set(err)
|
||||
mc.cleanup()
|
||||
}
|
||||
|
||||
// finish is called when the query has succeeded.
|
||||
func (mc *mysqlConn) finish() {
|
||||
if !mc.watching || mc.finished == nil {
|
||||
return
|
||||
}
|
||||
select {
|
||||
case mc.finished <- struct{}{}:
|
||||
mc.watching = false
|
||||
case <-mc.closech:
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,208 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
)
|
||||
|
||||
// Ping implements driver.Pinger interface
|
||||
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
|
||||
if mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
if err = mc.watchCancel(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
if err = mc.writeCommandPacket(comPing); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
// BeginTx implements driver.ConnBeginTx interface
|
||||
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
|
||||
level, err := mapIsolationLevel(opts.Isolation)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return mc.begin(opts.ReadOnly)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := mc.query(query, dargs)
|
||||
if err != nil {
|
||||
mc.finish()
|
||||
return nil, err
|
||||
}
|
||||
rows.finish = mc.finish
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer mc.finish()
|
||||
|
||||
return mc.Exec(query, dargs)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
|
||||
if err := mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmt, err := mc.Prepare(query)
|
||||
mc.finish()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
stmt.Close()
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
return stmt, nil
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := stmt.query(dargs)
|
||||
if err != nil {
|
||||
stmt.mc.finish()
|
||||
return nil, err
|
||||
}
|
||||
rows.finish = stmt.mc.finish
|
||||
return rows, err
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
|
||||
dargs, err := namedValueToValue(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := stmt.mc.watchCancel(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer stmt.mc.finish()
|
||||
|
||||
return stmt.Exec(dargs)
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
|
||||
if mc.watching {
|
||||
// Reach here if canceled,
|
||||
// so the connection is already invalid
|
||||
mc.cleanup()
|
||||
return nil
|
||||
}
|
||||
if ctx.Done() == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watching = true
|
||||
select {
|
||||
default:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
if mc.watcher == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
mc.watcher <- ctx
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) startWatcher() {
|
||||
watcher := make(chan mysqlContext, 1)
|
||||
mc.watcher = watcher
|
||||
finished := make(chan struct{})
|
||||
mc.finished = finished
|
||||
go func() {
|
||||
for {
|
||||
var ctx mysqlContext
|
||||
select {
|
||||
case ctx = <-watcher:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
mc.cancel(ctx.Err())
|
||||
case <-finished:
|
||||
case <-mc.closech:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
|
||||
nv.Value, err = converter{}.ConvertValue(nv.Value)
|
||||
return
|
||||
}
|
||||
|
||||
// ResetSession implements driver.SessionResetter.
|
||||
// (From Go 1.10)
|
||||
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
|
||||
if mc.closed.IsSet() {
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCheckNamedValue(t *testing.T) {
|
||||
value := driver.NamedValue{Value: ^uint64(0)}
|
||||
x := &mysqlConn{}
|
||||
err := x.CheckNamedValue(&value)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal("uint64 high-bit not convertible", err)
|
||||
}
|
||||
|
||||
if value.Value != "18446744073709551615" {
|
||||
t.Fatalf("uint64 high-bit not converted, got %#v %T", value.Value, value.Value)
|
||||
}
|
||||
}
|
|
@ -9,7 +9,9 @@
|
|||
package mysql
|
||||
|
||||
const (
|
||||
minProtocolVersion byte = 10
|
||||
defaultAuthPlugin = "mysql_native_password"
|
||||
defaultMaxAllowedPacket = 4 << 20 // 4 MiB
|
||||
minProtocolVersion = 10
|
||||
maxPacketSize = 1<<24 - 1
|
||||
timeFormat = "2006-01-02 15:04:05.999999"
|
||||
)
|
||||
|
@ -18,10 +20,11 @@ const (
|
|||
// http://dev.mysql.com/doc/internals/en/client-server-protocol.html
|
||||
|
||||
const (
|
||||
iOK byte = 0x00
|
||||
iLocalInFile byte = 0xfb
|
||||
iEOF byte = 0xfe
|
||||
iERR byte = 0xff
|
||||
iOK byte = 0x00
|
||||
iAuthMoreData byte = 0x01
|
||||
iLocalInFile byte = 0xfb
|
||||
iEOF byte = 0xfe
|
||||
iERR byte = 0xff
|
||||
)
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
|
||||
|
@ -87,8 +90,10 @@ const (
|
|||
)
|
||||
|
||||
// https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType
|
||||
type fieldType byte
|
||||
|
||||
const (
|
||||
fieldTypeDecimal byte = iota
|
||||
fieldTypeDecimal fieldType = iota
|
||||
fieldTypeTiny
|
||||
fieldTypeShort
|
||||
fieldTypeLong
|
||||
|
@ -107,7 +112,7 @@ const (
|
|||
fieldTypeBit
|
||||
)
|
||||
const (
|
||||
fieldTypeJSON byte = iota + 0xf5
|
||||
fieldTypeJSON fieldType = iota + 0xf5
|
||||
fieldTypeNewDecimal
|
||||
fieldTypeEnum
|
||||
fieldTypeSet
|
||||
|
@ -161,3 +166,9 @@ const (
|
|||
statusInTransReadonly
|
||||
statusSessionStateChanged
|
||||
)
|
||||
|
||||
const (
|
||||
cachingSha2PasswordRequestPublicKey = 2
|
||||
cachingSha2PasswordFastAuthSuccess = 3
|
||||
cachingSha2PasswordPerformFullAuthentication = 4
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// Package mysql provides a MySQL driver for Go's database/sql package
|
||||
// Package mysql provides a MySQL driver for Go's database/sql package.
|
||||
//
|
||||
// The driver should be used via the database/sql package:
|
||||
//
|
||||
|
@ -20,8 +20,14 @@ import (
|
|||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// watcher interface is used for context support (From Go 1.8)
|
||||
type watcher interface {
|
||||
startWatcher()
|
||||
}
|
||||
|
||||
// MySQLDriver is exported to make the driver directly accessible.
|
||||
// In general the driver is used via the database/sql package.
|
||||
type MySQLDriver struct{}
|
||||
|
@ -30,12 +36,17 @@ type MySQLDriver struct{}
|
|||
// Custom dial functions must be registered with RegisterDial
|
||||
type DialFunc func(addr string) (net.Conn, error)
|
||||
|
||||
var dials map[string]DialFunc
|
||||
var (
|
||||
dialsLock sync.RWMutex
|
||||
dials map[string]DialFunc
|
||||
)
|
||||
|
||||
// RegisterDial registers a custom dial function. It can then be used by the
|
||||
// network address mynet(addr), where mynet is the registered new network.
|
||||
// addr is passed as a parameter to the dial function.
|
||||
func RegisterDial(net string, dial DialFunc) {
|
||||
dialsLock.Lock()
|
||||
defer dialsLock.Unlock()
|
||||
if dials == nil {
|
||||
dials = make(map[string]DialFunc)
|
||||
}
|
||||
|
@ -52,16 +63,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
|||
mc := &mysqlConn{
|
||||
maxAllowedPacket: maxPacketSize,
|
||||
maxWriteSize: maxPacketSize - 1,
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
mc.cfg, err = ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mc.parseTime = mc.cfg.ParseTime
|
||||
mc.strict = mc.cfg.Strict
|
||||
|
||||
// Connect to Server
|
||||
if dial, ok := dials[mc.cfg.Net]; ok {
|
||||
dialsLock.RLock()
|
||||
dial, ok := dials[mc.cfg.Net]
|
||||
dialsLock.RUnlock()
|
||||
if ok {
|
||||
mc.netConn, err = dial(mc.cfg.Addr)
|
||||
} else {
|
||||
nd := net.Dialer{Timeout: mc.cfg.Timeout}
|
||||
|
@ -81,6 +95,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
|||
}
|
||||
}
|
||||
|
||||
// Call startWatcher for context support (From Go 1.8)
|
||||
if s, ok := interface{}(mc).(watcher); ok {
|
||||
s.startWatcher()
|
||||
}
|
||||
|
||||
mc.buf = newBuffer(mc.netConn)
|
||||
|
||||
// Set I/O timeouts
|
||||
|
@ -88,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
|||
mc.writeTimeout = mc.cfg.WriteTimeout
|
||||
|
||||
// Reading Handshake Initialization Packet
|
||||
cipher, err := mc.readInitPacket()
|
||||
authData, plugin, err := mc.readHandshakePacket()
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Send Client Authentication Packet
|
||||
if err = mc.writeAuthPacket(cipher); err != nil {
|
||||
authResp, addNUL, err := mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
// try the default auth plugin, if using the requested plugin failed
|
||||
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
|
||||
plugin = defaultAuthPlugin
|
||||
authResp, addNUL, err = mc.auth(authData, plugin)
|
||||
if err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
|
||||
mc.cleanup()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Handle response to auth packet, switch methods if possible
|
||||
if err = handleAuthResult(mc, cipher); err != nil {
|
||||
if err = mc.handleAuthResult(authData, plugin); err != nil {
|
||||
// Authentication failed and MySQL has already closed the connection
|
||||
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
|
||||
// Do not send COM_QUIT, just cleanup and return the error.
|
||||
|
@ -134,50 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
|
|||
return mc, nil
|
||||
}
|
||||
|
||||
func handleAuthResult(mc *mysqlConn, oldCipher []byte) error {
|
||||
// Read Result Packet
|
||||
cipher, err := mc.readResultOK()
|
||||
if err == nil {
|
||||
return nil // auth successful
|
||||
}
|
||||
|
||||
if mc.cfg == nil {
|
||||
return err // auth failed and retry not possible
|
||||
}
|
||||
|
||||
// Retry auth if configured to do so.
|
||||
if mc.cfg.AllowOldPasswords && err == ErrOldPassword {
|
||||
// Retry with old authentication method. Note: there are edge cases
|
||||
// where this should work but doesn't; this is currently "wontfix":
|
||||
// https://github.com/go-sql-driver/mysql/issues/184
|
||||
|
||||
// If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is
|
||||
// sent and we have to keep using the cipher sent in the init packet.
|
||||
if cipher == nil {
|
||||
cipher = oldCipher
|
||||
}
|
||||
|
||||
if err = mc.writeOldAuthPacket(cipher); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mc.readResultOK()
|
||||
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword {
|
||||
// Retry with clear text password for
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
|
||||
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
|
||||
if err = mc.writeClearAuthPacket(); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mc.readResultOK()
|
||||
} else if mc.cfg.AllowNativePasswords && err == ErrNativePassword {
|
||||
if err = mc.writeNativeAuthPacket(cipher); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mc.readResultOK()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func init() {
|
||||
sql.Register("mysql", &MySQLDriver{})
|
||||
}
|
||||
|
|
|
@ -0,0 +1,806 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// static interface implementation checks of mysqlConn
|
||||
var (
|
||||
_ driver.ConnBeginTx = &mysqlConn{}
|
||||
_ driver.ConnPrepareContext = &mysqlConn{}
|
||||
_ driver.ExecerContext = &mysqlConn{}
|
||||
_ driver.Pinger = &mysqlConn{}
|
||||
_ driver.QueryerContext = &mysqlConn{}
|
||||
)
|
||||
|
||||
// static interface implementation checks of mysqlStmt
|
||||
var (
|
||||
_ driver.StmtExecContext = &mysqlStmt{}
|
||||
_ driver.StmtQueryContext = &mysqlStmt{}
|
||||
)
|
||||
|
||||
// Ensure that all the driver interfaces are implemented
|
||||
var (
|
||||
// _ driver.RowsColumnTypeLength = &binaryRows{}
|
||||
// _ driver.RowsColumnTypeLength = &textRows{}
|
||||
_ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{}
|
||||
_ driver.RowsColumnTypeDatabaseTypeName = &textRows{}
|
||||
_ driver.RowsColumnTypeNullable = &binaryRows{}
|
||||
_ driver.RowsColumnTypeNullable = &textRows{}
|
||||
_ driver.RowsColumnTypePrecisionScale = &binaryRows{}
|
||||
_ driver.RowsColumnTypePrecisionScale = &textRows{}
|
||||
_ driver.RowsColumnTypeScanType = &binaryRows{}
|
||||
_ driver.RowsColumnTypeScanType = &textRows{}
|
||||
_ driver.RowsNextResultSet = &binaryRows{}
|
||||
_ driver.RowsNextResultSet = &textRows{}
|
||||
)
|
||||
|
||||
func TestMultiResultSet(t *testing.T) {
|
||||
type result struct {
|
||||
values [][]int
|
||||
columns []string
|
||||
}
|
||||
|
||||
// checkRows is a helper test function to validate rows containing 3 result
|
||||
// sets with specific values and columns. The basic query would look like this:
|
||||
//
|
||||
// SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
|
||||
// SELECT 0 UNION SELECT 1;
|
||||
// SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
|
||||
//
|
||||
// to distinguish test cases the first string argument is put in front of
|
||||
// every error or fatal message.
|
||||
checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) {
|
||||
expected := []result{
|
||||
{
|
||||
values: [][]int{{1, 2}, {3, 4}},
|
||||
columns: []string{"col1", "col2"},
|
||||
},
|
||||
{
|
||||
values: [][]int{{1, 2, 3}, {4, 5, 6}},
|
||||
columns: []string{"col1", "col2", "col3"},
|
||||
},
|
||||
}
|
||||
|
||||
var res1 result
|
||||
for rows.Next() {
|
||||
var res [2]int
|
||||
if err := rows.Scan(&res[0], &res[1]); err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
res1.values = append(res1.values, res[:])
|
||||
}
|
||||
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
dbt.Fatal(desc, err)
|
||||
}
|
||||
res1.columns = cols
|
||||
|
||||
if !reflect.DeepEqual(expected[0], res1) {
|
||||
dbt.Error(desc, "want =", expected[0], "got =", res1)
|
||||
}
|
||||
|
||||
if !rows.NextResultSet() {
|
||||
dbt.Fatal(desc, "expected next result set")
|
||||
}
|
||||
|
||||
// ignoring one result set
|
||||
|
||||
if !rows.NextResultSet() {
|
||||
dbt.Fatal(desc, "expected next result set")
|
||||
}
|
||||
|
||||
var res2 result
|
||||
cols, err = rows.Columns()
|
||||
if err != nil {
|
||||
dbt.Fatal(desc, err)
|
||||
}
|
||||
res2.columns = cols
|
||||
|
||||
for rows.Next() {
|
||||
var res [3]int
|
||||
if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil {
|
||||
dbt.Fatal(desc, err)
|
||||
}
|
||||
res2.values = append(res2.values, res[:])
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected[1], res2) {
|
||||
dbt.Error(desc, "want =", expected[1], "got =", res2)
|
||||
}
|
||||
|
||||
if rows.NextResultSet() {
|
||||
dbt.Error(desc, "unexpected next result set")
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
dbt.Error(desc, err)
|
||||
}
|
||||
}
|
||||
|
||||
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
|
||||
rows := dbt.mustQuery(`DO 1;
|
||||
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
|
||||
DO 1;
|
||||
SELECT 0 UNION SELECT 1;
|
||||
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`)
|
||||
defer rows.Close()
|
||||
checkRows("query: ", rows, dbt)
|
||||
})
|
||||
|
||||
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
|
||||
queries := []string{
|
||||
`
|
||||
DROP PROCEDURE IF EXISTS test_mrss;
|
||||
CREATE PROCEDURE test_mrss()
|
||||
BEGIN
|
||||
DO 1;
|
||||
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
|
||||
DO 1;
|
||||
SELECT 0 UNION SELECT 1;
|
||||
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
|
||||
END
|
||||
`,
|
||||
`
|
||||
DROP PROCEDURE IF EXISTS test_mrss;
|
||||
CREATE PROCEDURE test_mrss()
|
||||
BEGIN
|
||||
SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4;
|
||||
SELECT 0 UNION SELECT 1;
|
||||
SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;
|
||||
END
|
||||
`,
|
||||
}
|
||||
|
||||
defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss")
|
||||
|
||||
for i, query := range queries {
|
||||
dbt.mustExec(query)
|
||||
|
||||
stmt, err := dbt.db.Prepare("CALL test_mrss()")
|
||||
if err != nil {
|
||||
dbt.Fatalf("%v (i=%d)", err, i)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
for j := 0; j < 2; j++ {
|
||||
rows, err := stmt.Query()
|
||||
if err != nil {
|
||||
dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j)
|
||||
}
|
||||
checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMultiResultSetNoSelect(t *testing.T) {
|
||||
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
|
||||
rows := dbt.mustQuery("DO 1; DO 2;")
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
dbt.Error("unexpected row")
|
||||
}
|
||||
|
||||
if rows.NextResultSet() {
|
||||
dbt.Error("unexpected next result set")
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
dbt.Error("expected nil; got ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// tests if rows are set in a proper state if some results were ignored before
|
||||
// calling rows.NextResultSet.
|
||||
func TestSkipResults(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
rows := dbt.mustQuery("SELECT 1, 2")
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
dbt.Error("expected row")
|
||||
}
|
||||
|
||||
if rows.NextResultSet() {
|
||||
dbt.Error("unexpected next result set")
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
dbt.Error("expected nil; got ", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestPingContext(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
if err := dbt.db.PingContext(ctx); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancelExec(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Delay execution for just a bit until db.ExecContext has begun.
|
||||
defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
|
||||
|
||||
// This query will be canceled.
|
||||
startTime := time.Now()
|
||||
if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if d := time.Since(startTime); d > 500*time.Millisecond {
|
||||
dbt.Errorf("too long execution time: %s", d)
|
||||
}
|
||||
|
||||
// Wait for the INSERT query to be done.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Check how many times the query is executed.
|
||||
var v int
|
||||
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
if v != 1 { // TODO: need to kill the query, and v should be 0.
|
||||
dbt.Skipf("[WARN] expected val to be 1, got %d", v)
|
||||
}
|
||||
|
||||
// Context is already canceled, so error should come before execution.
|
||||
if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil {
|
||||
dbt.Error("expected error")
|
||||
} else if err.Error() != "context canceled" {
|
||||
dbt.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
|
||||
// The second insert query will fail, so the table has no changes.
|
||||
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
if v != 1 {
|
||||
dbt.Skipf("[WARN] expected val to be 1, got %d", v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancelQuery(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
// Delay execution for just a bit until db.ExecContext has begun.
|
||||
defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
|
||||
|
||||
// This query will be canceled.
|
||||
startTime := time.Now()
|
||||
if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if d := time.Since(startTime); d > 500*time.Millisecond {
|
||||
dbt.Errorf("too long execution time: %s", d)
|
||||
}
|
||||
|
||||
// Wait for the INSERT query to be done.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Check how many times the query is executed.
|
||||
var v int
|
||||
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
if v != 1 { // TODO: need to kill the query, and v should be 0.
|
||||
dbt.Skipf("[WARN] expected val to be 1, got %d", v)
|
||||
}
|
||||
|
||||
// Context is already canceled, so error should come before execution.
|
||||
if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
|
||||
// The second insert query will fail, so the table has no changes.
|
||||
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
if v != 1 {
|
||||
dbt.Skipf("[WARN] expected val to be 1, got %d", v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancelQueryRow(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test")
|
||||
if err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
|
||||
// the first row will be succeed.
|
||||
var v int
|
||||
if !rows.Next() {
|
||||
dbt.Fatalf("unexpected end")
|
||||
}
|
||||
if err := rows.Scan(&v); err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
|
||||
cancel()
|
||||
// make sure the driver receives the cancel request.
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
if rows.Next() {
|
||||
dbt.Errorf("expected end, but not")
|
||||
}
|
||||
if err := rows.Err(); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancelPrepare(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancelStmtExec(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
|
||||
if err != nil {
|
||||
dbt.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Delay execution for just a bit until db.ExecContext has begun.
|
||||
defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
|
||||
|
||||
// This query will be canceled.
|
||||
startTime := time.Now()
|
||||
if _, err := stmt.ExecContext(ctx); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if d := time.Since(startTime); d > 500*time.Millisecond {
|
||||
dbt.Errorf("too long execution time: %s", d)
|
||||
}
|
||||
|
||||
// Wait for the INSERT query to be done.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Check how many times the query is executed.
|
||||
var v int
|
||||
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
if v != 1 { // TODO: need to kill the query, and v should be 0.
|
||||
dbt.Skipf("[WARN] expected val to be 1, got %d", v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancelStmtQuery(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))")
|
||||
if err != nil {
|
||||
dbt.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Delay execution for just a bit until db.ExecContext has begun.
|
||||
defer time.AfterFunc(250*time.Millisecond, cancel).Stop()
|
||||
|
||||
// This query will be canceled.
|
||||
startTime := time.Now()
|
||||
if _, err := stmt.QueryContext(ctx); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if d := time.Since(startTime); d > 500*time.Millisecond {
|
||||
dbt.Errorf("too long execution time: %s", d)
|
||||
}
|
||||
|
||||
// Wait for the INSERT query has done.
|
||||
time.Sleep(time.Second)
|
||||
|
||||
// Check how many times the query is executed.
|
||||
var v int
|
||||
if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil {
|
||||
dbt.Fatalf("%s", err.Error())
|
||||
}
|
||||
if v != 1 { // TODO: need to kill the query, and v should be 0.
|
||||
dbt.Skipf("[WARN] expected val to be 1, got %d", v)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextCancelBegin(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
tx, err := dbt.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
|
||||
// Delay execution for just a bit until db.ExecContext has begun.
|
||||
defer time.AfterFunc(100*time.Millisecond, cancel).Stop()
|
||||
|
||||
// This query will be canceled.
|
||||
startTime := time.Now()
|
||||
if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
if d := time.Since(startTime); d > 500*time.Millisecond {
|
||||
dbt.Errorf("too long execution time: %s", d)
|
||||
}
|
||||
|
||||
// Transaction is canceled, so expect an error.
|
||||
switch err := tx.Commit(); err {
|
||||
case sql.ErrTxDone:
|
||||
// because the transaction has already been rollbacked.
|
||||
// the database/sql package watches ctx
|
||||
// and rollbacks when ctx is canceled.
|
||||
case context.Canceled:
|
||||
// the database/sql package rollbacks on another goroutine,
|
||||
// so the transaction may not be rollbacked depending on goroutine scheduling.
|
||||
default:
|
||||
dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err)
|
||||
}
|
||||
|
||||
// Context is canceled, so cannot begin a transaction.
|
||||
if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled {
|
||||
dbt.Errorf("expected context.Canceled, got %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextBeginIsolationLevel(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tx1, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelRepeatableRead,
|
||||
})
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
|
||||
tx2, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
})
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)")
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
|
||||
var v int
|
||||
row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
|
||||
if err := row.Scan(&v); err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
// Because writer transaction wasn't commited yet, it should be available
|
||||
if v != 0 {
|
||||
dbt.Errorf("expected val to be 0, got %d", v)
|
||||
}
|
||||
|
||||
err = tx1.Commit()
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
|
||||
row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
|
||||
if err := row.Scan(&v); err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
// Data written by writer transaction is already commited, it should be selectable
|
||||
if v != 1 {
|
||||
dbt.Errorf("expected val to be 1, got %d", v)
|
||||
}
|
||||
tx2.Commit()
|
||||
})
|
||||
}
|
||||
|
||||
func TestContextBeginReadOnly(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (v INTEGER)")
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{
|
||||
ReadOnly: true,
|
||||
})
|
||||
if _, ok := err.(*MySQLError); ok {
|
||||
dbt.Skip("It seems that your MySQL does not support READ ONLY transactions")
|
||||
return
|
||||
} else if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
|
||||
// INSERT queries fail in a READ ONLY transaction.
|
||||
_, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)")
|
||||
if _, ok := err.(*MySQLError); !ok {
|
||||
dbt.Errorf("expected MySQLError, got %v", err)
|
||||
}
|
||||
|
||||
// SELECT queries can be executed.
|
||||
var v int
|
||||
row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test")
|
||||
if err := row.Scan(&v); err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
if v != 0 {
|
||||
dbt.Errorf("expected val to be 0, got %d", v)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestRowsColumnTypes(t *testing.T) {
|
||||
niNULL := sql.NullInt64{Int64: 0, Valid: false}
|
||||
ni0 := sql.NullInt64{Int64: 0, Valid: true}
|
||||
ni1 := sql.NullInt64{Int64: 1, Valid: true}
|
||||
ni42 := sql.NullInt64{Int64: 42, Valid: true}
|
||||
nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false}
|
||||
nf0 := sql.NullFloat64{Float64: 0.0, Valid: true}
|
||||
nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true}
|
||||
nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true}
|
||||
nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true}
|
||||
nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true}
|
||||
nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true}
|
||||
nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true}
|
||||
nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true}
|
||||
ndNULL := NullTime{Time: time.Time{}, Valid: false}
|
||||
rbNULL := sql.RawBytes(nil)
|
||||
rb0 := sql.RawBytes("0")
|
||||
rb42 := sql.RawBytes("42")
|
||||
rbTest := sql.RawBytes("Test")
|
||||
rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00
|
||||
rbx0 := sql.RawBytes("\x00")
|
||||
rbx42 := sql.RawBytes("\x42")
|
||||
|
||||
var columns = []struct {
|
||||
name string
|
||||
fieldType string // type used when creating table schema
|
||||
databaseTypeName string // actual type used by MySQL
|
||||
scanType reflect.Type
|
||||
nullable bool
|
||||
precision int64 // 0 if not ok
|
||||
scale int64
|
||||
valuesIn [3]string
|
||||
valuesOut [3]interface{}
|
||||
}{
|
||||
{"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}},
|
||||
{"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}},
|
||||
{"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}},
|
||||
{"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
|
||||
{"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}},
|
||||
{"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
|
||||
{"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}},
|
||||
{"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}},
|
||||
{"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}},
|
||||
{"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}},
|
||||
{"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}},
|
||||
{"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}},
|
||||
{"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}},
|
||||
{"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}},
|
||||
{"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}},
|
||||
{"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}},
|
||||
{"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
|
||||
{"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
|
||||
{"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}},
|
||||
{"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}},
|
||||
{"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}},
|
||||
{"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}},
|
||||
{"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}},
|
||||
{"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}},
|
||||
{"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}},
|
||||
{"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}},
|
||||
{"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
|
||||
{"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
|
||||
{"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}},
|
||||
{"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
|
||||
{"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
|
||||
{"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
|
||||
{"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
|
||||
{"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}},
|
||||
{"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
|
||||
{"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
|
||||
{"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
|
||||
{"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}},
|
||||
{"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}},
|
||||
{"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}},
|
||||
{"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}},
|
||||
{"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}},
|
||||
{"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}},
|
||||
}
|
||||
|
||||
schema := ""
|
||||
values1 := ""
|
||||
values2 := ""
|
||||
values3 := ""
|
||||
for _, column := range columns {
|
||||
schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType)
|
||||
values1 += column.valuesIn[0] + ", "
|
||||
values2 += column.valuesIn[1] + ", "
|
||||
values3 += column.valuesIn[2] + ", "
|
||||
}
|
||||
schema = schema[:len(schema)-2]
|
||||
values1 = values1[:len(values1)-2]
|
||||
values2 = values2[:len(values2)-2]
|
||||
values3 = values3[:len(values3)-2]
|
||||
|
||||
dsns := []string{
|
||||
dsn + "&parseTime=true",
|
||||
dsn + "&parseTime=false",
|
||||
}
|
||||
for _, testdsn := range dsns {
|
||||
runTests(t, testdsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (" + schema + ")")
|
||||
dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")")
|
||||
|
||||
rows, err := dbt.db.Query("SELECT * FROM test")
|
||||
if err != nil {
|
||||
t.Fatalf("Query: %v", err)
|
||||
}
|
||||
|
||||
tt, err := rows.ColumnTypes()
|
||||
if err != nil {
|
||||
t.Fatalf("ColumnTypes: %v", err)
|
||||
}
|
||||
|
||||
if len(tt) != len(columns) {
|
||||
t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt))
|
||||
}
|
||||
|
||||
types := make([]reflect.Type, len(tt))
|
||||
for i, tp := range tt {
|
||||
column := columns[i]
|
||||
|
||||
// Name
|
||||
name := tp.Name()
|
||||
if name != column.name {
|
||||
t.Errorf("column name mismatch %s != %s", name, column.name)
|
||||
continue
|
||||
}
|
||||
|
||||
// DatabaseTypeName
|
||||
databaseTypeName := tp.DatabaseTypeName()
|
||||
if databaseTypeName != column.databaseTypeName {
|
||||
t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName)
|
||||
continue
|
||||
}
|
||||
|
||||
// ScanType
|
||||
scanType := tp.ScanType()
|
||||
if scanType != column.scanType {
|
||||
if scanType == nil {
|
||||
t.Errorf("scantype is null for column %q", name)
|
||||
} else {
|
||||
t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name())
|
||||
}
|
||||
continue
|
||||
}
|
||||
types[i] = scanType
|
||||
|
||||
// Nullable
|
||||
nullable, ok := tp.Nullable()
|
||||
if !ok {
|
||||
t.Errorf("nullable not ok %q", name)
|
||||
continue
|
||||
}
|
||||
if nullable != column.nullable {
|
||||
t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable)
|
||||
}
|
||||
|
||||
// Length
|
||||
// length, ok := tp.Length()
|
||||
// if length != column.length {
|
||||
// if !ok {
|
||||
// t.Errorf("length not ok for column %q", name)
|
||||
// } else {
|
||||
// t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length)
|
||||
// }
|
||||
// continue
|
||||
// }
|
||||
|
||||
// Precision and Scale
|
||||
precision, scale, ok := tp.DecimalSize()
|
||||
if precision != column.precision {
|
||||
if !ok {
|
||||
t.Errorf("precision not ok for column %q", name)
|
||||
} else {
|
||||
t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if scale != column.scale {
|
||||
if !ok {
|
||||
t.Errorf("scale not ok for column %q", name)
|
||||
} else {
|
||||
t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale)
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
values := make([]interface{}, len(tt))
|
||||
for i := range values {
|
||||
values[i] = reflect.New(types[i]).Interface()
|
||||
}
|
||||
i := 0
|
||||
for rows.Next() {
|
||||
err = rows.Scan(values...)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to scan values in %v", err)
|
||||
}
|
||||
for j := range values {
|
||||
value := reflect.ValueOf(values[j]).Elem().Interface()
|
||||
if !reflect.DeepEqual(value, columns[j].valuesOut[i]) {
|
||||
if columns[j].scanType == scanTypeRawBytes {
|
||||
t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes)))
|
||||
} else {
|
||||
t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
i++
|
||||
}
|
||||
if i != 3 {
|
||||
t.Errorf("expected 3 rows, got %d", i)
|
||||
}
|
||||
|
||||
if err := rows.Close(); err != nil {
|
||||
t.Errorf("error closing rows: %s", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValuerWithValueReceiverGivenNilValue(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (value VARCHAR(255))")
|
||||
dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil))
|
||||
// This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value()
|
||||
})
|
||||
}
|
|
@ -27,6 +27,12 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
// Ensure that all the driver interfaces are implemented
|
||||
var (
|
||||
_ driver.Rows = &binaryRows{}
|
||||
_ driver.Rows = &textRows{}
|
||||
)
|
||||
|
||||
var (
|
||||
user string
|
||||
pass string
|
||||
|
@ -63,7 +69,7 @@ func init() {
|
|||
addr = env("MYSQL_TEST_ADDR", "localhost:3306")
|
||||
dbname = env("MYSQL_TEST_DBNAME", "gotest")
|
||||
netAddr = fmt.Sprintf("%s(%s)", prot, addr)
|
||||
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s&strict=true", user, pass, netAddr, dbname)
|
||||
dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname)
|
||||
c, err := net.Dial(prot, addr)
|
||||
if err == nil {
|
||||
available = true
|
||||
|
@ -171,6 +177,17 @@ func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows)
|
|||
return rows
|
||||
}
|
||||
|
||||
func maybeSkip(t *testing.T, err error, skipErrno uint16) {
|
||||
mySQLErr, ok := err.(*MySQLError)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if mySQLErr.Number == skipErrno {
|
||||
t.Skipf("skipping test for error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmptyQuery(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
// just a comment, no query
|
||||
|
@ -482,6 +499,113 @@ func TestString(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestRawBytes(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
v1 := []byte("aaa")
|
||||
v2 := []byte("bbb")
|
||||
rows := dbt.mustQuery("SELECT ?, ?", v1, v2)
|
||||
if rows.Next() {
|
||||
var o1, o2 sql.RawBytes
|
||||
if err := rows.Scan(&o1, &o2); err != nil {
|
||||
dbt.Errorf("Got error: %v", err)
|
||||
}
|
||||
if !bytes.Equal(v1, o1) {
|
||||
dbt.Errorf("expected %v, got %v", v1, o1)
|
||||
}
|
||||
if !bytes.Equal(v2, o2) {
|
||||
dbt.Errorf("expected %v, got %v", v2, o2)
|
||||
}
|
||||
// https://github.com/go-sql-driver/mysql/issues/765
|
||||
// Appending to RawBytes shouldn't overwrite next RawBytes.
|
||||
o1 = append(o1, "xyzzy"...)
|
||||
if !bytes.Equal(v2, o2) {
|
||||
dbt.Errorf("expected %v, got %v", v2, o2)
|
||||
}
|
||||
} else {
|
||||
dbt.Errorf("no data")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
type testValuer struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (tv testValuer) Value() (driver.Value, error) {
|
||||
return tv.value, nil
|
||||
}
|
||||
|
||||
func TestValuer(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
in := testValuer{"a_value"}
|
||||
var out string
|
||||
var rows *sql.Rows
|
||||
|
||||
dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8")
|
||||
dbt.mustExec("INSERT INTO test VALUES (?)", in)
|
||||
rows = dbt.mustQuery("SELECT value FROM test")
|
||||
if rows.Next() {
|
||||
rows.Scan(&out)
|
||||
if in.value != out {
|
||||
dbt.Errorf("Valuer: %v != %s", in, out)
|
||||
}
|
||||
} else {
|
||||
dbt.Errorf("Valuer: no data")
|
||||
}
|
||||
|
||||
dbt.mustExec("DROP TABLE IF EXISTS test")
|
||||
})
|
||||
}
|
||||
|
||||
type testValuerWithValidation struct {
|
||||
value string
|
||||
}
|
||||
|
||||
func (tv testValuerWithValidation) Value() (driver.Value, error) {
|
||||
if len(tv.value) == 0 {
|
||||
return nil, fmt.Errorf("Invalid string valuer. Value must not be empty")
|
||||
}
|
||||
|
||||
return tv.value, nil
|
||||
}
|
||||
|
||||
func TestValuerWithValidation(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
in := testValuerWithValidation{"a_value"}
|
||||
var out string
|
||||
var rows *sql.Rows
|
||||
|
||||
dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8")
|
||||
dbt.mustExec("INSERT INTO testValuer VALUES (?)", in)
|
||||
|
||||
rows = dbt.mustQuery("SELECT value FROM testValuer")
|
||||
defer rows.Close()
|
||||
|
||||
if rows.Next() {
|
||||
rows.Scan(&out)
|
||||
if in.value != out {
|
||||
dbt.Errorf("Valuer: %v != %s", in, out)
|
||||
}
|
||||
} else {
|
||||
dbt.Errorf("Valuer: no data")
|
||||
}
|
||||
|
||||
if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil {
|
||||
dbt.Errorf("Failed to check valuer error")
|
||||
}
|
||||
|
||||
if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil {
|
||||
dbt.Errorf("Failed to check nil")
|
||||
}
|
||||
|
||||
if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil {
|
||||
dbt.Errorf("Failed to check not valuer")
|
||||
}
|
||||
|
||||
dbt.mustExec("DROP TABLE IF EXISTS testValuer")
|
||||
})
|
||||
}
|
||||
|
||||
type timeTests struct {
|
||||
dbtype string
|
||||
tlayout string
|
||||
|
@ -684,7 +808,7 @@ func TestDateTime(t *testing.T) {
|
|||
for _, setup := range setups.tests {
|
||||
allowBinTime := true
|
||||
if setup.s == "" {
|
||||
// fill time string whereever Go can reliable produce it
|
||||
// fill time string wherever Go can reliable produce it
|
||||
setup.s = setup.t.Format(setups.tlayout)
|
||||
} else if setup.s[0] == '!' {
|
||||
// skip tests using setup.t as source in queries
|
||||
|
@ -856,14 +980,14 @@ func TestNULL(t *testing.T) {
|
|||
dbt.Fatal(err)
|
||||
}
|
||||
if b != nil {
|
||||
dbt.Error("non-nil []byte wich should be nil")
|
||||
dbt.Error("non-nil []byte which should be nil")
|
||||
}
|
||||
// Read non-nil
|
||||
if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
if b == nil {
|
||||
dbt.Error("nil []byte wich should be non-nil")
|
||||
dbt.Error("nil []byte which should be non-nil")
|
||||
}
|
||||
// Insert nil
|
||||
b = nil
|
||||
|
@ -953,7 +1077,7 @@ func TestUint64(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestLongData(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) {
|
||||
var maxAllowedPacketSize int
|
||||
err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize)
|
||||
if err != nil {
|
||||
|
@ -1054,22 +1178,36 @@ func TestLoadData(t *testing.T) {
|
|||
dbt.Fatalf("rows count mismatch. Got %d, want 4", i)
|
||||
}
|
||||
}
|
||||
file, err := ioutil.TempFile("", "gotest")
|
||||
defer os.Remove(file.Name())
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
|
||||
file.Close()
|
||||
|
||||
dbt.db.Exec("DROP TABLE IF EXISTS test")
|
||||
dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8")
|
||||
|
||||
// Local File
|
||||
file, err := ioutil.TempFile("", "gotest")
|
||||
defer os.Remove(file.Name())
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
RegisterLocalFile(file.Name())
|
||||
|
||||
// Try first with empty file
|
||||
dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
|
||||
var count int
|
||||
err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&count)
|
||||
if err != nil {
|
||||
dbt.Fatal(err.Error())
|
||||
}
|
||||
if count != 0 {
|
||||
dbt.Fatalf("unexpected row count: got %d, want 0", count)
|
||||
}
|
||||
|
||||
// Then fille File with data and try to load it
|
||||
file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n")
|
||||
file.Close()
|
||||
dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name()))
|
||||
verifyLoadDataResult()
|
||||
// negative test
|
||||
|
||||
// Try with non-existing file
|
||||
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
|
||||
if err == nil {
|
||||
dbt.Fatal("load non-existent file didn't fail")
|
||||
|
@ -1145,84 +1283,6 @@ func TestFoundRows(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestStrict(t *testing.T) {
|
||||
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
|
||||
relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'"
|
||||
// make sure the MySQL version is recent enough with a separate connection
|
||||
// before running the test
|
||||
conn, err := MySQLDriver{}.Open(relaxedDsn)
|
||||
if conn != nil {
|
||||
conn.Close()
|
||||
}
|
||||
if me, ok := err.(*MySQLError); ok && me.Number == 1231 {
|
||||
// Error 1231: Variable 'sql_mode' can't be set to the value of 'ALLOW_INVALID_DATES'
|
||||
// => skip test, MySQL server version is too old
|
||||
return
|
||||
}
|
||||
runTests(t, relaxedDsn, func(dbt *DBTest) {
|
||||
dbt.mustExec("CREATE TABLE test (a TINYINT NOT NULL, b CHAR(4))")
|
||||
|
||||
var queries = [...]struct {
|
||||
in string
|
||||
codes []string
|
||||
}{
|
||||
{"DROP TABLE IF EXISTS no_such_table", []string{"1051"}},
|
||||
{"INSERT INTO test VALUES(10,'mysql'),(NULL,'test'),(300,'Open Source')", []string{"1265", "1048", "1264", "1265"}},
|
||||
}
|
||||
var err error
|
||||
|
||||
var checkWarnings = func(err error, mode string, idx int) {
|
||||
if err == nil {
|
||||
dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in)
|
||||
}
|
||||
|
||||
if warnings, ok := err.(MySQLWarnings); ok {
|
||||
var codes = make([]string, len(warnings))
|
||||
for i := range warnings {
|
||||
codes[i] = warnings[i].Code
|
||||
}
|
||||
if len(codes) != len(queries[idx].codes) {
|
||||
dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
|
||||
}
|
||||
|
||||
for i := range warnings {
|
||||
if codes[i] != queries[idx].codes[i] {
|
||||
dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
} else {
|
||||
dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
// text protocol
|
||||
for i := range queries {
|
||||
_, err = dbt.db.Exec(queries[i].in)
|
||||
checkWarnings(err, "text", i)
|
||||
}
|
||||
|
||||
var stmt *sql.Stmt
|
||||
|
||||
// binary protocol
|
||||
for i := range queries {
|
||||
stmt, err = dbt.db.Prepare(queries[i].in)
|
||||
if err != nil {
|
||||
dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error())
|
||||
}
|
||||
|
||||
_, err = stmt.Exec()
|
||||
checkWarnings(err, "binary", i)
|
||||
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTLS(t *testing.T) {
|
||||
tlsTest := func(dbt *DBTest) {
|
||||
if err := dbt.db.Ping(); err != nil {
|
||||
|
@ -1422,7 +1482,6 @@ func TestTimezoneConversion(t *testing.T) {
|
|||
|
||||
// Regression test for timezone handling
|
||||
tzTest := func(dbt *DBTest) {
|
||||
|
||||
// Create table
|
||||
dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)")
|
||||
|
||||
|
@ -1638,8 +1697,9 @@ func TestStmtMultiRows(t *testing.T) {
|
|||
// Regression test for
|
||||
// * more than 32 NULL parameters (issue 209)
|
||||
// * more parameters than fit into the buffer (issue 201)
|
||||
// * parameters * 64 > max_allowed_packet (issue 734)
|
||||
func TestPreparedManyCols(t *testing.T) {
|
||||
const numParams = defaultBufSize
|
||||
numParams := 65535
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
query := "SELECT ?" + strings.Repeat(",?", numParams-1)
|
||||
stmt, err := dbt.db.Prepare(query)
|
||||
|
@ -1647,15 +1707,25 @@ func TestPreparedManyCols(t *testing.T) {
|
|||
dbt.Fatal(err)
|
||||
}
|
||||
defer stmt.Close()
|
||||
|
||||
// create more parameters than fit into the buffer
|
||||
// which will take nil-values
|
||||
params := make([]interface{}, numParams)
|
||||
rows, err := stmt.Query(params...)
|
||||
if err != nil {
|
||||
stmt.Close()
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
defer rows.Close()
|
||||
rows.Close()
|
||||
|
||||
// Create 0byte string which we can't send via STMT_LONG_DATA.
|
||||
for i := 0; i < numParams; i++ {
|
||||
params[i] = ""
|
||||
}
|
||||
rows, err = stmt.Query(params...)
|
||||
if err != nil {
|
||||
dbt.Fatal(err)
|
||||
}
|
||||
rows.Close()
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -1739,7 +1809,7 @@ func TestCustomDial(t *testing.T) {
|
|||
return net.Dial(prot, addr)
|
||||
})
|
||||
|
||||
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
|
||||
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname))
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting: %s", err.Error())
|
||||
}
|
||||
|
@ -1772,7 +1842,7 @@ func TestSQLInjection(t *testing.T) {
|
|||
|
||||
dsns := []string{
|
||||
dsn,
|
||||
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
|
||||
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
|
||||
}
|
||||
for _, testdsn := range dsns {
|
||||
runTests(t, testdsn, createTest("1 OR 1=1"))
|
||||
|
@ -1802,7 +1872,7 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
|
|||
|
||||
dsns := []string{
|
||||
dsn,
|
||||
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
|
||||
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'",
|
||||
}
|
||||
for _, testdsn := range dsns {
|
||||
runTests(t, testdsn, testData)
|
||||
|
@ -1836,7 +1906,7 @@ func TestUnixSocketAuthFail(t *testing.T) {
|
|||
}
|
||||
}
|
||||
t.Logf("socket: %s", socket)
|
||||
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname)
|
||||
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname)
|
||||
db, err := sql.Open("mysql", badDSN)
|
||||
if err != nil {
|
||||
t.Fatalf("error connecting: %s", err.Error())
|
||||
|
@ -1902,3 +1972,104 @@ func TestInterruptBySignal(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestColumnsReusesSlice(t *testing.T) {
|
||||
rows := mysqlRows{
|
||||
rs: resultSet{
|
||||
columns: []mysqlField{
|
||||
{
|
||||
tableName: "test",
|
||||
name: "A",
|
||||
},
|
||||
{
|
||||
tableName: "test",
|
||||
name: "B",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1, func() {
|
||||
cols := rows.Columns()
|
||||
|
||||
if len(cols) != 2 {
|
||||
t.Fatalf("expected 2 columns, got %d", len(cols))
|
||||
}
|
||||
})
|
||||
|
||||
if allocs != 0 {
|
||||
t.Fatalf("expected 0 allocations, got %d", int(allocs))
|
||||
}
|
||||
|
||||
if rows.rs.columnNames == nil {
|
||||
t.Fatalf("expected columnNames to be set, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRejectReadOnly(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
// Create Table
|
||||
dbt.mustExec("CREATE TABLE test (value BOOL)")
|
||||
// Set the session to read-only. We didn't set the `rejectReadOnly`
|
||||
// option, so any writes after this should fail.
|
||||
_, err := dbt.db.Exec("SET SESSION TRANSACTION READ ONLY")
|
||||
// Error 1193: Unknown system variable 'TRANSACTION' => skip test,
|
||||
// MySQL server version is too old
|
||||
maybeSkip(t, err, 1193)
|
||||
if _, err := dbt.db.Exec("DROP TABLE test"); err == nil {
|
||||
t.Fatalf("writing to DB in read-only session without " +
|
||||
"rejectReadOnly did not error")
|
||||
}
|
||||
// Set the session back to read-write so runTests() can properly clean
|
||||
// up the table `test`.
|
||||
dbt.mustExec("SET SESSION TRANSACTION READ WRITE")
|
||||
})
|
||||
|
||||
// Enable the `rejectReadOnly` option.
|
||||
runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) {
|
||||
// Create Table
|
||||
dbt.mustExec("CREATE TABLE test (value BOOL)")
|
||||
// Set the session to read only. Any writes after this should error on
|
||||
// a driver.ErrBadConn, and cause `database/sql` to initiate a new
|
||||
// connection.
|
||||
dbt.mustExec("SET SESSION TRANSACTION READ ONLY")
|
||||
// This would error, but `database/sql` should automatically retry on a
|
||||
// new connection which is not read-only, and eventually succeed.
|
||||
dbt.mustExec("DROP TABLE test")
|
||||
})
|
||||
}
|
||||
|
||||
func TestPing(t *testing.T) {
|
||||
runTests(t, dsn, func(dbt *DBTest) {
|
||||
if err := dbt.db.Ping(); err != nil {
|
||||
dbt.fail("Ping", "Ping", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// See Issue #799
|
||||
func TestEmptyPassword(t *testing.T) {
|
||||
if !available {
|
||||
t.Skipf("MySQL server not running on %s", netAddr)
|
||||
}
|
||||
|
||||
dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname)
|
||||
db, err := sql.Open("mysql", dsn)
|
||||
if err == nil {
|
||||
defer db.Close()
|
||||
err = db.Ping()
|
||||
}
|
||||
|
||||
if pass == "" {
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
t.Fatal("expected authentication error")
|
||||
}
|
||||
if !strings.HasPrefix(err.Error(), "Error 1045") {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,13 @@ package mysql
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
@ -27,7 +29,9 @@ var (
|
|||
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
|
||||
)
|
||||
|
||||
// Config is a configuration parsed from a DSN string
|
||||
// Config is a configuration parsed from a DSN string.
|
||||
// If a new Config is created instead of being parsed from a DSN string,
|
||||
// the NewConfig function should be used, which sets default values.
|
||||
type Config struct {
|
||||
User string // Username
|
||||
Passwd string // Password (requires User)
|
||||
|
@ -38,6 +42,8 @@ type Config struct {
|
|||
Collation string // Connection collation
|
||||
Loc *time.Location // Location for time.Time values
|
||||
MaxAllowedPacket int // Max packet size allowed
|
||||
ServerPubKey string // Server public key name
|
||||
pubKey *rsa.PublicKey // Server public key
|
||||
TLSConfig string // TLS configuration name
|
||||
tls *tls.Config // TLS configuration
|
||||
Timeout time.Duration // Dial timeout
|
||||
|
@ -53,7 +59,54 @@ type Config struct {
|
|||
InterpolateParams bool // Interpolate placeholders into query string
|
||||
MultiStatements bool // Allow multiple statements in one query
|
||||
ParseTime bool // Parse time values to time.Time
|
||||
Strict bool // Return warnings as errors
|
||||
RejectReadOnly bool // Reject read-only connections
|
||||
}
|
||||
|
||||
// NewConfig creates a new Config and sets default values.
|
||||
func NewConfig() *Config {
|
||||
return &Config{
|
||||
Collation: defaultCollation,
|
||||
Loc: time.UTC,
|
||||
MaxAllowedPacket: defaultMaxAllowedPacket,
|
||||
AllowNativePasswords: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg *Config) normalize() error {
|
||||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
|
||||
return errInvalidDSNUnsafeCollation
|
||||
}
|
||||
|
||||
// Set default network if empty
|
||||
if cfg.Net == "" {
|
||||
cfg.Net = "tcp"
|
||||
}
|
||||
|
||||
// Set default address if empty
|
||||
if cfg.Addr == "" {
|
||||
switch cfg.Net {
|
||||
case "tcp":
|
||||
cfg.Addr = "127.0.0.1:3306"
|
||||
case "unix":
|
||||
cfg.Addr = "/tmp/mysql.sock"
|
||||
default:
|
||||
return errors.New("default addr for network '" + cfg.Net + "' unknown")
|
||||
}
|
||||
|
||||
} else if cfg.Net == "tcp" {
|
||||
cfg.Addr = ensureHavePort(cfg.Addr)
|
||||
}
|
||||
|
||||
if cfg.tls != nil {
|
||||
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
|
||||
host, _, err := net.SplitHostPort(cfg.Addr)
|
||||
if err == nil {
|
||||
cfg.tls.ServerName = host
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FormatDSN formats the given Config into a DSN string which can be passed to
|
||||
|
@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string {
|
|||
}
|
||||
}
|
||||
|
||||
if cfg.AllowNativePasswords {
|
||||
if !cfg.AllowNativePasswords {
|
||||
if hasParam {
|
||||
buf.WriteString("&allowNativePasswords=true")
|
||||
buf.WriteString("&allowNativePasswords=false")
|
||||
} else {
|
||||
hasParam = true
|
||||
buf.WriteString("?allowNativePasswords=true")
|
||||
buf.WriteString("?allowNativePasswords=false")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string {
|
|||
buf.WriteString(cfg.ReadTimeout.String())
|
||||
}
|
||||
|
||||
if cfg.Strict {
|
||||
if cfg.RejectReadOnly {
|
||||
if hasParam {
|
||||
buf.WriteString("&strict=true")
|
||||
buf.WriteString("&rejectReadOnly=true")
|
||||
} else {
|
||||
hasParam = true
|
||||
buf.WriteString("?strict=true")
|
||||
buf.WriteString("?rejectReadOnly=true")
|
||||
}
|
||||
}
|
||||
|
||||
if len(cfg.ServerPubKey) > 0 {
|
||||
if hasParam {
|
||||
buf.WriteString("&serverPubKey=")
|
||||
} else {
|
||||
hasParam = true
|
||||
buf.WriteString("?serverPubKey=")
|
||||
}
|
||||
buf.WriteString(url.QueryEscape(cfg.ServerPubKey))
|
||||
}
|
||||
|
||||
if cfg.Timeout > 0 {
|
||||
if hasParam {
|
||||
buf.WriteString("&timeout=")
|
||||
|
@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string {
|
|||
buf.WriteString(cfg.WriteTimeout.String())
|
||||
}
|
||||
|
||||
if cfg.MaxAllowedPacket > 0 {
|
||||
if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
|
||||
if hasParam {
|
||||
buf.WriteString("&maxAllowedPacket=")
|
||||
} else {
|
||||
|
@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string {
|
|||
|
||||
// other params
|
||||
if cfg.Params != nil {
|
||||
for param, value := range cfg.Params {
|
||||
var params []string
|
||||
for param := range cfg.Params {
|
||||
params = append(params, param)
|
||||
}
|
||||
sort.Strings(params)
|
||||
for _, param := range params {
|
||||
if hasParam {
|
||||
buf.WriteByte('&')
|
||||
} else {
|
||||
|
@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string {
|
|||
|
||||
buf.WriteString(param)
|
||||
buf.WriteByte('=')
|
||||
buf.WriteString(url.QueryEscape(value))
|
||||
buf.WriteString(url.QueryEscape(cfg.Params[param]))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string {
|
|||
// ParseDSN parses the DSN string to a Config
|
||||
func ParseDSN(dsn string) (cfg *Config, err error) {
|
||||
// New config with some default values
|
||||
cfg = &Config{
|
||||
Loc: time.UTC,
|
||||
Collation: defaultCollation,
|
||||
}
|
||||
cfg = NewConfig()
|
||||
|
||||
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
||||
// Find the last '/' (since the password or the net addr might contain a '/')
|
||||
|
@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) {
|
|||
return nil, errInvalidDSNNoSlash
|
||||
}
|
||||
|
||||
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
|
||||
return nil, errInvalidDSNUnsafeCollation
|
||||
if err = cfg.normalize(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set default network if empty
|
||||
if cfg.Net == "" {
|
||||
cfg.Net = "tcp"
|
||||
}
|
||||
|
||||
// Set default address if empty
|
||||
if cfg.Addr == "" {
|
||||
switch cfg.Net {
|
||||
case "tcp":
|
||||
cfg.Addr = "127.0.0.1:3306"
|
||||
case "unix":
|
||||
cfg.Addr = "/tmp/mysql.sock"
|
||||
default:
|
||||
return nil, errors.New("default addr for network '" + cfg.Net + "' unknown")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
|||
|
||||
// cfg params
|
||||
switch value := param[1]; param[0] {
|
||||
|
||||
// Disable INFILE whitelist / enable all files
|
||||
case "allowAllFiles":
|
||||
var isBool bool
|
||||
|
@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// Strict mode
|
||||
case "strict":
|
||||
// Reject read-only connections
|
||||
case "rejectReadOnly":
|
||||
var isBool bool
|
||||
cfg.Strict, isBool = readBool(value)
|
||||
cfg.RejectReadOnly, isBool = readBool(value)
|
||||
if !isBool {
|
||||
return errors.New("invalid bool value: " + value)
|
||||
}
|
||||
|
||||
// Server public key
|
||||
case "serverPubKey":
|
||||
name, err := url.QueryUnescape(value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid value for server pub key name: %v", err)
|
||||
}
|
||||
|
||||
if pubKey := getServerPubKey(name); pubKey != nil {
|
||||
cfg.ServerPubKey = name
|
||||
cfg.pubKey = pubKey
|
||||
} else {
|
||||
return errors.New("invalid value / unknown server pub key name: " + name)
|
||||
}
|
||||
|
||||
// Strict mode
|
||||
case "strict":
|
||||
panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
|
||||
|
||||
// Dial Timeout
|
||||
case "timeout":
|
||||
cfg.Timeout, err = time.ParseDuration(value)
|
||||
|
@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
|||
return fmt.Errorf("invalid value for TLS config name: %v", err)
|
||||
}
|
||||
|
||||
if tlsConfig, ok := tlsConfigRegister[name]; ok {
|
||||
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
|
||||
host, _, err := net.SplitHostPort(cfg.Addr)
|
||||
if err == nil {
|
||||
tlsConfig.ServerName = host
|
||||
}
|
||||
}
|
||||
|
||||
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
|
||||
cfg.TLSConfig = name
|
||||
cfg.tls = tlsConfig
|
||||
} else {
|
||||
|
@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
func ensureHavePort(addr string) string {
|
||||
if _, _, err := net.SplitHostPort(addr); err != nil {
|
||||
return net.JoinHostPort(addr, "3306")
|
||||
}
|
||||
return addr
|
||||
}
|
||||
|
|
|
@ -22,47 +22,57 @@ var testDSNs = []struct {
|
|||
out *Config
|
||||
}{{
|
||||
"username:password@protocol(address)/dbname?param=value",
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true",
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, ColumnsWithAlias: true},
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true},
|
||||
}, {
|
||||
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true",
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, ColumnsWithAlias: true, MultiStatements: true},
|
||||
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true},
|
||||
}, {
|
||||
"user@unix(/path/to/socket)/dbname?charset=utf8",
|
||||
&Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
&Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, TLSConfig: "true"},
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"},
|
||||
}, {
|
||||
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, TLSConfig: "skip-verify"},
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
|
||||
}, {
|
||||
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216},
|
||||
}, {
|
||||
"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
|
||||
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
|
||||
}, {
|
||||
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
|
||||
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local},
|
||||
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"/dbname",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"@/",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"/",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"user:p@/ssword@/",
|
||||
&Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
&Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"unix/?arg=%2Fsome%2Fpath.ext",
|
||||
&Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC},
|
||||
}}
|
||||
&Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"tcp(127.0.0.1)/dbname",
|
||||
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
}, {
|
||||
"tcp(de:ad:be:ef::ca:fe)/dbname",
|
||||
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
|
||||
},
|
||||
}
|
||||
|
||||
func TestDSNParser(t *testing.T) {
|
||||
for i, tst := range testDSNs {
|
||||
|
@ -88,6 +98,7 @@ func TestDSNParserInvalid(t *testing.T) {
|
|||
"(/", // no closing brace
|
||||
"net(addr)//", // unescaped
|
||||
"User:pass@tcp(1.2.3.4:3306)", // no trailing slash
|
||||
"net()/", // unknown default addr
|
||||
//"/dbname?arg=/some/unescaped/path",
|
||||
}
|
||||
|
||||
|
@ -124,11 +135,56 @@ func TestDSNReformat(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDSNServerPubKey(t *testing.T) {
|
||||
baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey="
|
||||
|
||||
RegisterServerPubKey("testKey", testPubKeyRSA)
|
||||
defer DeregisterServerPubKey("testKey")
|
||||
|
||||
tst := baseDSN + "testKey"
|
||||
cfg, err := ParseDSN(tst)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if cfg.ServerPubKey != "testKey" {
|
||||
t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey)
|
||||
}
|
||||
if cfg.pubKey != testPubKeyRSA {
|
||||
t.Error("pub key pointer doesn't match")
|
||||
}
|
||||
|
||||
// Key is missing
|
||||
tst = baseDSN + "invalid_name"
|
||||
cfg, err = ParseDSN(tst)
|
||||
if err == nil {
|
||||
t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNServerPubKeyQueryEscape(t *testing.T) {
|
||||
const name = "&%!:"
|
||||
dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name)
|
||||
|
||||
RegisterServerPubKey(name, testPubKeyRSA)
|
||||
defer DeregisterServerPubKey(name)
|
||||
|
||||
cfg, err := ParseDSN(dsn)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
|
||||
if cfg.pubKey != testPubKeyRSA {
|
||||
t.Error("pub key pointer doesn't match")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNWithCustomTLS(t *testing.T) {
|
||||
baseDSN := "User:password@tcp(localhost:5555)/dbname?tls="
|
||||
tlsCfg := tls.Config{}
|
||||
|
||||
RegisterTLSConfig("utils_test", &tlsCfg)
|
||||
defer DeregisterTLSConfig("utils_test")
|
||||
|
||||
// Custom TLS is missing
|
||||
tst := baseDSN + "invalid_tls"
|
||||
|
@ -159,9 +215,37 @@ func TestDSNWithCustomTLS(t *testing.T) {
|
|||
t.Error(err.Error())
|
||||
} else if cfg.tls.ServerName != name {
|
||||
t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
|
||||
} else if tlsCfg.ServerName != "" {
|
||||
t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNTLSConfig(t *testing.T) {
|
||||
expectedServerName := "example.com"
|
||||
dsn := "tcp(example.com:1234)/?tls=true"
|
||||
|
||||
cfg, err := ParseDSN(dsn)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
if cfg.tls == nil {
|
||||
t.Error("cfg.tls should not be nil")
|
||||
}
|
||||
if cfg.tls.ServerName != expectedServerName {
|
||||
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
|
||||
}
|
||||
|
||||
DeregisterTLSConfig("utils_test")
|
||||
dsn = "tcp(example.com)/?tls=true"
|
||||
cfg, err = ParseDSN(dsn)
|
||||
if err != nil {
|
||||
t.Error(err.Error())
|
||||
}
|
||||
if cfg.tls == nil {
|
||||
t.Error("cfg.tls should not be nil")
|
||||
}
|
||||
if cfg.tls.ServerName != expectedServerName {
|
||||
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
|
||||
|
@ -171,6 +255,7 @@ func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
|
|||
tlsCfg := tls.Config{ServerName: name}
|
||||
|
||||
RegisterTLSConfig(configKey, &tlsCfg)
|
||||
defer DeregisterTLSConfig(configKey)
|
||||
|
||||
cfg, err := ParseDSN(dsn)
|
||||
|
||||
|
@ -218,6 +303,21 @@ func TestDSNUnsafeCollation(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestParamsAreSorted(t *testing.T) {
|
||||
expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo"
|
||||
cfg := NewConfig()
|
||||
cfg.DBName = "dbname"
|
||||
cfg.InterpolateParams = true
|
||||
cfg.Params = map[string]string{
|
||||
"quux": "loo",
|
||||
"foobar": "baz",
|
||||
}
|
||||
actual := cfg.FormatDSN()
|
||||
if actual != expected {
|
||||
t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseDSN(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
|
|
|
@ -9,10 +9,8 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
)
|
||||
|
@ -31,6 +29,12 @@ var (
|
|||
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
|
||||
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
|
||||
ErrBusyBuffer = errors.New("busy buffer")
|
||||
|
||||
// errBadConnNoWrite is used for connection errors where nothing was sent to the database yet.
|
||||
// If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn
|
||||
// to trigger a resend.
|
||||
// See https://github.com/go-sql-driver/mysql/pull/302
|
||||
errBadConnNoWrite = errors.New("bad connection")
|
||||
)
|
||||
|
||||
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
|
||||
|
@ -59,74 +63,3 @@ type MySQLError struct {
|
|||
func (me *MySQLError) Error() string {
|
||||
return fmt.Sprintf("Error %d: %s", me.Number, me.Message)
|
||||
}
|
||||
|
||||
// MySQLWarnings is an error type which represents a group of one or more MySQL
|
||||
// warnings
|
||||
type MySQLWarnings []MySQLWarning
|
||||
|
||||
func (mws MySQLWarnings) Error() string {
|
||||
var msg string
|
||||
for i, warning := range mws {
|
||||
if i > 0 {
|
||||
msg += "\r\n"
|
||||
}
|
||||
msg += fmt.Sprintf(
|
||||
"%s %s: %s",
|
||||
warning.Level,
|
||||
warning.Code,
|
||||
warning.Message,
|
||||
)
|
||||
}
|
||||
return msg
|
||||
}
|
||||
|
||||
// MySQLWarning is an error type which represents a single MySQL warning.
|
||||
// Warnings are returned in groups only. See MySQLWarnings
|
||||
type MySQLWarning struct {
|
||||
Level string
|
||||
Code string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (mc *mysqlConn) getWarnings() (err error) {
|
||||
rows, err := mc.Query("SHOW WARNINGS", nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var warnings = MySQLWarnings{}
|
||||
var values = make([]driver.Value, 3)
|
||||
|
||||
for {
|
||||
err = rows.Next(values)
|
||||
switch err {
|
||||
case nil:
|
||||
warning := MySQLWarning{}
|
||||
|
||||
if raw, ok := values[0].([]byte); ok {
|
||||
warning.Level = string(raw)
|
||||
} else {
|
||||
warning.Level = fmt.Sprintf("%s", values[0])
|
||||
}
|
||||
if raw, ok := values[1].([]byte); ok {
|
||||
warning.Code = string(raw)
|
||||
} else {
|
||||
warning.Code = fmt.Sprintf("%s", values[1])
|
||||
}
|
||||
if raw, ok := values[2].([]byte); ok {
|
||||
warning.Message = string(raw)
|
||||
} else {
|
||||
warning.Message = fmt.Sprintf("%s", values[0])
|
||||
}
|
||||
|
||||
warnings = append(warnings, warning)
|
||||
|
||||
case io.EOF:
|
||||
return warnings
|
||||
|
||||
default:
|
||||
rows.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,194 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func (mf *mysqlField) typeDatabaseName() string {
|
||||
switch mf.fieldType {
|
||||
case fieldTypeBit:
|
||||
return "BIT"
|
||||
case fieldTypeBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "TEXT"
|
||||
}
|
||||
return "BLOB"
|
||||
case fieldTypeDate:
|
||||
return "DATE"
|
||||
case fieldTypeDateTime:
|
||||
return "DATETIME"
|
||||
case fieldTypeDecimal:
|
||||
return "DECIMAL"
|
||||
case fieldTypeDouble:
|
||||
return "DOUBLE"
|
||||
case fieldTypeEnum:
|
||||
return "ENUM"
|
||||
case fieldTypeFloat:
|
||||
return "FLOAT"
|
||||
case fieldTypeGeometry:
|
||||
return "GEOMETRY"
|
||||
case fieldTypeInt24:
|
||||
return "MEDIUMINT"
|
||||
case fieldTypeJSON:
|
||||
return "JSON"
|
||||
case fieldTypeLong:
|
||||
return "INT"
|
||||
case fieldTypeLongBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "LONGTEXT"
|
||||
}
|
||||
return "LONGBLOB"
|
||||
case fieldTypeLongLong:
|
||||
return "BIGINT"
|
||||
case fieldTypeMediumBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "MEDIUMTEXT"
|
||||
}
|
||||
return "MEDIUMBLOB"
|
||||
case fieldTypeNewDate:
|
||||
return "DATE"
|
||||
case fieldTypeNewDecimal:
|
||||
return "DECIMAL"
|
||||
case fieldTypeNULL:
|
||||
return "NULL"
|
||||
case fieldTypeSet:
|
||||
return "SET"
|
||||
case fieldTypeShort:
|
||||
return "SMALLINT"
|
||||
case fieldTypeString:
|
||||
if mf.charSet == collations[binaryCollation] {
|
||||
return "BINARY"
|
||||
}
|
||||
return "CHAR"
|
||||
case fieldTypeTime:
|
||||
return "TIME"
|
||||
case fieldTypeTimestamp:
|
||||
return "TIMESTAMP"
|
||||
case fieldTypeTiny:
|
||||
return "TINYINT"
|
||||
case fieldTypeTinyBLOB:
|
||||
if mf.charSet != collations[binaryCollation] {
|
||||
return "TINYTEXT"
|
||||
}
|
||||
return "TINYBLOB"
|
||||
case fieldTypeVarChar:
|
||||
if mf.charSet == collations[binaryCollation] {
|
||||
return "VARBINARY"
|
||||
}
|
||||
return "VARCHAR"
|
||||
case fieldTypeVarString:
|
||||
if mf.charSet == collations[binaryCollation] {
|
||||
return "VARBINARY"
|
||||
}
|
||||
return "VARCHAR"
|
||||
case fieldTypeYear:
|
||||
return "YEAR"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
scanTypeFloat32 = reflect.TypeOf(float32(0))
|
||||
scanTypeFloat64 = reflect.TypeOf(float64(0))
|
||||
scanTypeInt8 = reflect.TypeOf(int8(0))
|
||||
scanTypeInt16 = reflect.TypeOf(int16(0))
|
||||
scanTypeInt32 = reflect.TypeOf(int32(0))
|
||||
scanTypeInt64 = reflect.TypeOf(int64(0))
|
||||
scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{})
|
||||
scanTypeNullInt = reflect.TypeOf(sql.NullInt64{})
|
||||
scanTypeNullTime = reflect.TypeOf(NullTime{})
|
||||
scanTypeUint8 = reflect.TypeOf(uint8(0))
|
||||
scanTypeUint16 = reflect.TypeOf(uint16(0))
|
||||
scanTypeUint32 = reflect.TypeOf(uint32(0))
|
||||
scanTypeUint64 = reflect.TypeOf(uint64(0))
|
||||
scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{})
|
||||
scanTypeUnknown = reflect.TypeOf(new(interface{}))
|
||||
)
|
||||
|
||||
type mysqlField struct {
|
||||
tableName string
|
||||
name string
|
||||
length uint32
|
||||
flags fieldFlag
|
||||
fieldType fieldType
|
||||
decimals byte
|
||||
charSet uint8
|
||||
}
|
||||
|
||||
func (mf *mysqlField) scanType() reflect.Type {
|
||||
switch mf.fieldType {
|
||||
case fieldTypeTiny:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint8
|
||||
}
|
||||
return scanTypeInt8
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeShort, fieldTypeYear:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint16
|
||||
}
|
||||
return scanTypeInt16
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeInt24, fieldTypeLong:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint32
|
||||
}
|
||||
return scanTypeInt32
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeLongLong:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
if mf.flags&flagUnsigned != 0 {
|
||||
return scanTypeUint64
|
||||
}
|
||||
return scanTypeInt64
|
||||
}
|
||||
return scanTypeNullInt
|
||||
|
||||
case fieldTypeFloat:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
return scanTypeFloat32
|
||||
}
|
||||
return scanTypeNullFloat
|
||||
|
||||
case fieldTypeDouble:
|
||||
if mf.flags&flagNotNULL != 0 {
|
||||
return scanTypeFloat64
|
||||
}
|
||||
return scanTypeNullFloat
|
||||
|
||||
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
|
||||
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
|
||||
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
|
||||
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON,
|
||||
fieldTypeTime:
|
||||
return scanTypeRawBytes
|
||||
|
||||
case fieldTypeDate, fieldTypeNewDate,
|
||||
fieldTypeTimestamp, fieldTypeDateTime:
|
||||
// NullTime is always returned for more consistent behavior as it can
|
||||
// handle both cases of parseTime regardless if the field is nullable.
|
||||
return scanTypeNullTime
|
||||
|
||||
default:
|
||||
return scanTypeUnknown
|
||||
}
|
||||
}
|
|
@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|||
}
|
||||
|
||||
// send content packets
|
||||
if err == nil {
|
||||
// if packetSize == 0, the Reader contains no data
|
||||
if err == nil && packetSize > 0 {
|
||||
data := make([]byte, 4+packetSize)
|
||||
var n int
|
||||
for err == nil {
|
||||
|
@ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
|
|||
|
||||
// read OK packet
|
||||
if err == nil {
|
||||
_, err = mc.readResultOK()
|
||||
return err
|
||||
return mc.readResultOK()
|
||||
}
|
||||
|
||||
mc.readPacket()
|
||||
|
|
|
@ -30,9 +30,12 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|||
// read packet header
|
||||
data, err := mc.buf.readNext(4)
|
||||
if err != nil {
|
||||
if cerr := mc.canceled.Value(); cerr != nil {
|
||||
return nil, cerr
|
||||
}
|
||||
errLog.Print(err)
|
||||
mc.Close()
|
||||
return nil, driver.ErrBadConn
|
||||
return nil, ErrInvalidConn
|
||||
}
|
||||
|
||||
// packet length [24 bit]
|
||||
|
@ -54,7 +57,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|||
if prevData == nil {
|
||||
errLog.Print(ErrMalformPkt)
|
||||
mc.Close()
|
||||
return nil, driver.ErrBadConn
|
||||
return nil, ErrInvalidConn
|
||||
}
|
||||
|
||||
return prevData, nil
|
||||
|
@ -63,9 +66,12 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
|
|||
// read packet body [pktLen bytes]
|
||||
data, err = mc.buf.readNext(pktLen)
|
||||
if err != nil {
|
||||
if cerr := mc.canceled.Value(); cerr != nil {
|
||||
return nil, cerr
|
||||
}
|
||||
errLog.Print(err)
|
||||
mc.Close()
|
||||
return nil, driver.ErrBadConn
|
||||
return nil, ErrInvalidConn
|
||||
}
|
||||
|
||||
// return data if this was the last packet
|
||||
|
@ -125,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error {
|
|||
|
||||
// Handle error
|
||||
if err == nil { // n != len(data)
|
||||
mc.cleanup()
|
||||
errLog.Print(ErrMalformPkt)
|
||||
} else {
|
||||
if cerr := mc.canceled.Value(); cerr != nil {
|
||||
return cerr
|
||||
}
|
||||
if n == 0 && pktLen == len(data)-4 {
|
||||
// only for the first loop iteration when nothing was written yet
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
mc.cleanup()
|
||||
errLog.Print(err)
|
||||
}
|
||||
return driver.ErrBadConn
|
||||
return ErrInvalidConn
|
||||
}
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Initialisation Process *
|
||||
* Initialization Process *
|
||||
******************************************************************************/
|
||||
|
||||
// Handshake Initialization Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||||
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
||||
func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
|
||||
// in connection initialization we don't risk retrying non-idempotent actions.
|
||||
if err == ErrInvalidConn {
|
||||
return nil, "", driver.ErrBadConn
|
||||
}
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if data[0] == iERR {
|
||||
return nil, mc.handleErrorPacket(data)
|
||||
return nil, "", mc.handleErrorPacket(data)
|
||||
}
|
||||
|
||||
// protocol version [1 byte]
|
||||
if data[0] < minProtocolVersion {
|
||||
return nil, fmt.Errorf(
|
||||
return nil, "", fmt.Errorf(
|
||||
"unsupported protocol version %d. Version %d or higher is required",
|
||||
data[0],
|
||||
minProtocolVersion,
|
||||
|
@ -163,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
|||
pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4
|
||||
|
||||
// first part of the password cipher [8 bytes]
|
||||
cipher := data[pos : pos+8]
|
||||
authData := data[pos : pos+8]
|
||||
|
||||
// (filler) always 0x00 [1 byte]
|
||||
pos += 8 + 1
|
||||
|
@ -171,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
|||
// capability flags (lower 2 bytes) [2 bytes]
|
||||
mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||||
if mc.flags&clientProtocol41 == 0 {
|
||||
return nil, ErrOldProtocol
|
||||
return nil, "", ErrOldProtocol
|
||||
}
|
||||
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
|
||||
return nil, ErrNoTLS
|
||||
return nil, "", ErrNoTLS
|
||||
}
|
||||
pos += 2
|
||||
|
||||
plugin := ""
|
||||
if len(data) > pos {
|
||||
// character set [1 byte]
|
||||
// status flags [2 bytes]
|
||||
|
@ -198,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
|
|||
//
|
||||
// The official Python library uses the fixed length 12
|
||||
// which seems to work but technically could have a hidden bug.
|
||||
cipher = append(cipher, data[pos:pos+12]...)
|
||||
authData = append(authData, data[pos:pos+12]...)
|
||||
pos += 13
|
||||
|
||||
// TODO: Verify string termination
|
||||
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
|
||||
// \NUL otherwise
|
||||
//
|
||||
//if data[len(data)-1] == 0 {
|
||||
// return
|
||||
//}
|
||||
//return ErrMalformPkt
|
||||
if end := bytes.IndexByte(data[pos:], 0x00); end != -1 {
|
||||
plugin = string(data[pos : pos+end])
|
||||
} else {
|
||||
plugin = string(data[pos:])
|
||||
}
|
||||
|
||||
// make a memory safe copy of the cipher slice
|
||||
var b [20]byte
|
||||
copy(b[:], cipher)
|
||||
return b[:], nil
|
||||
copy(b[:], authData)
|
||||
return b[:], plugin, nil
|
||||
}
|
||||
|
||||
plugin = defaultAuthPlugin
|
||||
|
||||
// make a memory safe copy of the cipher slice
|
||||
var b [8]byte
|
||||
copy(b[:], cipher)
|
||||
return b[:], nil
|
||||
copy(b[:], authData)
|
||||
return b[:], plugin, nil
|
||||
}
|
||||
|
||||
// Client Authentication Packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
||||
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
|
||||
// Adjust client flags based on server support
|
||||
clientFlags := clientProtocol41 |
|
||||
clientSecureConn |
|
||||
|
@ -247,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|||
clientFlags |= clientMultiStatements
|
||||
}
|
||||
|
||||
// User Password
|
||||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
|
||||
// encode length of the auth plugin data
|
||||
var authRespLEIBuf [9]byte
|
||||
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
|
||||
if len(authRespLEI) > 1 {
|
||||
// if the length can not be written in 1 byte, it must be written as a
|
||||
// length encoded integer
|
||||
clientFlags |= clientPluginAuthLenEncClientData
|
||||
}
|
||||
|
||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
|
||||
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
|
||||
if addNUL {
|
||||
pktLen++
|
||||
}
|
||||
|
||||
// To specify a db name
|
||||
if n := len(mc.cfg.DBName); n > 0 {
|
||||
|
@ -261,9 +293,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|||
// Calculate packet length and get buffer with that size
|
||||
data := mc.buf.takeSmallBuffer(pktLen + 4)
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// ClientFlags [32 bit]
|
||||
|
@ -318,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|||
data[pos] = 0x00
|
||||
pos++
|
||||
|
||||
// ScrambleBuffer [length encoded integer]
|
||||
data[pos] = byte(len(scrambleBuff))
|
||||
pos += 1 + copy(data[pos+1:], scrambleBuff)
|
||||
// Auth Data [length encoded integer]
|
||||
pos += copy(data[pos:], authRespLEI)
|
||||
pos += copy(data[pos:], authResp)
|
||||
if addNUL {
|
||||
data[pos] = 0x00
|
||||
pos++
|
||||
}
|
||||
|
||||
// Databasename [null terminated string]
|
||||
if len(mc.cfg.DBName) > 0 {
|
||||
|
@ -329,72 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
|
|||
pos++
|
||||
}
|
||||
|
||||
// Assume native client during response
|
||||
pos += copy(data[pos:], "mysql_native_password")
|
||||
pos += copy(data[pos:], plugin)
|
||||
data[pos] = 0x00
|
||||
|
||||
// Send Auth packet
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
// Client old authentication packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
||||
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
|
||||
// User password
|
||||
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))
|
||||
|
||||
// Calculate the packet length and add a tailing 0
|
||||
pktLen := len(scrambleBuff) + 1
|
||||
data := mc.buf.takeSmallBuffer(4 + pktLen)
|
||||
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
|
||||
pktLen := 4 + len(authData)
|
||||
if addNUL {
|
||||
pktLen++
|
||||
}
|
||||
data := mc.buf.takeSmallBuffer(pktLen)
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// Add the scrambled password [null terminated string]
|
||||
copy(data[4:], scrambleBuff)
|
||||
data[4+pktLen-1] = 0x00
|
||||
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
// Client clear text authentication packet
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
||||
func (mc *mysqlConn) writeClearAuthPacket() error {
|
||||
// Calculate the packet length and add a tailing 0
|
||||
pktLen := len(mc.cfg.Passwd) + 1
|
||||
data := mc.buf.takeSmallBuffer(4 + pktLen)
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
// Add the auth data [EOF]
|
||||
copy(data[4:], authData)
|
||||
if addNUL {
|
||||
data[pktLen-1] = 0x00
|
||||
}
|
||||
|
||||
// Add the clear password [null terminated string]
|
||||
copy(data[4:], mc.cfg.Passwd)
|
||||
data[4+pktLen-1] = 0x00
|
||||
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
// Native password authentication method
|
||||
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
|
||||
func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error {
|
||||
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
|
||||
|
||||
// Calculate the packet length and add a tailing 0
|
||||
pktLen := len(scrambleBuff)
|
||||
data := mc.buf.takeSmallBuffer(4 + pktLen)
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
// Add the scramble
|
||||
copy(data[4:], scrambleBuff)
|
||||
|
||||
return mc.writePacket(data)
|
||||
}
|
||||
|
||||
|
@ -408,9 +404,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
|
|||
|
||||
data := mc.buf.takeSmallBuffer(4 + 1)
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// Add command byte
|
||||
|
@ -427,9 +423,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
|
|||
pktLen := 1 + len(arg)
|
||||
data := mc.buf.takeBuffer(pktLen + 4)
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// Add command byte
|
||||
|
@ -448,9 +444,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
|||
|
||||
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// Add command byte
|
||||
|
@ -470,44 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
|
|||
* Result Packets *
|
||||
******************************************************************************/
|
||||
|
||||
// Returns error if Packet is not an 'Result OK'-Packet
|
||||
func (mc *mysqlConn) readResultOK() ([]byte, error) {
|
||||
func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
|
||||
data, err := mc.readPacket()
|
||||
if err == nil {
|
||||
// packet indicator
|
||||
switch data[0] {
|
||||
|
||||
case iOK:
|
||||
return nil, mc.handleOkPacket(data)
|
||||
|
||||
case iEOF:
|
||||
if len(data) > 1 {
|
||||
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
||||
plugin := string(data[1:pluginEndIndex])
|
||||
cipher := data[pluginEndIndex+1 : len(data)-1]
|
||||
|
||||
if plugin == "mysql_old_password" {
|
||||
// using old_passwords
|
||||
return cipher, ErrOldPassword
|
||||
} else if plugin == "mysql_clear_password" {
|
||||
// using clear text password
|
||||
return cipher, ErrCleartextPassword
|
||||
} else if plugin == "mysql_native_password" {
|
||||
// using mysql default authentication method
|
||||
return cipher, ErrNativePassword
|
||||
} else {
|
||||
return cipher, ErrUnknownPlugin
|
||||
}
|
||||
} else {
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
||||
return nil, ErrOldPassword
|
||||
}
|
||||
|
||||
default: // Error otherwise
|
||||
return nil, mc.handleErrorPacket(data)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return nil, err
|
||||
|
||||
// packet indicator
|
||||
switch data[0] {
|
||||
|
||||
case iOK:
|
||||
return nil, "", mc.handleOkPacket(data)
|
||||
|
||||
case iAuthMoreData:
|
||||
return data[1:], "", err
|
||||
|
||||
case iEOF:
|
||||
if len(data) < 1 {
|
||||
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
|
||||
return nil, "mysql_old_password", nil
|
||||
}
|
||||
pluginEndIndex := bytes.IndexByte(data, 0x00)
|
||||
if pluginEndIndex < 0 {
|
||||
return nil, "", ErrMalformPkt
|
||||
}
|
||||
plugin := string(data[1:pluginEndIndex])
|
||||
authData := data[pluginEndIndex+1:]
|
||||
return authData, plugin, nil
|
||||
|
||||
default: // Error otherwise
|
||||
return nil, "", mc.handleErrorPacket(data)
|
||||
}
|
||||
}
|
||||
|
||||
// Returns error if Packet is not an 'Result OK'-Packet
|
||||
func (mc *mysqlConn) readResultOK() error {
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if data[0] == iOK {
|
||||
return mc.handleOkPacket(data)
|
||||
}
|
||||
return mc.handleErrorPacket(data)
|
||||
}
|
||||
|
||||
// Result Set Header Packet
|
||||
|
@ -550,6 +552,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
|
|||
// Error Number [16 bit uint]
|
||||
errno := binary.LittleEndian.Uint16(data[1:3])
|
||||
|
||||
// 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION
|
||||
// 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover)
|
||||
if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly {
|
||||
// Oops; we are connected to a read-only connection, and won't be able
|
||||
// to issue any write statements. Since RejectReadOnly is configured,
|
||||
// we throw away this connection hoping this one would have write
|
||||
// permission. This is specifically for a possible race condition
|
||||
// during failover (e.g. on AWS Aurora). See README.md for more.
|
||||
//
|
||||
// We explicitly close the connection before returning
|
||||
// driver.ErrBadConn to ensure that `database/sql` purges this
|
||||
// connection and initiates a new one for next statement next time.
|
||||
mc.Close()
|
||||
return driver.ErrBadConn
|
||||
}
|
||||
|
||||
pos := 3
|
||||
|
||||
// SQL State [optional: # + 5bytes string]
|
||||
|
@ -584,19 +602,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
|
|||
|
||||
// server_status [2 bytes]
|
||||
mc.status = readStatus(data[1+n+m : 1+n+m+2])
|
||||
if err := mc.discardResults(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// warning count [2 bytes]
|
||||
if !mc.strict {
|
||||
if mc.status&statusMoreResultsExists != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
pos := 1 + n + m + 2
|
||||
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
|
||||
return mc.getWarnings()
|
||||
}
|
||||
// warning count [2 bytes]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -668,14 +679,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pos += n
|
||||
|
||||
// Filler [uint8]
|
||||
pos++
|
||||
|
||||
// Charset [charset, collation uint8]
|
||||
columns[i].charSet = data[pos]
|
||||
pos += 2
|
||||
|
||||
// Length [uint32]
|
||||
pos += n + 1 + 2 + 4
|
||||
columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4])
|
||||
pos += 4
|
||||
|
||||
// Field type [uint8]
|
||||
columns[i].fieldType = data[pos]
|
||||
columns[i].fieldType = fieldType(data[pos])
|
||||
pos++
|
||||
|
||||
// Flags [uint16]
|
||||
|
@ -698,6 +716,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
|
|||
func (rows *textRows) readRow(dest []driver.Value) error {
|
||||
mc := rows.mc
|
||||
|
||||
if rows.rs.done {
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
data, err := mc.readPacket()
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -707,15 +729,11 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
|||
if data[0] == iEOF && len(data) == 5 {
|
||||
// server_status [2 bytes]
|
||||
rows.mc.status = readStatus(data[3:])
|
||||
err = rows.mc.discardResults()
|
||||
if err == nil {
|
||||
err = io.EOF
|
||||
} else {
|
||||
// connection unusable
|
||||
rows.mc.Close()
|
||||
rows.rs.done = true
|
||||
if !rows.HasNextResultSet() {
|
||||
rows.mc = nil
|
||||
}
|
||||
rows.mc = nil
|
||||
return err
|
||||
return io.EOF
|
||||
}
|
||||
if data[0] == iERR {
|
||||
rows.mc = nil
|
||||
|
@ -736,7 +754,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
|
|||
if !mc.parseTime {
|
||||
continue
|
||||
} else {
|
||||
switch rows.columns[i].fieldType {
|
||||
switch rows.rs.columns[i].fieldType {
|
||||
case fieldTypeTimestamp, fieldTypeDateTime,
|
||||
fieldTypeDate, fieldTypeNewDate:
|
||||
dest[i], err = parseDateTime(
|
||||
|
@ -808,14 +826,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
|
|||
// Reserved [8 bit]
|
||||
|
||||
// Warning count [16 bit uint]
|
||||
if !stmt.mc.strict {
|
||||
return columnCount, nil
|
||||
}
|
||||
|
||||
// Check for warnings count > 0, only available in MySQL > 4.1
|
||||
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
|
||||
return columnCount, stmt.mc.getWarnings()
|
||||
}
|
||||
return columnCount, nil
|
||||
}
|
||||
return 0, err
|
||||
|
@ -832,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
|
|||
// 2 bytes paramID
|
||||
const dataOffset = 1 + 4 + 2
|
||||
|
||||
// Can not use the write buffer since
|
||||
// Cannot use the write buffer since
|
||||
// a) the buffer is too small
|
||||
// b) it is in use
|
||||
data := make([]byte, 4+1+4+2+len(arg))
|
||||
|
@ -887,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
const minPktLen = 4 + 1 + 4 + 1 + 4
|
||||
mc := stmt.mc
|
||||
|
||||
// Determine threshould dynamically to avoid packet size shortage.
|
||||
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
|
||||
if longDataSize < 64 {
|
||||
longDataSize = 64
|
||||
}
|
||||
|
||||
// Reset packet-sequence
|
||||
mc.sequence = 0
|
||||
|
||||
|
@ -898,9 +915,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
data = mc.buf.takeCompleteBuffer()
|
||||
}
|
||||
if data == nil {
|
||||
// can not take the buffer. Something must be wrong with the connection
|
||||
// cannot take the buffer. Something must be wrong with the connection
|
||||
errLog.Print(ErrBusyBuffer)
|
||||
return driver.ErrBadConn
|
||||
return errBadConnNoWrite
|
||||
}
|
||||
|
||||
// command [1 byte]
|
||||
|
@ -959,7 +976,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
// build NULL-bitmap
|
||||
if arg == nil {
|
||||
nullMask[i/8] |= 1 << (uint(i) & 7)
|
||||
paramTypes[i+i] = fieldTypeNULL
|
||||
paramTypes[i+i] = byte(fieldTypeNULL)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
continue
|
||||
}
|
||||
|
@ -967,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
// cache types and values
|
||||
switch v := arg.(type) {
|
||||
case int64:
|
||||
paramTypes[i+i] = fieldTypeLongLong
|
||||
paramTypes[i+i] = byte(fieldTypeLongLong)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
||||
|
@ -983,7 +1000,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
}
|
||||
|
||||
case float64:
|
||||
paramTypes[i+i] = fieldTypeDouble
|
||||
paramTypes[i+i] = byte(fieldTypeDouble)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if cap(paramValues)-len(paramValues)-8 >= 0 {
|
||||
|
@ -999,7 +1016,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
}
|
||||
|
||||
case bool:
|
||||
paramTypes[i+i] = fieldTypeTiny
|
||||
paramTypes[i+i] = byte(fieldTypeTiny)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if v {
|
||||
|
@ -1011,10 +1028,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
case []byte:
|
||||
// Common case (non-nil value) first
|
||||
if v != nil {
|
||||
paramTypes[i+i] = fieldTypeString
|
||||
paramTypes[i+i] = byte(fieldTypeString)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
||||
if len(v) < longDataSize {
|
||||
paramValues = appendLengthEncodedInteger(paramValues,
|
||||
uint64(len(v)),
|
||||
)
|
||||
|
@ -1029,14 +1046,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
|
||||
// Handle []byte(nil) as a NULL value
|
||||
nullMask[i/8] |= 1 << (uint(i) & 7)
|
||||
paramTypes[i+i] = fieldTypeNULL
|
||||
paramTypes[i+i] = byte(fieldTypeNULL)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
case string:
|
||||
paramTypes[i+i] = fieldTypeString
|
||||
paramTypes[i+i] = byte(fieldTypeString)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 {
|
||||
if len(v) < longDataSize {
|
||||
paramValues = appendLengthEncodedInteger(paramValues,
|
||||
uint64(len(v)),
|
||||
)
|
||||
|
@ -1048,23 +1065,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
|
|||
}
|
||||
|
||||
case time.Time:
|
||||
paramTypes[i+i] = fieldTypeString
|
||||
paramTypes[i+i] = byte(fieldTypeString)
|
||||
paramTypes[i+i+1] = 0x00
|
||||
|
||||
var val []byte
|
||||
var a [64]byte
|
||||
var b = a[:0]
|
||||
|
||||
if v.IsZero() {
|
||||
val = []byte("0000-00-00")
|
||||
b = append(b, "0000-00-00"...)
|
||||
} else {
|
||||
val = []byte(v.In(mc.cfg.Loc).Format(timeFormat))
|
||||
b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat)
|
||||
}
|
||||
|
||||
paramValues = appendLengthEncodedInteger(paramValues,
|
||||
uint64(len(val)),
|
||||
uint64(len(b)),
|
||||
)
|
||||
paramValues = append(paramValues, val...)
|
||||
paramValues = append(paramValues, b...)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("can not convert type: %T", arg)
|
||||
return fmt.Errorf("cannot convert type: %T", arg)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1097,8 +1116,6 @@ func (mc *mysqlConn) discardResults() error {
|
|||
if err := mc.readUntilEOF(); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
mc.status &^= statusMoreResultsExists
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -1116,20 +1133,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
// EOF Packet
|
||||
if data[0] == iEOF && len(data) == 5 {
|
||||
rows.mc.status = readStatus(data[3:])
|
||||
err = rows.mc.discardResults()
|
||||
if err == nil {
|
||||
err = io.EOF
|
||||
} else {
|
||||
// connection unusable
|
||||
rows.mc.Close()
|
||||
rows.rs.done = true
|
||||
if !rows.HasNextResultSet() {
|
||||
rows.mc = nil
|
||||
}
|
||||
rows.mc = nil
|
||||
return err
|
||||
return io.EOF
|
||||
}
|
||||
mc := rows.mc
|
||||
rows.mc = nil
|
||||
|
||||
// Error otherwise
|
||||
return rows.mc.handleErrorPacket(data)
|
||||
return mc.handleErrorPacket(data)
|
||||
}
|
||||
|
||||
// NULL-bitmap, [(column-count + 7 + 2) / 8 bytes]
|
||||
|
@ -1145,14 +1159,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
}
|
||||
|
||||
// Convert to byte-coded string
|
||||
switch rows.columns[i].fieldType {
|
||||
switch rows.rs.columns[i].fieldType {
|
||||
case fieldTypeNULL:
|
||||
dest[i] = nil
|
||||
continue
|
||||
|
||||
// Numeric Types
|
||||
case fieldTypeTiny:
|
||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
||||
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||
dest[i] = int64(data[pos])
|
||||
} else {
|
||||
dest[i] = int64(int8(data[pos]))
|
||||
|
@ -1161,7 +1175,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
continue
|
||||
|
||||
case fieldTypeShort, fieldTypeYear:
|
||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
||||
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||
dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2]))
|
||||
} else {
|
||||
dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2])))
|
||||
|
@ -1170,7 +1184,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
continue
|
||||
|
||||
case fieldTypeInt24, fieldTypeLong:
|
||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
||||
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||
dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
||||
} else {
|
||||
dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
||||
|
@ -1179,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
continue
|
||||
|
||||
case fieldTypeLongLong:
|
||||
if rows.columns[i].flags&flagUnsigned != 0 {
|
||||
if rows.rs.columns[i].flags&flagUnsigned != 0 {
|
||||
val := binary.LittleEndian.Uint64(data[pos : pos+8])
|
||||
if val > math.MaxInt64 {
|
||||
dest[i] = uint64ToString(val)
|
||||
|
@ -1193,7 +1207,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
continue
|
||||
|
||||
case fieldTypeFloat:
|
||||
dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
|
||||
dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))
|
||||
pos += 4
|
||||
continue
|
||||
|
||||
|
@ -1233,10 +1247,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
case isNull:
|
||||
dest[i] = nil
|
||||
continue
|
||||
case rows.columns[i].fieldType == fieldTypeTime:
|
||||
case rows.rs.columns[i].fieldType == fieldTypeTime:
|
||||
// database/sql does not support an equivalent to TIME, return a string
|
||||
var dstlen uint8
|
||||
switch decimals := rows.columns[i].decimals; decimals {
|
||||
switch decimals := rows.rs.columns[i].decimals; decimals {
|
||||
case 0x00, 0x1f:
|
||||
dstlen = 8
|
||||
case 1, 2, 3, 4, 5, 6:
|
||||
|
@ -1244,7 +1258,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
default:
|
||||
return fmt.Errorf(
|
||||
"protocol error, illegal decimals value %d",
|
||||
rows.columns[i].decimals,
|
||||
rows.rs.columns[i].decimals,
|
||||
)
|
||||
}
|
||||
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
|
||||
|
@ -1252,10 +1266,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
|
||||
default:
|
||||
var dstlen uint8
|
||||
if rows.columns[i].fieldType == fieldTypeDate {
|
||||
if rows.rs.columns[i].fieldType == fieldTypeDate {
|
||||
dstlen = 10
|
||||
} else {
|
||||
switch decimals := rows.columns[i].decimals; decimals {
|
||||
switch decimals := rows.rs.columns[i].decimals; decimals {
|
||||
case 0x00, 0x1f:
|
||||
dstlen = 19
|
||||
case 1, 2, 3, 4, 5, 6:
|
||||
|
@ -1263,7 +1277,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
default:
|
||||
return fmt.Errorf(
|
||||
"protocol error, illegal decimals value %d",
|
||||
rows.columns[i].decimals,
|
||||
rows.rs.columns[i].decimals,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
@ -1279,7 +1293,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
|
|||
|
||||
// Please report if this happens!
|
||||
default:
|
||||
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
|
||||
return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"bytes"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
@ -24,16 +24,17 @@ var (
|
|||
|
||||
// struct to mock a net.Conn for testing purposes
|
||||
type mockConn struct {
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
data []byte
|
||||
closed bool
|
||||
read int
|
||||
written int
|
||||
reads int
|
||||
writes int
|
||||
maxReads int
|
||||
maxWrites int
|
||||
laddr net.Addr
|
||||
raddr net.Addr
|
||||
data []byte
|
||||
written []byte
|
||||
queuedReplies [][]byte
|
||||
closed bool
|
||||
read int
|
||||
reads int
|
||||
writes int
|
||||
maxReads int
|
||||
maxWrites int
|
||||
}
|
||||
|
||||
func (m *mockConn) Read(b []byte) (n int, err error) {
|
||||
|
@ -62,7 +63,12 @@ func (m *mockConn) Write(b []byte) (n int, err error) {
|
|||
}
|
||||
|
||||
n = len(b)
|
||||
m.written += n
|
||||
m.written = append(m.written, b...)
|
||||
|
||||
if n > 0 && len(m.queuedReplies) > 0 {
|
||||
m.data = m.queuedReplies[0]
|
||||
m.queuedReplies = m.queuedReplies[1:]
|
||||
}
|
||||
return
|
||||
}
|
||||
func (m *mockConn) Close() error {
|
||||
|
@ -88,6 +94,19 @@ func (m *mockConn) SetWriteDeadline(t time.Time) error {
|
|||
// make sure mockConn implements the net.Conn interface
|
||||
var _ net.Conn = new(mockConn)
|
||||
|
||||
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
cfg: NewConfig(),
|
||||
netConn: conn,
|
||||
closech: make(chan struct{}),
|
||||
maxAllowedPacket: defaultMaxAllowedPacket,
|
||||
sequence: sequence,
|
||||
}
|
||||
return conn, mc
|
||||
}
|
||||
|
||||
func TestReadPacketSingleByte(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
|
@ -101,7 +120,7 @@ func TestReadPacketSingleByte(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != 1 {
|
||||
t.Fatalf("unexpected packet lenght: expected %d, got %d", 1, len(packet))
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet))
|
||||
}
|
||||
if packet[0] != 0xff {
|
||||
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
|
||||
|
@ -171,7 +190,7 @@ func TestReadPacketSplit(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != maxPacketSize {
|
||||
t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize, len(packet))
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet))
|
||||
}
|
||||
if packet[0] != 0x11 {
|
||||
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
|
||||
|
@ -205,7 +224,7 @@ func TestReadPacketSplit(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != 2*maxPacketSize {
|
||||
t.Fatalf("unexpected packet lenght: expected %d, got %d", 2*maxPacketSize, len(packet))
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet))
|
||||
}
|
||||
if packet[0] != 0x11 {
|
||||
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
|
||||
|
@ -231,7 +250,7 @@ func TestReadPacketSplit(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
if len(packet) != maxPacketSize+42 {
|
||||
t.Fatalf("unexpected packet lenght: expected %d, got %d", maxPacketSize+42, len(packet))
|
||||
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet))
|
||||
}
|
||||
if packet[0] != 0x11 {
|
||||
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
|
||||
|
@ -244,15 +263,16 @@ func TestReadPacketSplit(t *testing.T) {
|
|||
func TestReadPacketFail(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
buf: newBuffer(conn),
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
|
||||
// illegal empty (stand-alone) packet
|
||||
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
|
||||
conn.maxReads = 1
|
||||
_, err := mc.readPacket()
|
||||
if err != driver.ErrBadConn {
|
||||
t.Errorf("expected ErrBadConn, got %v", err)
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %v", err)
|
||||
}
|
||||
|
||||
// reset
|
||||
|
@ -263,8 +283,8 @@ func TestReadPacketFail(t *testing.T) {
|
|||
// fail to read header
|
||||
conn.closed = true
|
||||
_, err = mc.readPacket()
|
||||
if err != driver.ErrBadConn {
|
||||
t.Errorf("expected ErrBadConn, got %v", err)
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %v", err)
|
||||
}
|
||||
|
||||
// reset
|
||||
|
@ -276,7 +296,41 @@ func TestReadPacketFail(t *testing.T) {
|
|||
// fail to read body
|
||||
conn.maxReads = 1
|
||||
_, err = mc.readPacket()
|
||||
if err != driver.ErrBadConn {
|
||||
t.Errorf("expected ErrBadConn, got %v", err)
|
||||
if err != ErrInvalidConn {
|
||||
t.Errorf("expected ErrInvalidConn, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/go-sql-driver/mysql/pull/801
|
||||
// not-NUL terminated plugin_name in init packet
|
||||
func TestRegression801(t *testing.T) {
|
||||
conn := new(mockConn)
|
||||
mc := &mysqlConn{
|
||||
buf: newBuffer(conn),
|
||||
cfg: new(Config),
|
||||
sequence: 42,
|
||||
closech: make(chan struct{}),
|
||||
}
|
||||
|
||||
conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
|
||||
60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77,
|
||||
50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95,
|
||||
112, 97, 115, 115, 119, 111, 114, 100}
|
||||
conn.maxReads = 1
|
||||
|
||||
authData, pluginName, err := mc.readHandshakePacket()
|
||||
if err != nil {
|
||||
t.Fatalf("got error: %v", err)
|
||||
}
|
||||
|
||||
if pluginName != "mysql_native_password" {
|
||||
t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
|
||||
}
|
||||
|
||||
expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114,
|
||||
47, 85, 75, 109, 99, 51, 77, 50, 64}
|
||||
if !bytes.Equal(authData, expectedAuthData) {
|
||||
t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,19 +11,20 @@ package mysql
|
|||
import (
|
||||
"database/sql/driver"
|
||||
"io"
|
||||
"math"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type mysqlField struct {
|
||||
tableName string
|
||||
name string
|
||||
flags fieldFlag
|
||||
fieldType byte
|
||||
decimals byte
|
||||
type resultSet struct {
|
||||
columns []mysqlField
|
||||
columnNames []string
|
||||
done bool
|
||||
}
|
||||
|
||||
type mysqlRows struct {
|
||||
mc *mysqlConn
|
||||
columns []mysqlField
|
||||
mc *mysqlConn
|
||||
rs resultSet
|
||||
finish func()
|
||||
}
|
||||
|
||||
type binaryRows struct {
|
||||
|
@ -34,37 +35,86 @@ type textRows struct {
|
|||
mysqlRows
|
||||
}
|
||||
|
||||
type emptyRows struct{}
|
||||
|
||||
func (rows *mysqlRows) Columns() []string {
|
||||
columns := make([]string, len(rows.columns))
|
||||
if rows.rs.columnNames != nil {
|
||||
return rows.rs.columnNames
|
||||
}
|
||||
|
||||
columns := make([]string, len(rows.rs.columns))
|
||||
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
|
||||
for i := range columns {
|
||||
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
|
||||
columns[i] = tableName + "." + rows.columns[i].name
|
||||
if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 {
|
||||
columns[i] = tableName + "." + rows.rs.columns[i].name
|
||||
} else {
|
||||
columns[i] = rows.columns[i].name
|
||||
columns[i] = rows.rs.columns[i].name
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for i := range columns {
|
||||
columns[i] = rows.columns[i].name
|
||||
columns[i] = rows.rs.columns[i].name
|
||||
}
|
||||
}
|
||||
|
||||
rows.rs.columnNames = columns
|
||||
return columns
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) Close() error {
|
||||
func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string {
|
||||
return rows.rs.columns[i].typeDatabaseName()
|
||||
}
|
||||
|
||||
// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) {
|
||||
// return int64(rows.rs.columns[i].length), true
|
||||
// }
|
||||
|
||||
func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) {
|
||||
return rows.rs.columns[i].flags&flagNotNULL == 0, true
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) {
|
||||
column := rows.rs.columns[i]
|
||||
decimals := int64(column.decimals)
|
||||
|
||||
switch column.fieldType {
|
||||
case fieldTypeDecimal, fieldTypeNewDecimal:
|
||||
if decimals > 0 {
|
||||
return int64(column.length) - 2, decimals, true
|
||||
}
|
||||
return int64(column.length) - 1, decimals, true
|
||||
case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime:
|
||||
return decimals, decimals, true
|
||||
case fieldTypeFloat, fieldTypeDouble:
|
||||
if decimals == 0x1f {
|
||||
return math.MaxInt64, math.MaxInt64, true
|
||||
}
|
||||
return math.MaxInt64, decimals, true
|
||||
}
|
||||
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type {
|
||||
return rows.rs.columns[i].scanType()
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) Close() (err error) {
|
||||
if f := rows.finish; f != nil {
|
||||
f()
|
||||
rows.finish = nil
|
||||
}
|
||||
|
||||
mc := rows.mc
|
||||
if mc == nil {
|
||||
return nil
|
||||
}
|
||||
if mc.netConn == nil {
|
||||
return ErrInvalidConn
|
||||
if err := mc.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Remove unread packets from stream
|
||||
err := mc.readUntilEOF()
|
||||
if !rows.rs.done {
|
||||
err = mc.readUntilEOF()
|
||||
}
|
||||
if err == nil {
|
||||
if err = mc.discardResults(); err != nil {
|
||||
return err
|
||||
|
@ -75,10 +125,66 @@ func (rows *mysqlRows) Close() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) HasNextResultSet() (b bool) {
|
||||
if rows.mc == nil {
|
||||
return false
|
||||
}
|
||||
return rows.mc.status&statusMoreResultsExists != 0
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) nextResultSet() (int, error) {
|
||||
if rows.mc == nil {
|
||||
return 0, io.EOF
|
||||
}
|
||||
if err := rows.mc.error(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Remove unread packets from stream
|
||||
if !rows.rs.done {
|
||||
if err := rows.mc.readUntilEOF(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows.rs.done = true
|
||||
}
|
||||
|
||||
if !rows.HasNextResultSet() {
|
||||
rows.mc = nil
|
||||
return 0, io.EOF
|
||||
}
|
||||
rows.rs = resultSet{}
|
||||
return rows.mc.readResultSetHeaderPacket()
|
||||
}
|
||||
|
||||
func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) {
|
||||
for {
|
||||
resLen, err := rows.nextResultSet()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if resLen > 0 {
|
||||
return resLen, nil
|
||||
}
|
||||
|
||||
rows.rs.done = true
|
||||
}
|
||||
}
|
||||
|
||||
func (rows *binaryRows) NextResultSet() error {
|
||||
resLen, err := rows.nextNotEmptyResultSet()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||
return err
|
||||
}
|
||||
|
||||
func (rows *binaryRows) Next(dest []driver.Value) error {
|
||||
if mc := rows.mc; mc != nil {
|
||||
if mc.netConn == nil {
|
||||
return ErrInvalidConn
|
||||
if err := mc.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Fetch next row from stream
|
||||
|
@ -87,10 +193,20 @@ func (rows *binaryRows) Next(dest []driver.Value) error {
|
|||
return io.EOF
|
||||
}
|
||||
|
||||
func (rows *textRows) NextResultSet() (err error) {
|
||||
resLen, err := rows.nextNotEmptyResultSet()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rows.rs.columns, err = rows.mc.readColumns(resLen)
|
||||
return err
|
||||
}
|
||||
|
||||
func (rows *textRows) Next(dest []driver.Value) error {
|
||||
if mc := rows.mc; mc != nil {
|
||||
if mc.netConn == nil {
|
||||
return ErrInvalidConn
|
||||
if err := mc.error(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Fetch next row from stream
|
||||
|
@ -98,15 +214,3 @@ func (rows *textRows) Next(dest []driver.Value) error {
|
|||
}
|
||||
return io.EOF
|
||||
}
|
||||
|
||||
func (rows emptyRows) Columns() []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rows emptyRows) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rows emptyRows) Next(dest []driver.Value) error {
|
||||
return io.EOF
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ package mysql
|
|||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"strconv"
|
||||
)
|
||||
|
@ -19,11 +20,10 @@ type mysqlStmt struct {
|
|||
mc *mysqlConn
|
||||
id uint32
|
||||
paramCount int
|
||||
columns []mysqlField // cached from the first query
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Close() error {
|
||||
if stmt.mc == nil || stmt.mc.netConn == nil {
|
||||
if stmt.mc == nil || stmt.mc.closed.IsSet() {
|
||||
// driver.Stmt.Close can be called more than once, thus this function
|
||||
// has to be idempotent.
|
||||
// See also Issue #450 and golang/go#16019.
|
||||
|
@ -45,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter {
|
|||
}
|
||||
|
||||
func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
||||
if stmt.mc.netConn == nil {
|
||||
if stmt.mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, stmt.mc.markBadConn(err)
|
||||
}
|
||||
|
||||
mc := stmt.mc
|
||||
|
@ -62,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) {
|
|||
|
||||
// Read Result
|
||||
resLen, err := mc.readResultSetHeaderPacket()
|
||||
if err == nil {
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
err = mc.readUntilEOF()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Rows
|
||||
err = mc.readUntilEOF()
|
||||
if resLen > 0 {
|
||||
// Columns
|
||||
if err = mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err == nil {
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(mc.affectedRows),
|
||||
insertId: int64(mc.insertId),
|
||||
}, nil
|
||||
|
||||
// Rows
|
||||
if err := mc.readUntilEOF(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, err
|
||||
if err := mc.discardResults(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &mysqlResult{
|
||||
affectedRows: int64(mc.affectedRows),
|
||||
insertId: int64(mc.insertId),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
||||
if stmt.mc.netConn == nil {
|
||||
return stmt.query(args)
|
||||
}
|
||||
|
||||
func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) {
|
||||
if stmt.mc.closed.IsSet() {
|
||||
errLog.Print(ErrInvalidConn)
|
||||
return nil, driver.ErrBadConn
|
||||
}
|
||||
// Send command
|
||||
err := stmt.writeExecutePacket(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, stmt.mc.markBadConn(err)
|
||||
}
|
||||
|
||||
mc := stmt.mc
|
||||
|
@ -107,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
|||
|
||||
if resLen > 0 {
|
||||
rows.mc = mc
|
||||
// Columns
|
||||
// If not cached, read them and cache them
|
||||
if stmt.columns == nil {
|
||||
rows.columns, err = mc.readColumns(resLen)
|
||||
stmt.columns = rows.columns
|
||||
} else {
|
||||
rows.columns = stmt.columns
|
||||
err = mc.readUntilEOF()
|
||||
rows.rs.columns, err = mc.readColumns(resLen)
|
||||
} else {
|
||||
rows.rs.done = true
|
||||
|
||||
switch err := rows.NextResultSet(); err {
|
||||
case nil, io.EOF:
|
||||
return rows, nil
|
||||
default:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
|
|||
|
||||
type converter struct{}
|
||||
|
||||
// ConvertValue mirrors the reference/default converter in database/sql/driver
|
||||
// with _one_ exception. We support uint64 with their high bit and the default
|
||||
// implementation does not. This function should be kept in sync with
|
||||
// database/sql/driver defaultConverter.ConvertValue() except for that
|
||||
// deliberate difference.
|
||||
func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
||||
if driver.IsValue(v) {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
if vr, ok := v.(driver.Valuer); ok {
|
||||
sv, err := callValuerValue(vr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !driver.IsValue(sv) {
|
||||
return nil, fmt.Errorf("non-Value type %T returned from Value", sv)
|
||||
}
|
||||
return sv, nil
|
||||
}
|
||||
|
||||
rv := reflect.ValueOf(v)
|
||||
switch rv.Kind() {
|
||||
case reflect.Ptr:
|
||||
// indirect pointers
|
||||
if rv.IsNil() {
|
||||
return nil, nil
|
||||
} else {
|
||||
return c.ConvertValue(rv.Elem().Interface())
|
||||
}
|
||||
return c.ConvertValue(rv.Elem().Interface())
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
return rv.Int(), nil
|
||||
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
||||
|
@ -148,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
|
|||
return int64(u64), nil
|
||||
case reflect.Float32, reflect.Float64:
|
||||
return rv.Float(), nil
|
||||
case reflect.Bool:
|
||||
return rv.Bool(), nil
|
||||
case reflect.Slice:
|
||||
ek := rv.Type().Elem().Kind()
|
||||
if ek == reflect.Uint8 {
|
||||
return rv.Bytes(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek)
|
||||
case reflect.String:
|
||||
return rv.String(), nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind())
|
||||
}
|
||||
|
||||
var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem()
|
||||
|
||||
// callValuerValue returns vr.Value(), with one exception:
|
||||
// If vr.Value is an auto-generated method on a pointer type and the
|
||||
// pointer is nil, it would panic at runtime in the panicwrap
|
||||
// method. Treat it like nil instead.
|
||||
//
|
||||
// This is so people can implement driver.Value on value types and
|
||||
// still use nil pointers to those types to mean nil/NULL, just like
|
||||
// string/*string.
|
||||
//
|
||||
// This is an exact copy of the same-named unexported function from the
|
||||
// database/sql package.
|
||||
func callValuerValue(vr driver.Valuer) (v driver.Value, err error) {
|
||||
if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr &&
|
||||
rv.IsNil() &&
|
||||
rv.Type().Elem().Implements(valuerReflectType) {
|
||||
return nil, nil
|
||||
}
|
||||
return vr.Value()
|
||||
}
|
||||
|
|
|
@ -0,0 +1,126 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConvertDerivedString(t *testing.T) {
|
||||
type derived string
|
||||
|
||||
output, err := converter{}.ConvertValue(derived("value"))
|
||||
if err != nil {
|
||||
t.Fatal("Derived string type not convertible", err)
|
||||
}
|
||||
|
||||
if output != "value" {
|
||||
t.Fatalf("Derived string type not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertDerivedByteSlice(t *testing.T) {
|
||||
type derived []uint8
|
||||
|
||||
output, err := converter{}.ConvertValue(derived("value"))
|
||||
if err != nil {
|
||||
t.Fatal("Byte slice not convertible", err)
|
||||
}
|
||||
|
||||
if bytes.Compare(output.([]byte), []byte("value")) != 0 {
|
||||
t.Fatalf("Byte slice not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertDerivedUnsupportedSlice(t *testing.T) {
|
||||
type derived []int
|
||||
|
||||
_, err := converter{}.ConvertValue(derived{1})
|
||||
if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
|
||||
t.Fatal("Unexpected error", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertDerivedBool(t *testing.T) {
|
||||
type derived bool
|
||||
|
||||
output, err := converter{}.ConvertValue(derived(true))
|
||||
if err != nil {
|
||||
t.Fatal("Derived bool type not convertible", err)
|
||||
}
|
||||
|
||||
if output != true {
|
||||
t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertPointer(t *testing.T) {
|
||||
str := "value"
|
||||
|
||||
output, err := converter{}.ConvertValue(&str)
|
||||
if err != nil {
|
||||
t.Fatal("Pointer type not convertible", err)
|
||||
}
|
||||
|
||||
if output != "value" {
|
||||
t.Fatalf("Pointer type not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertSignedIntegers(t *testing.T) {
|
||||
values := []interface{}{
|
||||
int8(-42),
|
||||
int16(-42),
|
||||
int32(-42),
|
||||
int64(-42),
|
||||
int(-42),
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
output, err := converter{}.ConvertValue(value)
|
||||
if err != nil {
|
||||
t.Fatalf("%T type not convertible %s", value, err)
|
||||
}
|
||||
|
||||
if output != int64(-42) {
|
||||
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertUnsignedIntegers(t *testing.T) {
|
||||
values := []interface{}{
|
||||
uint8(42),
|
||||
uint16(42),
|
||||
uint32(42),
|
||||
uint64(42),
|
||||
uint(42),
|
||||
}
|
||||
|
||||
for _, value := range values {
|
||||
output, err := converter{}.ConvertValue(value)
|
||||
if err != nil {
|
||||
t.Fatalf("%T type not convertible %s", value, err)
|
||||
}
|
||||
|
||||
if output != int64(42) {
|
||||
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
|
||||
}
|
||||
}
|
||||
|
||||
output, err := converter{}.ConvertValue(^uint64(0))
|
||||
if err != nil {
|
||||
t.Fatal("uint64 high-bit not convertible", err)
|
||||
}
|
||||
|
||||
if output != "18446744073709551615" {
|
||||
t.Fatalf("uint64 high-bit not converted, got %#v %T", output, output)
|
||||
}
|
||||
}
|
|
@ -13,7 +13,7 @@ type mysqlTx struct {
|
|||
}
|
||||
|
||||
func (tx *mysqlTx) Commit() (err error) {
|
||||
if tx.mc == nil || tx.mc.netConn == nil {
|
||||
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||
return ErrInvalidConn
|
||||
}
|
||||
err = tx.mc.exec("COMMIT")
|
||||
|
@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) {
|
|||
}
|
||||
|
||||
func (tx *mysqlTx) Rollback() (err error) {
|
||||
if tx.mc == nil || tx.mc.netConn == nil {
|
||||
if tx.mc == nil || tx.mc.closed.IsSet() {
|
||||
return ErrInvalidConn
|
||||
}
|
||||
err = tx.mc.exec("ROLLBACK")
|
||||
|
|
|
@ -9,23 +9,29 @@
|
|||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"crypto/tls"
|
||||
"database/sql/driver"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Registry for custom tls.Configs
|
||||
var (
|
||||
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
|
||||
tlsConfigLock sync.RWMutex
|
||||
tlsConfigRegistry map[string]*tls.Config
|
||||
)
|
||||
|
||||
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
|
||||
// Use the key as a value in the DSN where tls=value.
|
||||
//
|
||||
// Note: The provided tls.Config is exclusively owned by the driver after
|
||||
// registering it.
|
||||
//
|
||||
// rootCertPool := x509.NewCertPool()
|
||||
// pem, err := ioutil.ReadFile("/path/ca-cert.pem")
|
||||
// if err != nil {
|
||||
|
@ -51,19 +57,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error {
|
|||
return fmt.Errorf("key '%s' is reserved", key)
|
||||
}
|
||||
|
||||
if tlsConfigRegister == nil {
|
||||
tlsConfigRegister = make(map[string]*tls.Config)
|
||||
tlsConfigLock.Lock()
|
||||
if tlsConfigRegistry == nil {
|
||||
tlsConfigRegistry = make(map[string]*tls.Config)
|
||||
}
|
||||
|
||||
tlsConfigRegister[key] = config
|
||||
tlsConfigRegistry[key] = config
|
||||
tlsConfigLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeregisterTLSConfig removes the tls.Config associated with key.
|
||||
func DeregisterTLSConfig(key string) {
|
||||
if tlsConfigRegister != nil {
|
||||
delete(tlsConfigRegister, key)
|
||||
tlsConfigLock.Lock()
|
||||
if tlsConfigRegistry != nil {
|
||||
delete(tlsConfigRegistry, key)
|
||||
}
|
||||
tlsConfigLock.Unlock()
|
||||
}
|
||||
|
||||
func getTLSConfigClone(key string) (config *tls.Config) {
|
||||
tlsConfigLock.RLock()
|
||||
if v, ok := tlsConfigRegistry[key]; ok {
|
||||
config = cloneTLSConfig(v)
|
||||
}
|
||||
tlsConfigLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Returns the bool value of the input.
|
||||
|
@ -80,119 +99,6 @@ func readBool(input string) (value bool, valid bool) {
|
|||
return
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Authentication *
|
||||
******************************************************************************/
|
||||
|
||||
// Encrypt password using 4.1+ method
|
||||
func scramblePassword(scramble, password []byte) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stage1Hash = SHA1(password)
|
||||
crypt := sha1.New()
|
||||
crypt.Write(password)
|
||||
stage1 := crypt.Sum(nil)
|
||||
|
||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
||||
// inner Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(stage1)
|
||||
hash := crypt.Sum(nil)
|
||||
|
||||
// outer Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(scramble)
|
||||
crypt.Write(hash)
|
||||
scramble = crypt.Sum(nil)
|
||||
|
||||
// token = scrambleHash XOR stage1Hash
|
||||
for i := range scramble {
|
||||
scramble[i] ^= stage1[i]
|
||||
}
|
||||
return scramble
|
||||
}
|
||||
|
||||
// Encrypt password using pre 4.1 (old password) method
|
||||
// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c
|
||||
type myRnd struct {
|
||||
seed1, seed2 uint32
|
||||
}
|
||||
|
||||
const myRndMaxVal = 0x3FFFFFFF
|
||||
|
||||
// Pseudo random number generator
|
||||
func newMyRnd(seed1, seed2 uint32) *myRnd {
|
||||
return &myRnd{
|
||||
seed1: seed1 % myRndMaxVal,
|
||||
seed2: seed2 % myRndMaxVal,
|
||||
}
|
||||
}
|
||||
|
||||
// Tested to be equivalent to MariaDB's floating point variant
|
||||
// http://play.golang.org/p/QHvhd4qved
|
||||
// http://play.golang.org/p/RG0q4ElWDx
|
||||
func (r *myRnd) NextByte() byte {
|
||||
r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal
|
||||
r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal
|
||||
|
||||
return byte(uint64(r.seed1) * 31 / myRndMaxVal)
|
||||
}
|
||||
|
||||
// Generate binary hash from byte string using insecure pre 4.1 method
|
||||
func pwHash(password []byte) (result [2]uint32) {
|
||||
var add uint32 = 7
|
||||
var tmp uint32
|
||||
|
||||
result[0] = 1345345333
|
||||
result[1] = 0x12345671
|
||||
|
||||
for _, c := range password {
|
||||
// skip spaces and tabs in password
|
||||
if c == ' ' || c == '\t' {
|
||||
continue
|
||||
}
|
||||
|
||||
tmp = uint32(c)
|
||||
result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8)
|
||||
result[1] += (result[1] << 8) ^ result[0]
|
||||
add += tmp
|
||||
}
|
||||
|
||||
// Remove sign bit (1<<31)-1)
|
||||
result[0] &= 0x7FFFFFFF
|
||||
result[1] &= 0x7FFFFFFF
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Encrypt password using insecure pre 4.1 method
|
||||
func scrambleOldPassword(scramble, password []byte) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
scramble = scramble[:8]
|
||||
|
||||
hashPw := pwHash(password)
|
||||
hashSc := pwHash(scramble)
|
||||
|
||||
r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1])
|
||||
|
||||
var out [8]byte
|
||||
for i := range out {
|
||||
out[i] = r.NextByte() + 64
|
||||
}
|
||||
|
||||
mask := r.NextByte()
|
||||
for i := range out {
|
||||
out[i] ^= mask
|
||||
}
|
||||
|
||||
return out[:]
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Time related utils *
|
||||
******************************************************************************/
|
||||
|
@ -519,7 +425,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) {
|
|||
|
||||
// Check data length
|
||||
if len(b) >= n {
|
||||
return b[n-int(num) : n], false, n, nil
|
||||
return b[n-int(num) : n : n], false, n, nil
|
||||
}
|
||||
return nil, false, n, io.EOF
|
||||
}
|
||||
|
@ -548,8 +454,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) {
|
|||
if len(b) == 0 {
|
||||
return 0, true, 1
|
||||
}
|
||||
switch b[0] {
|
||||
|
||||
switch b[0] {
|
||||
// 251: NULL
|
||||
case 0xfb:
|
||||
return 0, true, 1
|
||||
|
@ -738,3 +644,67 @@ func escapeStringQuotes(buf []byte, v string) []byte {
|
|||
|
||||
return buf[:pos]
|
||||
}
|
||||
|
||||
/******************************************************************************
|
||||
* Sync utils *
|
||||
******************************************************************************/
|
||||
|
||||
// noCopy may be embedded into structs which must not be copied
|
||||
// after the first use.
|
||||
//
|
||||
// See https://github.com/golang/go/issues/8005#issuecomment-190753527
|
||||
// for details.
|
||||
type noCopy struct{}
|
||||
|
||||
// Lock is a no-op used by -copylocks checker from `go vet`.
|
||||
func (*noCopy) Lock() {}
|
||||
|
||||
// atomicBool is a wrapper around uint32 for usage as a boolean value with
|
||||
// atomic access.
|
||||
type atomicBool struct {
|
||||
_noCopy noCopy
|
||||
value uint32
|
||||
}
|
||||
|
||||
// IsSet returns wether the current boolean value is true
|
||||
func (ab *atomicBool) IsSet() bool {
|
||||
return atomic.LoadUint32(&ab.value) > 0
|
||||
}
|
||||
|
||||
// Set sets the value of the bool regardless of the previous value
|
||||
func (ab *atomicBool) Set(value bool) {
|
||||
if value {
|
||||
atomic.StoreUint32(&ab.value, 1)
|
||||
} else {
|
||||
atomic.StoreUint32(&ab.value, 0)
|
||||
}
|
||||
}
|
||||
|
||||
// TrySet sets the value of the bool and returns wether the value changed
|
||||
func (ab *atomicBool) TrySet(value bool) bool {
|
||||
if value {
|
||||
return atomic.SwapUint32(&ab.value, 1) == 0
|
||||
}
|
||||
return atomic.SwapUint32(&ab.value, 0) > 0
|
||||
}
|
||||
|
||||
// atomicError is a wrapper for atomically accessed error values
|
||||
type atomicError struct {
|
||||
_noCopy noCopy
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
// Set sets the error value regardless of the previous value.
|
||||
// The value must not be nil
|
||||
func (ae *atomicError) Set(value error) {
|
||||
ae.value.Store(value)
|
||||
}
|
||||
|
||||
// Value returns the current error value
|
||||
func (ae *atomicError) Value() error {
|
||||
if v := ae.value.Load(); v != nil {
|
||||
// this will panic if the value doesn't implement the error interface
|
||||
return v.(error)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.7
|
||||
// +build !go1.8
|
||||
|
||||
package mysql
|
||||
|
||||
import "crypto/tls"
|
||||
|
||||
func cloneTLSConfig(c *tls.Config) *tls.Config {
|
||||
return &tls.Config{
|
||||
Rand: c.Rand,
|
||||
Time: c.Time,
|
||||
Certificates: c.Certificates,
|
||||
NameToCertificate: c.NameToCertificate,
|
||||
GetCertificate: c.GetCertificate,
|
||||
RootCAs: c.RootCAs,
|
||||
NextProtos: c.NextProtos,
|
||||
ServerName: c.ServerName,
|
||||
ClientAuth: c.ClientAuth,
|
||||
ClientCAs: c.ClientCAs,
|
||||
InsecureSkipVerify: c.InsecureSkipVerify,
|
||||
CipherSuites: c.CipherSuites,
|
||||
PreferServerCipherSuites: c.PreferServerCipherSuites,
|
||||
SessionTicketsDisabled: c.SessionTicketsDisabled,
|
||||
SessionTicketKey: c.SessionTicketKey,
|
||||
ClientSessionCache: c.ClientSessionCache,
|
||||
MinVersion: c.MinVersion,
|
||||
MaxVersion: c.MaxVersion,
|
||||
CurvePreferences: c.CurvePreferences,
|
||||
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
|
||||
Renegotiation: c.Renegotiation,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,50 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
func cloneTLSConfig(c *tls.Config) *tls.Config {
|
||||
return c.Clone()
|
||||
}
|
||||
|
||||
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
|
||||
dargs := make([]driver.Value, len(named))
|
||||
for n, param := range named {
|
||||
if len(param.Name) > 0 {
|
||||
// TODO: support the use of Named Parameters #561
|
||||
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
|
||||
}
|
||||
dargs[n] = param.Value
|
||||
}
|
||||
return dargs, nil
|
||||
}
|
||||
|
||||
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
|
||||
switch sql.IsolationLevel(level) {
|
||||
case sql.LevelRepeatableRead:
|
||||
return "REPEATABLE READ", nil
|
||||
case sql.LevelReadCommitted:
|
||||
return "READ COMMITTED", nil
|
||||
case sql.LevelReadUncommitted:
|
||||
return "READ UNCOMMITTED", nil
|
||||
case sql.LevelSerializable:
|
||||
return "SERIALIZABLE", nil
|
||||
default:
|
||||
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
|
||||
//
|
||||
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
|
||||
// You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
// +build go1.8
|
||||
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestIsolationLevelMapping(t *testing.T) {
|
||||
data := []struct {
|
||||
level driver.IsolationLevel
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelReadCommitted),
|
||||
expected: "READ COMMITTED",
|
||||
},
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelRepeatableRead),
|
||||
expected: "REPEATABLE READ",
|
||||
},
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelReadUncommitted),
|
||||
expected: "READ UNCOMMITTED",
|
||||
},
|
||||
{
|
||||
level: driver.IsolationLevel(sql.LevelSerializable),
|
||||
expected: "SERIALIZABLE",
|
||||
},
|
||||
}
|
||||
|
||||
for i, td := range data {
|
||||
if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil {
|
||||
t.Fatal(i, td.expected, actual, err)
|
||||
}
|
||||
}
|
||||
|
||||
// check unsupported mapping
|
||||
expectedErr := "mysql: unsupported isolation level: 7"
|
||||
actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable))
|
||||
if actual != "" || err == nil {
|
||||
t.Fatal("Expected error on unsupported isolation level")
|
||||
}
|
||||
if err.Error() != expectedErr {
|
||||
t.Fatalf("Expected error to be %q, got %q", expectedErr, err)
|
||||
}
|
||||
}
|
|
@ -11,7 +11,6 @@ package mysql
|
|||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
@ -93,25 +92,6 @@ func TestLengthEncodedInteger(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestOldPass(t *testing.T) {
|
||||
scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2}
|
||||
vectors := []struct {
|
||||
pass string
|
||||
out string
|
||||
}{
|
||||
{" pass", "47575c5a435b4251"},
|
||||
{"pass ", "47575c5a435b4251"},
|
||||
{"123\t456", "575c47505b5b5559"},
|
||||
{"C0mpl!ca ted#PASS123", "5d5d554849584a45"},
|
||||
}
|
||||
for _, tuple := range vectors {
|
||||
ours := scrambleOldPassword(scramble, []byte(tuple.pass))
|
||||
if tuple.out != fmt.Sprintf("%x", ours) {
|
||||
t.Errorf("Failed old password %q", tuple.pass)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatBinaryDateTime(t *testing.T) {
|
||||
rawDate := [11]byte{}
|
||||
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
|
||||
|
@ -195,3 +175,83 @@ func TestEscapeQuotes(t *testing.T) {
|
|||
expect("foo''bar", "foo'bar") // affected
|
||||
expect("foo\"bar", "foo\"bar") // not affected
|
||||
}
|
||||
|
||||
func TestAtomicBool(t *testing.T) {
|
||||
var ab atomicBool
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
|
||||
ab.Set(true)
|
||||
if ab.value != 1 {
|
||||
t.Fatal("Set(true) did not set value to 1")
|
||||
}
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
|
||||
ab.Set(true)
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
|
||||
ab.Set(false)
|
||||
if ab.value != 0 {
|
||||
t.Fatal("Set(false) did not set value to 0")
|
||||
}
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
|
||||
ab.Set(false)
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
if ab.TrySet(false) {
|
||||
t.Fatal("Expected TrySet(false) to fail")
|
||||
}
|
||||
if !ab.TrySet(true) {
|
||||
t.Fatal("Expected TrySet(true) to succeed")
|
||||
}
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
|
||||
ab.Set(true)
|
||||
if !ab.IsSet() {
|
||||
t.Fatal("Expected value to be true")
|
||||
}
|
||||
if ab.TrySet(true) {
|
||||
t.Fatal("Expected TrySet(true) to fail")
|
||||
}
|
||||
if !ab.TrySet(false) {
|
||||
t.Fatal("Expected TrySet(false) to succeed")
|
||||
}
|
||||
if ab.IsSet() {
|
||||
t.Fatal("Expected value to be false")
|
||||
}
|
||||
|
||||
ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
|
||||
}
|
||||
|
||||
func TestAtomicError(t *testing.T) {
|
||||
var ae atomicError
|
||||
if ae.Value() != nil {
|
||||
t.Fatal("Expected value to be nil")
|
||||
}
|
||||
|
||||
ae.Set(ErrMalformPkt)
|
||||
if v := ae.Value(); v != ErrMalformPkt {
|
||||
if v == nil {
|
||||
t.Fatal("Value is still nil")
|
||||
}
|
||||
t.Fatal("Error did not match")
|
||||
}
|
||||
ae.Set(ErrPktSync)
|
||||
if ae.Value() == ErrMalformPkt {
|
||||
t.Fatal("Error still matches old error")
|
||||
}
|
||||
if v := ae.Value(); v != ErrPktSync {
|
||||
t.Fatal("Error did not match")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
language: go
|
||||
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
|
||||
go:
|
||||
- 1.6.x
|
||||
- 1.x
|
||||
- 1.7.x
|
||||
- 1.8.x
|
||||
- 1.9.x
|
||||
- 1.10.x
|
||||
|
||||
install:
|
||||
# go-flags
|
||||
|
@ -10,8 +17,7 @@ install:
|
|||
- go build -v ./...
|
||||
|
||||
# linting
|
||||
- go get github.com/golang/lint
|
||||
- go install github.com/golang/lint/golint
|
||||
- go get github.com/golang/lint/golint
|
||||
|
||||
# code coverage
|
||||
- go get golang.org/x/tools/cmd/cover
|
||||
|
|
|
@ -110,7 +110,6 @@ args, err := flags.ParseArgs(&opts, args)
|
|||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
fmt.Printf("Verbosity: %v\n", opts.Verbose)
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Command represents an application command. Commands can be added to the
|
||||
|
@ -229,7 +228,17 @@ func (c *Command) scanSubcommandHandler(parentg *Group) scanHandler {
|
|||
subcommand := mtag.Get("command")
|
||||
|
||||
if len(subcommand) != 0 {
|
||||
ptrval := reflect.NewAt(realval.Type(), unsafe.Pointer(realval.UnsafeAddr()))
|
||||
var ptrval reflect.Value
|
||||
|
||||
if realval.Kind() == reflect.Ptr {
|
||||
ptrval = realval
|
||||
|
||||
if ptrval.IsNil() {
|
||||
ptrval.Set(reflect.New(ptrval.Type().Elem()))
|
||||
}
|
||||
} else {
|
||||
ptrval = realval.Addr()
|
||||
}
|
||||
|
||||
shortDescription := mtag.Get("description")
|
||||
longDescription := mtag.Get("long-description")
|
||||
|
@ -237,6 +246,7 @@ func (c *Command) scanSubcommandHandler(parentg *Group) scanHandler {
|
|||
aliases := mtag.GetMany("alias")
|
||||
|
||||
subc, err := c.AddCommand(subcommand, shortDescription, longDescription, ptrval.Interface())
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"reflect"
|
||||
"strings"
|
||||
"unicode/utf8"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// ErrNotPointerToStruct indicates that a provided data container is not
|
||||
|
@ -338,10 +337,22 @@ func (g *Group) scanSubGroupHandler(realval reflect.Value, sfield *reflect.Struc
|
|||
subgroup := mtag.Get("group")
|
||||
|
||||
if len(subgroup) != 0 {
|
||||
ptrval := reflect.NewAt(realval.Type(), unsafe.Pointer(realval.UnsafeAddr()))
|
||||
var ptrval reflect.Value
|
||||
|
||||
if realval.Kind() == reflect.Ptr {
|
||||
ptrval = realval
|
||||
|
||||
if ptrval.IsNil() {
|
||||
ptrval.Set(reflect.New(ptrval.Type()))
|
||||
}
|
||||
} else {
|
||||
ptrval = realval.Addr()
|
||||
}
|
||||
|
||||
description := mtag.Get("description")
|
||||
|
||||
group, err := g.AddGroup(subgroup, description, ptrval.Interface())
|
||||
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
|
|
|
@ -3,9 +3,9 @@ package flags
|
|||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
|
@ -261,9 +261,7 @@ func (option *Option) clearDefault() {
|
|||
usedDefault := option.Default
|
||||
|
||||
if envKey := option.EnvDefaultKey; envKey != "" {
|
||||
// os.Getenv() makes no distinction between undefined and
|
||||
// empty values, so we use syscall.Getenv()
|
||||
if value, ok := syscall.Getenv(envKey); ok {
|
||||
if value, ok := os.LookupEnv(envKey); ok {
|
||||
if option.EnvDefaultDelim != "" {
|
||||
usedDefault = strings.Split(value,
|
||||
option.EnvDefaultDelim)
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// +build !windows,!plan9,!solaris
|
||||
// +build !windows,!plan9,!solaris,!appengine
|
||||
|
||||
package flags
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
// +build windows plan9 solaris
|
||||
// +build windows plan9 solaris appengine
|
||||
|
||||
package flags
|
||||
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
language: go
|
||||
|
||||
go:
|
||||
- 1.6.3
|
||||
- 1.7.1
|
||||
|
||||
install:
|
||||
- go get -v -t -d google.golang.org/appengine/...
|
||||
- mkdir sdk
|
||||
- curl -o sdk.zip "https://storage.googleapis.com/appengine-sdks/featured/go_appengine_sdk_linux_amd64-1.9.40.zip"
|
||||
- unzip -q sdk.zip -d sdk
|
||||
- export APPENGINE_DEV_APPSERVER=$(pwd)/sdk/go_appengine/dev_appserver.py
|
||||
|
||||
script:
|
||||
- go version
|
||||
- go test -v google.golang.org/appengine/...
|
||||
- go test -v -race google.golang.org/appengine/...
|
||||
- sdk/go_appengine/goapp test -v google.golang.org/appengine/...
|
|
@ -0,0 +1,202 @@
|
|||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,73 @@
|
|||
# Go App Engine packages
|
||||
|
||||
[![Build Status](https://travis-ci.org/golang/appengine.svg)](https://travis-ci.org/golang/appengine)
|
||||
|
||||
This repository supports the Go runtime on App Engine,
|
||||
including both the standard App Engine and the
|
||||
"App Engine flexible environment" (formerly known as "Managed VMs").
|
||||
It provides APIs for interacting with App Engine services.
|
||||
Its canonical import path is `google.golang.org/appengine`.
|
||||
|
||||
See https://cloud.google.com/appengine/docs/go/
|
||||
for more information.
|
||||
|
||||
File issue reports and feature requests on the [Google App Engine issue
|
||||
tracker](https://code.google.com/p/googleappengine/issues/entry?template=Go%20defect).
|
||||
|
||||
## Directory structure
|
||||
The top level directory of this repository is the `appengine` package. It
|
||||
contains the
|
||||
basic APIs (e.g. `appengine.NewContext`) that apply across APIs. Specific API
|
||||
packages are in subdirectories (e.g. `datastore`).
|
||||
|
||||
There is an `internal` subdirectory that contains service protocol buffers,
|
||||
plus packages required for connectivity to make API calls. App Engine apps
|
||||
should not directly import any package under `internal`.
|
||||
|
||||
## Updating a Go App Engine app
|
||||
|
||||
This section describes how to update an older Go App Engine app to use
|
||||
these packages. A provided tool, `aefix`, can help automate steps 2 and 3
|
||||
(run `go get google.golang.org/appengine/cmd/aefix` to install it), but
|
||||
read the details below since `aefix` can't perform all the changes.
|
||||
|
||||
### 1. Update YAML files (App Engine flexible environment / Managed VMs only)
|
||||
|
||||
The `app.yaml` file (and YAML files for modules) should have these new lines added:
|
||||
```
|
||||
vm: true
|
||||
```
|
||||
See https://cloud.google.com/appengine/docs/go/modules/#Go_Instance_scaling_and_class for details.
|
||||
|
||||
### 2. Update import paths
|
||||
|
||||
The import paths for App Engine packages are now fully qualified, based at `google.golang.org/appengine`.
|
||||
You will need to update your code to use import paths starting with that; for instance,
|
||||
code importing `appengine/datastore` will now need to import `google.golang.org/appengine/datastore`.
|
||||
|
||||
### 3. Update code using deprecated, removed or modified APIs
|
||||
|
||||
Most App Engine services are available with exactly the same API.
|
||||
A few APIs were cleaned up, and some are not available yet.
|
||||
This list summarises the differences:
|
||||
|
||||
* `appengine.Context` has been replaced with the `Context` type from `golang.org/x/net/context`.
|
||||
* Logging methods that were on `appengine.Context` are now functions in `google.golang.org/appengine/log`.
|
||||
* `appengine.Timeout` has been removed. Use `context.WithTimeout` instead.
|
||||
* `appengine.Datacenter` now takes a `context.Context` argument.
|
||||
* `datastore.PropertyLoadSaver` has been simplified to use slices in place of channels.
|
||||
* `delay.Call` now returns an error.
|
||||
* `search.FieldLoadSaver` now handles document metadata.
|
||||
* `urlfetch.Transport` no longer has a Deadline field; set a deadline on the
|
||||
`context.Context` instead.
|
||||
* `aetest` no longer declares its own Context type, and uses the standard one instead.
|
||||
* `taskqueue.QueueStats` no longer takes a maxTasks argument. That argument has been
|
||||
deprecated and unused for a long time.
|
||||
* `appengine.BackendHostname` and `appengine.BackendInstance` were for the deprecated backends feature.
|
||||
Use `appengine.ModuleHostname`and `appengine.ModuleName` instead.
|
||||
* Most of `appengine/file` and parts of `appengine/blobstore` are deprecated.
|
||||
Use [Google Cloud Storage](https://godoc.org/cloud.google.com/go/storage) if the
|
||||
feature you require is not present in the new
|
||||
[blobstore package](https://google.golang.org/appengine/blobstore).
|
||||
* `appengine/socket` is not required on App Engine flexible environment / Managed VMs.
|
||||
Use the standard `net` package instead.
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
Package aetest provides an API for running dev_appserver for use in tests.
|
||||
|
||||
An example test file:
|
||||
|
||||
package foo_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/appengine/memcache"
|
||||
"google.golang.org/appengine/aetest"
|
||||
)
|
||||
|
||||
func TestFoo(t *testing.T) {
|
||||
ctx, done, err := aetest.NewContext()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer done()
|
||||
|
||||
it := &memcache.Item{
|
||||
Key: "some-key",
|
||||
Value: []byte("some-value"),
|
||||
}
|
||||
err = memcache.Set(ctx, it)
|
||||
if err != nil {
|
||||
t.Fatalf("Set err: %v", err)
|
||||
}
|
||||
it, err = memcache.Get(ctx, "some-key")
|
||||
if err != nil {
|
||||
t.Fatalf("Get err: %v; want no error", err)
|
||||
}
|
||||
if g, w := string(it.Value), "some-value" ; g != w {
|
||||
t.Errorf("retrieved Item.Value = %q, want %q", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
The environment variable APPENGINE_DEV_APPSERVER specifies the location of the
|
||||
dev_appserver.py executable to use. If unset, the system PATH is consulted.
|
||||
*/
|
||||
package aetest
|
|
@ -0,0 +1,51 @@
|
|||
package aetest
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/appengine"
|
||||
)
|
||||
|
||||
// Instance represents a running instance of the development API Server.
|
||||
type Instance interface {
|
||||
// Close kills the child api_server.py process, releasing its resources.
|
||||
io.Closer
|
||||
// NewRequest returns an *http.Request associated with this instance.
|
||||
NewRequest(method, urlStr string, body io.Reader) (*http.Request, error)
|
||||
}
|
||||
|
||||
// Options is used to specify options when creating an Instance.
|
||||
type Options struct {
|
||||
// AppID specifies the App ID to use during tests.
|
||||
// By default, "testapp".
|
||||
AppID string
|
||||
// StronglyConsistentDatastore is whether the local datastore should be
|
||||
// strongly consistent. This will diverge from production behaviour.
|
||||
StronglyConsistentDatastore bool
|
||||
}
|
||||
|
||||
// NewContext starts an instance of the development API server, and returns
|
||||
// a context that will route all API calls to that server, as well as a
|
||||
// closure that must be called when the Context is no longer required.
|
||||
func NewContext() (context.Context, func(), error) {
|
||||
inst, err := NewInstance(nil)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
req, err := inst.NewRequest("GET", "/", nil)
|
||||
if err != nil {
|
||||
inst.Close()
|
||||
return nil, nil, err
|
||||
}
|
||||
ctx := appengine.NewContext(req)
|
||||
return ctx, func() {
|
||||
inst.Close()
|
||||
}, nil
|
||||
}
|
||||
|
||||
// PrepareDevAppserver is a hook which, if set, will be called before the
|
||||
// dev_appserver.py is started, each time it is started. If aetest.NewContext
|
||||
// is invoked from the goapp test tool, this hook is unnecessary.
|
||||
var PrepareDevAppserver func() error
|
|
@ -0,0 +1,21 @@
|
|||
// +build appengine
|
||||
|
||||
package aetest
|
||||
|
||||
import "appengine/aetest"
|
||||
|
||||
// NewInstance launches a running instance of api_server.py which can be used
|
||||
// for multiple test Contexts that delegate all App Engine API calls to that
|
||||
// instance.
|
||||
// If opts is nil the default values are used.
|
||||
func NewInstance(opts *Options) (Instance, error) {
|
||||
aetest.PrepareDevAppserver = PrepareDevAppserver
|
||||
var aeOpts *aetest.Options
|
||||
if opts != nil {
|
||||
aeOpts = &aetest.Options{
|
||||
AppID: opts.AppID,
|
||||
StronglyConsistentDatastore: opts.StronglyConsistentDatastore,
|
||||
}
|
||||
}
|
||||
return aetest.NewInstance(aeOpts)
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package aetest
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/datastore"
|
||||
"google.golang.org/appengine/memcache"
|
||||
"google.golang.org/appengine/user"
|
||||
)
|
||||
|
||||
func TestBasicAPICalls(t *testing.T) {
|
||||
// Only run the test if APPENGINE_DEV_APPSERVER is explicitly set.
|
||||
if os.Getenv("APPENGINE_DEV_APPSERVER") == "" {
|
||||
t.Skip("APPENGINE_DEV_APPSERVER not set")
|
||||
}
|
||||
|
||||
inst, err := NewInstance(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewInstance: %v", err)
|
||||
}
|
||||
defer inst.Close()
|
||||
|
||||
req, err := inst.NewRequest("GET", "http://example.com/page", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
ctx := appengine.NewContext(req)
|
||||
|
||||
it := &memcache.Item{
|
||||
Key: "some-key",
|
||||
Value: []byte("some-value"),
|
||||
}
|
||||
err = memcache.Set(ctx, it)
|
||||
if err != nil {
|
||||
t.Fatalf("Set err: %v", err)
|
||||
}
|
||||
it, err = memcache.Get(ctx, "some-key")
|
||||
if err != nil {
|
||||
t.Fatalf("Get err: %v; want no error", err)
|
||||
}
|
||||
if g, w := string(it.Value), "some-value"; g != w {
|
||||
t.Errorf("retrieved Item.Value = %q, want %q", g, w)
|
||||
}
|
||||
|
||||
type Entity struct{ Value string }
|
||||
e := &Entity{Value: "foo"}
|
||||
k := datastore.NewIncompleteKey(ctx, "Entity", nil)
|
||||
k, err = datastore.Put(ctx, k, e)
|
||||
if err != nil {
|
||||
t.Fatalf("datastore.Put: %v", err)
|
||||
}
|
||||
e = new(Entity)
|
||||
if err := datastore.Get(ctx, k, e); err != nil {
|
||||
t.Fatalf("datastore.Get: %v", err)
|
||||
}
|
||||
if g, w := e.Value, "foo"; g != w {
|
||||
t.Errorf("retrieved Entity.Value = %q, want %q", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContext(t *testing.T) {
|
||||
// Only run the test if APPENGINE_DEV_APPSERVER is explicitly set.
|
||||
if os.Getenv("APPENGINE_DEV_APPSERVER") == "" {
|
||||
t.Skip("APPENGINE_DEV_APPSERVER not set")
|
||||
}
|
||||
|
||||
// Check that the context methods work.
|
||||
_, done, err := NewContext()
|
||||
if err != nil {
|
||||
t.Fatalf("NewContext: %v", err)
|
||||
}
|
||||
done()
|
||||
}
|
||||
|
||||
func TestUsers(t *testing.T) {
|
||||
// Only run the test if APPENGINE_DEV_APPSERVER is explicitly set.
|
||||
if os.Getenv("APPENGINE_DEV_APPSERVER") == "" {
|
||||
t.Skip("APPENGINE_DEV_APPSERVER not set")
|
||||
}
|
||||
|
||||
inst, err := NewInstance(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewInstance: %v", err)
|
||||
}
|
||||
defer inst.Close()
|
||||
|
||||
req, err := inst.NewRequest("GET", "http://example.com/page", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("NewRequest: %v", err)
|
||||
}
|
||||
ctx := appengine.NewContext(req)
|
||||
|
||||
if user := user.Current(ctx); user != nil {
|
||||
t.Errorf("user.Current initially %v, want nil", user)
|
||||
}
|
||||
|
||||
u := &user.User{
|
||||
Email: "gopher@example.com",
|
||||
Admin: true,
|
||||
}
|
||||
Login(u, req)
|
||||
|
||||
if got := user.Current(ctx); got.Email != u.Email {
|
||||
t.Errorf("user.Current: %v, want %v", got, u)
|
||||
}
|
||||
if admin := user.IsAdmin(ctx); !admin {
|
||||
t.Errorf("user.IsAdmin: %t, want true", admin)
|
||||
}
|
||||
|
||||
Logout(req)
|
||||
if user := user.Current(ctx); user != nil {
|
||||
t.Errorf("user.Current after logout %v, want nil", user)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,276 @@
|
|||
// +build !appengine
|
||||
|
||||
package aetest
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/appengine/internal"
|
||||
)
|
||||
|
||||
// NewInstance launches a running instance of api_server.py which can be used
|
||||
// for multiple test Contexts that delegate all App Engine API calls to that
|
||||
// instance.
|
||||
// If opts is nil the default values are used.
|
||||
func NewInstance(opts *Options) (Instance, error) {
|
||||
i := &instance{
|
||||
opts: opts,
|
||||
appID: "testapp",
|
||||
}
|
||||
if opts != nil && opts.AppID != "" {
|
||||
i.appID = opts.AppID
|
||||
}
|
||||
if err := i.startChild(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func newSessionID() string {
|
||||
var buf [16]byte
|
||||
io.ReadFull(rand.Reader, buf[:])
|
||||
return fmt.Sprintf("%x", buf[:])
|
||||
}
|
||||
|
||||
// instance implements the Instance interface.
|
||||
type instance struct {
|
||||
opts *Options
|
||||
child *exec.Cmd
|
||||
apiURL *url.URL // base URL of API HTTP server
|
||||
adminURL string // base URL of admin HTTP server
|
||||
appDir string
|
||||
appID string
|
||||
relFuncs []func() // funcs to release any associated contexts
|
||||
}
|
||||
|
||||
// NewRequest returns an *http.Request associated with this instance.
|
||||
func (i *instance) NewRequest(method, urlStr string, body io.Reader) (*http.Request, error) {
|
||||
req, err := http.NewRequest(method, urlStr, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Associate this request.
|
||||
release := internal.RegisterTestRequest(req, i.apiURL, func(ctx context.Context) context.Context {
|
||||
ctx = internal.WithAppIDOverride(ctx, "dev~"+i.appID)
|
||||
return ctx
|
||||
})
|
||||
i.relFuncs = append(i.relFuncs, release)
|
||||
|
||||
return req, nil
|
||||
}
|
||||
|
||||
// Close kills the child api_server.py process, releasing its resources.
|
||||
func (i *instance) Close() (err error) {
|
||||
for _, rel := range i.relFuncs {
|
||||
rel()
|
||||
}
|
||||
i.relFuncs = nil
|
||||
if i.child == nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
i.child = nil
|
||||
err1 := os.RemoveAll(i.appDir)
|
||||
if err == nil {
|
||||
err = err1
|
||||
}
|
||||
}()
|
||||
|
||||
if p := i.child.Process; p != nil {
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- i.child.Wait()
|
||||
}()
|
||||
|
||||
// Call the quit handler on the admin server.
|
||||
res, err := http.Get(i.adminURL + "/quit")
|
||||
if err != nil {
|
||||
p.Kill()
|
||||
return fmt.Errorf("unable to call /quit handler: %v", err)
|
||||
}
|
||||
res.Body.Close()
|
||||
|
||||
select {
|
||||
case <-time.After(15 * time.Second):
|
||||
p.Kill()
|
||||
return errors.New("timeout killing child process")
|
||||
case err = <-errc:
|
||||
// Do nothing.
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func findPython() (path string, err error) {
|
||||
for _, name := range []string{"python2.7", "python"} {
|
||||
path, err = exec.LookPath(name)
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func findDevAppserver() (string, error) {
|
||||
if p := os.Getenv("APPENGINE_DEV_APPSERVER"); p != "" {
|
||||
if fileExists(p) {
|
||||
return p, nil
|
||||
}
|
||||
return "", fmt.Errorf("invalid APPENGINE_DEV_APPSERVER environment variable; path %q doesn't exist", p)
|
||||
}
|
||||
return exec.LookPath("dev_appserver.py")
|
||||
}
|
||||
|
||||
var apiServerAddrRE = regexp.MustCompile(`Starting API server at: (\S+)`)
|
||||
var adminServerAddrRE = regexp.MustCompile(`Starting admin server at: (\S+)`)
|
||||
|
||||
func (i *instance) startChild() (err error) {
|
||||
if PrepareDevAppserver != nil {
|
||||
if err := PrepareDevAppserver(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
python, err := findPython()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not find python interpreter: %v", err)
|
||||
}
|
||||
devAppserver, err := findDevAppserver()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Could not find dev_appserver.py: %v", err)
|
||||
}
|
||||
|
||||
i.appDir, err = ioutil.TempDir("", "appengine-aetest")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
os.RemoveAll(i.appDir)
|
||||
}
|
||||
}()
|
||||
err = os.Mkdir(filepath.Join(i.appDir, "app"), 0755)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ioutil.WriteFile(filepath.Join(i.appDir, "app", "app.yaml"), []byte(i.appYAML()), 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = ioutil.WriteFile(filepath.Join(i.appDir, "app", "stubapp.go"), []byte(appSource), 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
appserverArgs := []string{
|
||||
devAppserver,
|
||||
"--port=0",
|
||||
"--api_port=0",
|
||||
"--admin_port=0",
|
||||
"--automatic_restart=false",
|
||||
"--skip_sdk_update_check=true",
|
||||
"--clear_datastore=true",
|
||||
"--clear_search_indexes=true",
|
||||
"--datastore_path", filepath.Join(i.appDir, "datastore"),
|
||||
}
|
||||
if i.opts != nil && i.opts.StronglyConsistentDatastore {
|
||||
appserverArgs = append(appserverArgs, "--datastore_consistency_policy=consistent")
|
||||
}
|
||||
appserverArgs = append(appserverArgs, filepath.Join(i.appDir, "app"))
|
||||
|
||||
i.child = exec.Command(python,
|
||||
appserverArgs...,
|
||||
)
|
||||
i.child.Stdout = os.Stdout
|
||||
var stderr io.Reader
|
||||
stderr, err = i.child.StderrPipe()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stderr = io.TeeReader(stderr, os.Stderr)
|
||||
if err = i.child.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Read stderr until we have read the URLs of the API server and admin interface.
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
s := bufio.NewScanner(stderr)
|
||||
for s.Scan() {
|
||||
if match := apiServerAddrRE.FindStringSubmatch(s.Text()); match != nil {
|
||||
u, err := url.Parse(match[1])
|
||||
if err != nil {
|
||||
errc <- fmt.Errorf("failed to parse API URL %q: %v", match[1], err)
|
||||
return
|
||||
}
|
||||
i.apiURL = u
|
||||
}
|
||||
if match := adminServerAddrRE.FindStringSubmatch(s.Text()); match != nil {
|
||||
i.adminURL = match[1]
|
||||
}
|
||||
if i.adminURL != "" && i.apiURL != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
errc <- s.Err()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-time.After(15 * time.Second):
|
||||
if p := i.child.Process; p != nil {
|
||||
p.Kill()
|
||||
}
|
||||
return errors.New("timeout starting child process")
|
||||
case err := <-errc:
|
||||
if err != nil {
|
||||
return fmt.Errorf("error reading child process stderr: %v", err)
|
||||
}
|
||||
}
|
||||
if i.adminURL == "" {
|
||||
return errors.New("unable to find admin server URL")
|
||||
}
|
||||
if i.apiURL == nil {
|
||||
return errors.New("unable to find API server URL")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *instance) appYAML() string {
|
||||
return fmt.Sprintf(appYAMLTemplate, i.appID)
|
||||
}
|
||||
|
||||
const appYAMLTemplate = `
|
||||
application: %s
|
||||
version: 1
|
||||
runtime: go
|
||||
api_version: go1
|
||||
vm: true
|
||||
|
||||
handlers:
|
||||
- url: /.*
|
||||
script: _go_app
|
||||
`
|
||||
|
||||
const appSource = `
|
||||
package main
|
||||
import "google.golang.org/appengine"
|
||||
func main() { appengine.Main() }
|
||||
`
|
|
@ -0,0 +1,36 @@
|
|||
package aetest
|
||||
|
||||
import (
|
||||
"hash/crc32"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"google.golang.org/appengine/user"
|
||||
)
|
||||
|
||||
// Login causes the provided Request to act as though issued by the given user.
|
||||
func Login(u *user.User, req *http.Request) {
|
||||
req.Header.Set("X-AppEngine-User-Email", u.Email)
|
||||
id := u.ID
|
||||
if id == "" {
|
||||
id = strconv.Itoa(int(crc32.Checksum([]byte(u.Email), crc32.IEEETable)))
|
||||
}
|
||||
req.Header.Set("X-AppEngine-User-Id", id)
|
||||
req.Header.Set("X-AppEngine-User-Federated-Identity", u.Email)
|
||||
req.Header.Set("X-AppEngine-User-Federated-Provider", u.FederatedProvider)
|
||||
if u.Admin {
|
||||
req.Header.Set("X-AppEngine-User-Is-Admin", "1")
|
||||
} else {
|
||||
req.Header.Set("X-AppEngine-User-Is-Admin", "0")
|
||||
}
|
||||
}
|
||||
|
||||
// Logout causes the provided Request to act as though issued by a logged-out
|
||||
// user.
|
||||
func Logout(req *http.Request) {
|
||||
req.Header.Del("X-AppEngine-User-Email")
|
||||
req.Header.Del("X-AppEngine-User-Id")
|
||||
req.Header.Del("X-AppEngine-User-Is-Admin")
|
||||
req.Header.Del("X-AppEngine-User-Federated-Identity")
|
||||
req.Header.Del("X-AppEngine-User-Federated-Provider")
|
||||
}
|
|
@ -0,0 +1,112 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package appengine provides basic functionality for Google App Engine.
|
||||
//
|
||||
// For more information on how to write Go apps for Google App Engine, see:
|
||||
// https://cloud.google.com/appengine/docs/go/
|
||||
package appengine // import "google.golang.org/appengine"
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
)
|
||||
|
||||
// The gophers party all night; the rabbits provide the beats.
|
||||
|
||||
// Main is the principal entry point for an app running in App Engine.
|
||||
//
|
||||
// On App Engine Flexible it installs a trivial health checker if one isn't
|
||||
// already registered, and starts listening on port 8080 (overridden by the
|
||||
// $PORT environment variable).
|
||||
//
|
||||
// See https://cloud.google.com/appengine/docs/flexible/custom-runtimes#health_check_requests
|
||||
// for details on how to do your own health checking.
|
||||
//
|
||||
// Main is not yet supported on App Engine Standard.
|
||||
//
|
||||
// Main never returns.
|
||||
//
|
||||
// Main is designed so that the app's main package looks like this:
|
||||
//
|
||||
// package main
|
||||
//
|
||||
// import (
|
||||
// "google.golang.org/appengine"
|
||||
//
|
||||
// _ "myapp/package0"
|
||||
// _ "myapp/package1"
|
||||
// )
|
||||
//
|
||||
// func main() {
|
||||
// appengine.Main()
|
||||
// }
|
||||
//
|
||||
// The "myapp/packageX" packages are expected to register HTTP handlers
|
||||
// in their init functions.
|
||||
func Main() {
|
||||
internal.Main()
|
||||
}
|
||||
|
||||
// IsDevAppServer reports whether the App Engine app is running in the
|
||||
// development App Server.
|
||||
func IsDevAppServer() bool {
|
||||
return internal.IsDevAppServer()
|
||||
}
|
||||
|
||||
// NewContext returns a context for an in-flight HTTP request.
|
||||
// This function is cheap.
|
||||
func NewContext(req *http.Request) context.Context {
|
||||
return WithContext(context.Background(), req)
|
||||
}
|
||||
|
||||
// WithContext returns a copy of the parent context
|
||||
// and associates it with an in-flight HTTP request.
|
||||
// This function is cheap.
|
||||
func WithContext(parent context.Context, req *http.Request) context.Context {
|
||||
return internal.WithContext(parent, req)
|
||||
}
|
||||
|
||||
// TODO(dsymonds): Add a Call function here? Otherwise other packages can't access internal.Call.
|
||||
|
||||
// BlobKey is a key for a blobstore blob.
|
||||
//
|
||||
// Conceptually, this type belongs in the blobstore package, but it lives in
|
||||
// the appengine package to avoid a circular dependency: blobstore depends on
|
||||
// datastore, and datastore needs to refer to the BlobKey type.
|
||||
type BlobKey string
|
||||
|
||||
// GeoPoint represents a location as latitude/longitude in degrees.
|
||||
type GeoPoint struct {
|
||||
Lat, Lng float64
|
||||
}
|
||||
|
||||
// Valid returns whether a GeoPoint is within [-90, 90] latitude and [-180, 180] longitude.
|
||||
func (g GeoPoint) Valid() bool {
|
||||
return -90 <= g.Lat && g.Lat <= 90 && -180 <= g.Lng && g.Lng <= 180
|
||||
}
|
||||
|
||||
// APICallFunc defines a function type for handling an API call.
|
||||
// See WithCallOverride.
|
||||
type APICallFunc func(ctx context.Context, service, method string, in, out proto.Message) error
|
||||
|
||||
// WithAPICallFunc returns a copy of the parent context
|
||||
// that will cause API calls to invoke f instead of their normal operation.
|
||||
//
|
||||
// This is intended for advanced users only.
|
||||
func WithAPICallFunc(ctx context.Context, f APICallFunc) context.Context {
|
||||
return internal.WithCallOverride(ctx, internal.CallOverrideFunc(f))
|
||||
}
|
||||
|
||||
// APICall performs an API call.
|
||||
//
|
||||
// This is not intended for general use; it is exported for use in conjunction
|
||||
// with WithAPICallFunc.
|
||||
func APICall(ctx context.Context, service, method string, in, out proto.Message) error {
|
||||
return internal.Call(ctx, service, method, in, out)
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
// Copyright 2014 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package appengine
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestValidGeoPoint(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
pt GeoPoint
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
"valid",
|
||||
GeoPoint{67.21, 13.37},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"high lat",
|
||||
GeoPoint{-90.01, 13.37},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"low lat",
|
||||
GeoPoint{90.01, 13.37},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"high lng",
|
||||
GeoPoint{67.21, 182},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"low lng",
|
||||
GeoPoint{67.21, -181},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
if got := tc.pt.Valid(); got != tc.want {
|
||||
t.Errorf("%s: got %v, want %v", tc.desc, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,20 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
|
||||
package appengine
|
||||
|
||||
import (
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
)
|
||||
|
||||
// BackgroundContext returns a context not associated with a request.
|
||||
// This should only be used when not servicing a request.
|
||||
// This only works in App Engine "flexible environment".
|
||||
func BackgroundContext() context.Context {
|
||||
return internal.BackgroundContext()
|
||||
}
|
|
@ -0,0 +1,276 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package blobstore provides a client for App Engine's persistent blob
|
||||
// storage service.
|
||||
package blobstore // import "google.golang.org/appengine/blobstore"
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"mime"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/textproto"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/datastore"
|
||||
"google.golang.org/appengine/internal"
|
||||
|
||||
basepb "google.golang.org/appengine/internal/base"
|
||||
blobpb "google.golang.org/appengine/internal/blobstore"
|
||||
)
|
||||
|
||||
const (
|
||||
blobInfoKind = "__BlobInfo__"
|
||||
blobFileIndexKind = "__BlobFileIndex__"
|
||||
zeroKey = appengine.BlobKey("")
|
||||
)
|
||||
|
||||
// BlobInfo is the blob metadata that is stored in the datastore.
|
||||
// Filename may be empty.
|
||||
type BlobInfo struct {
|
||||
BlobKey appengine.BlobKey
|
||||
ContentType string `datastore:"content_type"`
|
||||
CreationTime time.Time `datastore:"creation"`
|
||||
Filename string `datastore:"filename"`
|
||||
Size int64 `datastore:"size"`
|
||||
MD5 string `datastore:"md5_hash"`
|
||||
|
||||
// ObjectName is the Google Cloud Storage name for this blob.
|
||||
ObjectName string `datastore:"gs_object_name"`
|
||||
}
|
||||
|
||||
// isErrFieldMismatch returns whether err is a datastore.ErrFieldMismatch.
|
||||
//
|
||||
// The blobstore stores blob metadata in the datastore. When loading that
|
||||
// metadata, it may contain fields that we don't care about. datastore.Get will
|
||||
// return datastore.ErrFieldMismatch in that case, so we ignore that specific
|
||||
// error.
|
||||
func isErrFieldMismatch(err error) bool {
|
||||
_, ok := err.(*datastore.ErrFieldMismatch)
|
||||
return ok
|
||||
}
|
||||
|
||||
// Stat returns the BlobInfo for a provided blobKey. If no blob was found for
|
||||
// that key, Stat returns datastore.ErrNoSuchEntity.
|
||||
func Stat(c context.Context, blobKey appengine.BlobKey) (*BlobInfo, error) {
|
||||
c, _ = appengine.Namespace(c, "") // Blobstore is always in the empty string namespace
|
||||
dskey := datastore.NewKey(c, blobInfoKind, string(blobKey), 0, nil)
|
||||
bi := &BlobInfo{
|
||||
BlobKey: blobKey,
|
||||
}
|
||||
if err := datastore.Get(c, dskey, bi); err != nil && !isErrFieldMismatch(err) {
|
||||
return nil, err
|
||||
}
|
||||
return bi, nil
|
||||
}
|
||||
|
||||
// Send sets the headers on response to instruct App Engine to send a blob as
|
||||
// the response body. This is more efficient than reading and writing it out
|
||||
// manually and isn't subject to normal response size limits.
|
||||
func Send(response http.ResponseWriter, blobKey appengine.BlobKey) {
|
||||
hdr := response.Header()
|
||||
hdr.Set("X-AppEngine-BlobKey", string(blobKey))
|
||||
|
||||
if hdr.Get("Content-Type") == "" {
|
||||
// This value is known to dev_appserver to mean automatic.
|
||||
// In production this is remapped to the empty value which
|
||||
// means automatic.
|
||||
hdr.Set("Content-Type", "application/vnd.google.appengine.auto")
|
||||
}
|
||||
}
|
||||
|
||||
// UploadURL creates an upload URL for the form that the user will
|
||||
// fill out, passing the application path to load when the POST of the
|
||||
// form is completed. These URLs expire and should not be reused. The
|
||||
// opts parameter may be nil.
|
||||
func UploadURL(c context.Context, successPath string, opts *UploadURLOptions) (*url.URL, error) {
|
||||
req := &blobpb.CreateUploadURLRequest{
|
||||
SuccessPath: proto.String(successPath),
|
||||
}
|
||||
if opts != nil {
|
||||
if n := opts.MaxUploadBytes; n != 0 {
|
||||
req.MaxUploadSizeBytes = &n
|
||||
}
|
||||
if n := opts.MaxUploadBytesPerBlob; n != 0 {
|
||||
req.MaxUploadSizePerBlobBytes = &n
|
||||
}
|
||||
if s := opts.StorageBucket; s != "" {
|
||||
req.GsBucketName = &s
|
||||
}
|
||||
}
|
||||
res := &blobpb.CreateUploadURLResponse{}
|
||||
if err := internal.Call(c, "blobstore", "CreateUploadURL", req, res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return url.Parse(*res.Url)
|
||||
}
|
||||
|
||||
// UploadURLOptions are the options to create an upload URL.
|
||||
type UploadURLOptions struct {
|
||||
MaxUploadBytes int64 // optional
|
||||
MaxUploadBytesPerBlob int64 // optional
|
||||
|
||||
// StorageBucket specifies the Google Cloud Storage bucket in which
|
||||
// to store the blob.
|
||||
// This is required if you use Cloud Storage instead of Blobstore.
|
||||
// Your application must have permission to write to the bucket.
|
||||
// You may optionally specify a bucket name and path in the format
|
||||
// "bucket_name/path", in which case the included path will be the
|
||||
// prefix of the uploaded object's name.
|
||||
StorageBucket string
|
||||
}
|
||||
|
||||
// Delete deletes a blob.
|
||||
func Delete(c context.Context, blobKey appengine.BlobKey) error {
|
||||
return DeleteMulti(c, []appengine.BlobKey{blobKey})
|
||||
}
|
||||
|
||||
// DeleteMulti deletes multiple blobs.
|
||||
func DeleteMulti(c context.Context, blobKey []appengine.BlobKey) error {
|
||||
s := make([]string, len(blobKey))
|
||||
for i, b := range blobKey {
|
||||
s[i] = string(b)
|
||||
}
|
||||
req := &blobpb.DeleteBlobRequest{
|
||||
BlobKey: s,
|
||||
}
|
||||
res := &basepb.VoidProto{}
|
||||
if err := internal.Call(c, "blobstore", "DeleteBlob", req, res); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func errorf(format string, args ...interface{}) error {
|
||||
return fmt.Errorf("blobstore: "+format, args...)
|
||||
}
|
||||
|
||||
// ParseUpload parses the synthetic POST request that your app gets from
|
||||
// App Engine after a user's successful upload of blobs. Given the request,
|
||||
// ParseUpload returns a map of the blobs received (keyed by HTML form
|
||||
// element name) and other non-blob POST parameters.
|
||||
func ParseUpload(req *http.Request) (blobs map[string][]*BlobInfo, other url.Values, err error) {
|
||||
_, params, err := mime.ParseMediaType(req.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
boundary := params["boundary"]
|
||||
if boundary == "" {
|
||||
return nil, nil, errorf("did not find MIME multipart boundary")
|
||||
}
|
||||
|
||||
blobs = make(map[string][]*BlobInfo)
|
||||
other = make(url.Values)
|
||||
|
||||
mreader := multipart.NewReader(io.MultiReader(req.Body, strings.NewReader("\r\n\r\n")), boundary)
|
||||
for {
|
||||
part, perr := mreader.NextPart()
|
||||
if perr == io.EOF {
|
||||
break
|
||||
}
|
||||
if perr != nil {
|
||||
return nil, nil, errorf("error reading next mime part with boundary %q (len=%d): %v",
|
||||
boundary, len(boundary), perr)
|
||||
}
|
||||
|
||||
bi := &BlobInfo{}
|
||||
ctype, params, err := mime.ParseMediaType(part.Header.Get("Content-Disposition"))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
bi.Filename = params["filename"]
|
||||
formKey := params["name"]
|
||||
|
||||
ctype, params, err = mime.ParseMediaType(part.Header.Get("Content-Type"))
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
bi.BlobKey = appengine.BlobKey(params["blob-key"])
|
||||
if ctype != "message/external-body" || bi.BlobKey == "" {
|
||||
if formKey != "" {
|
||||
slurp, serr := ioutil.ReadAll(part)
|
||||
if serr != nil {
|
||||
return nil, nil, errorf("error reading %q MIME part", formKey)
|
||||
}
|
||||
other[formKey] = append(other[formKey], string(slurp))
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// App Engine sends a MIME header as the body of each MIME part.
|
||||
tp := textproto.NewReader(bufio.NewReader(part))
|
||||
header, mimeerr := tp.ReadMIMEHeader()
|
||||
if mimeerr != nil {
|
||||
return nil, nil, mimeerr
|
||||
}
|
||||
bi.Size, err = strconv.ParseInt(header.Get("Content-Length"), 10, 64)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
bi.ContentType = header.Get("Content-Type")
|
||||
|
||||
// Parse the time from the MIME header like:
|
||||
// X-AppEngine-Upload-Creation: 2011-03-15 21:38:34.712136
|
||||
createDate := header.Get("X-AppEngine-Upload-Creation")
|
||||
if createDate == "" {
|
||||
return nil, nil, errorf("expected to find an X-AppEngine-Upload-Creation header")
|
||||
}
|
||||
bi.CreationTime, err = time.Parse("2006-01-02 15:04:05.000000", createDate)
|
||||
if err != nil {
|
||||
return nil, nil, errorf("error parsing X-AppEngine-Upload-Creation: %s", err)
|
||||
}
|
||||
|
||||
if hdr := header.Get("Content-MD5"); hdr != "" {
|
||||
md5, err := base64.URLEncoding.DecodeString(hdr)
|
||||
if err != nil {
|
||||
return nil, nil, errorf("bad Content-MD5 %q: %v", hdr, err)
|
||||
}
|
||||
bi.MD5 = string(md5)
|
||||
}
|
||||
|
||||
// If the GCS object name was provided, record it.
|
||||
bi.ObjectName = header.Get("X-AppEngine-Cloud-Storage-Object")
|
||||
|
||||
blobs[formKey] = append(blobs[formKey], bi)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Reader is a blob reader.
|
||||
type Reader interface {
|
||||
io.Reader
|
||||
io.ReaderAt
|
||||
io.Seeker
|
||||
}
|
||||
|
||||
// NewReader returns a reader for a blob. It always succeeds; if the blob does
|
||||
// not exist then an error will be reported upon first read.
|
||||
func NewReader(c context.Context, blobKey appengine.BlobKey) Reader {
|
||||
return openBlob(c, blobKey)
|
||||
}
|
||||
|
||||
// BlobKeyForFile returns a BlobKey for a Google Storage file.
|
||||
// The filename should be of the form "/gs/bucket_name/object_name".
|
||||
func BlobKeyForFile(c context.Context, filename string) (appengine.BlobKey, error) {
|
||||
req := &blobpb.CreateEncodedGoogleStorageKeyRequest{
|
||||
Filename: &filename,
|
||||
}
|
||||
res := &blobpb.CreateEncodedGoogleStorageKeyResponse{}
|
||||
if err := internal.Call(c, "blobstore", "CreateEncodedGoogleStorageKey", req, res); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return appengine.BlobKey(*res.BlobKey), nil
|
||||
}
|
|
@ -0,0 +1,183 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/internal/aetesting"
|
||||
|
||||
pb "google.golang.org/appengine/internal/blobstore"
|
||||
)
|
||||
|
||||
const rbs = readBufferSize
|
||||
|
||||
func min(x, y int) int {
|
||||
if x < y {
|
||||
return x
|
||||
}
|
||||
return y
|
||||
}
|
||||
|
||||
func fakeFetchData(req *pb.FetchDataRequest, res *pb.FetchDataResponse) error {
|
||||
i0 := int(*req.StartIndex)
|
||||
i1 := int(*req.EndIndex + 1) // Blobstore's end-indices are inclusive; Go's are exclusive.
|
||||
bk := *req.BlobKey
|
||||
if i := strings.Index(bk, "."); i != -1 {
|
||||
// Strip everything past the ".".
|
||||
bk = bk[:i]
|
||||
}
|
||||
switch bk {
|
||||
case "a14p":
|
||||
const s = "abcdefghijklmnop"
|
||||
i0 := min(len(s), i0)
|
||||
i1 := min(len(s), i1)
|
||||
res.Data = []byte(s[i0:i1])
|
||||
case "longBlob":
|
||||
res.Data = make([]byte, i1-i0)
|
||||
for i := range res.Data {
|
||||
res.Data[i] = 'A' + uint8(i0/rbs)
|
||||
i0++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// step is one step of a readerTest.
|
||||
// It consists of a Reader method to call, the method arguments
|
||||
// (lenp, offset, whence) and the expected results.
|
||||
type step struct {
|
||||
method string
|
||||
lenp int
|
||||
offset int64
|
||||
whence int
|
||||
want string
|
||||
wantErr error
|
||||
}
|
||||
|
||||
var readerTest = []struct {
|
||||
blobKey string
|
||||
step []step
|
||||
}{
|
||||
{"noSuchBlobKey", []step{
|
||||
{"Read", 8, 0, 0, "", io.EOF},
|
||||
}},
|
||||
{"a14p.0", []step{
|
||||
// Test basic reads.
|
||||
{"Read", 1, 0, 0, "a", nil},
|
||||
{"Read", 3, 0, 0, "bcd", nil},
|
||||
{"Read", 1, 0, 0, "e", nil},
|
||||
{"Read", 2, 0, 0, "fg", nil},
|
||||
// Test Seek.
|
||||
{"Seek", 0, 2, os.SEEK_SET, "2", nil},
|
||||
{"Read", 5, 0, 0, "cdefg", nil},
|
||||
{"Seek", 0, 2, os.SEEK_CUR, "9", nil},
|
||||
{"Read", 1, 0, 0, "j", nil},
|
||||
// Test reads up to and past EOF.
|
||||
{"Read", 5, 0, 0, "klmno", nil},
|
||||
{"Read", 5, 0, 0, "p", nil},
|
||||
{"Read", 5, 0, 0, "", io.EOF},
|
||||
// Test ReadAt.
|
||||
{"ReadAt", 4, 0, 0, "abcd", nil},
|
||||
{"ReadAt", 4, 3, 0, "defg", nil},
|
||||
{"ReadAt", 4, 12, 0, "mnop", nil},
|
||||
{"ReadAt", 4, 13, 0, "nop", io.EOF},
|
||||
{"ReadAt", 4, 99, 0, "", io.EOF},
|
||||
}},
|
||||
{"a14p.1", []step{
|
||||
// Test Seek before any reads.
|
||||
{"Seek", 0, 2, os.SEEK_SET, "2", nil},
|
||||
{"Read", 1, 0, 0, "c", nil},
|
||||
// Test that ReadAt doesn't affect the Read offset.
|
||||
{"ReadAt", 3, 9, 0, "jkl", nil},
|
||||
{"Read", 3, 0, 0, "def", nil},
|
||||
}},
|
||||
{"a14p.2", []step{
|
||||
// Test ReadAt before any reads or seeks.
|
||||
{"ReadAt", 2, 14, 0, "op", nil},
|
||||
}},
|
||||
{"longBlob.0", []step{
|
||||
// Test basic read.
|
||||
{"Read", 1, 0, 0, "A", nil},
|
||||
// Test that Read returns early when the buffer is exhausted.
|
||||
{"Seek", 0, rbs - 2, os.SEEK_SET, strconv.Itoa(rbs - 2), nil},
|
||||
{"Read", 5, 0, 0, "AA", nil},
|
||||
{"Read", 3, 0, 0, "BBB", nil},
|
||||
// Test that what we just read is still in the buffer.
|
||||
{"Seek", 0, rbs - 2, os.SEEK_SET, strconv.Itoa(rbs - 2), nil},
|
||||
{"Read", 5, 0, 0, "AABBB", nil},
|
||||
// Test ReadAt.
|
||||
{"ReadAt", 3, rbs - 4, 0, "AAA", nil},
|
||||
{"ReadAt", 6, rbs - 4, 0, "AAAABB", nil},
|
||||
{"ReadAt", 8, rbs - 4, 0, "AAAABBBB", nil},
|
||||
{"ReadAt", 5, rbs - 4, 0, "AAAAB", nil},
|
||||
{"ReadAt", 2, rbs - 4, 0, "AA", nil},
|
||||
// Test seeking backwards from the Read offset.
|
||||
{"Seek", 0, 2*rbs - 8, os.SEEK_SET, strconv.Itoa(2*rbs - 8), nil},
|
||||
{"Read", 1, 0, 0, "B", nil},
|
||||
{"Read", 1, 0, 0, "B", nil},
|
||||
{"Read", 1, 0, 0, "B", nil},
|
||||
{"Read", 1, 0, 0, "B", nil},
|
||||
{"Read", 8, 0, 0, "BBBBCCCC", nil},
|
||||
}},
|
||||
{"longBlob.1", []step{
|
||||
// Test ReadAt with a slice larger than the buffer size.
|
||||
{"LargeReadAt", 2*rbs - 2, 0, 0, strconv.Itoa(2*rbs - 2), nil},
|
||||
{"LargeReadAt", 2*rbs - 1, 0, 0, strconv.Itoa(2*rbs - 1), nil},
|
||||
{"LargeReadAt", 2*rbs + 0, 0, 0, strconv.Itoa(2*rbs + 0), nil},
|
||||
{"LargeReadAt", 2*rbs + 1, 0, 0, strconv.Itoa(2*rbs + 1), nil},
|
||||
{"LargeReadAt", 2*rbs + 2, 0, 0, strconv.Itoa(2*rbs + 2), nil},
|
||||
{"LargeReadAt", 2*rbs - 2, 1, 0, strconv.Itoa(2*rbs - 2), nil},
|
||||
{"LargeReadAt", 2*rbs - 1, 1, 0, strconv.Itoa(2*rbs - 1), nil},
|
||||
{"LargeReadAt", 2*rbs + 0, 1, 0, strconv.Itoa(2*rbs + 0), nil},
|
||||
{"LargeReadAt", 2*rbs + 1, 1, 0, strconv.Itoa(2*rbs + 1), nil},
|
||||
{"LargeReadAt", 2*rbs + 2, 1, 0, strconv.Itoa(2*rbs + 2), nil},
|
||||
}},
|
||||
}
|
||||
|
||||
func TestReader(t *testing.T) {
|
||||
for _, rt := range readerTest {
|
||||
c := aetesting.FakeSingleContext(t, "blobstore", "FetchData", fakeFetchData)
|
||||
r := NewReader(c, appengine.BlobKey(rt.blobKey))
|
||||
for i, step := range rt.step {
|
||||
var (
|
||||
got string
|
||||
gotErr error
|
||||
n int
|
||||
offset int64
|
||||
)
|
||||
switch step.method {
|
||||
case "LargeReadAt":
|
||||
p := make([]byte, step.lenp)
|
||||
n, gotErr = r.ReadAt(p, step.offset)
|
||||
got = strconv.Itoa(n)
|
||||
case "Read":
|
||||
p := make([]byte, step.lenp)
|
||||
n, gotErr = r.Read(p)
|
||||
got = string(p[:n])
|
||||
case "ReadAt":
|
||||
p := make([]byte, step.lenp)
|
||||
n, gotErr = r.ReadAt(p, step.offset)
|
||||
got = string(p[:n])
|
||||
case "Seek":
|
||||
offset, gotErr = r.Seek(step.offset, step.whence)
|
||||
got = strconv.FormatInt(offset, 10)
|
||||
default:
|
||||
t.Fatalf("unknown method: %s", step.method)
|
||||
}
|
||||
if gotErr != step.wantErr {
|
||||
t.Fatalf("%s step %d: got error %v want %v", rt.blobKey, i, gotErr, step.wantErr)
|
||||
}
|
||||
if got != step.want {
|
||||
t.Fatalf("%s step %d: got %q want %q", rt.blobKey, i, got, step.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,160 @@
|
|||
// Copyright 2012 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/internal"
|
||||
|
||||
blobpb "google.golang.org/appengine/internal/blobstore"
|
||||
)
|
||||
|
||||
// openBlob returns a reader for a blob. It always succeeds; if the blob does
|
||||
// not exist then an error will be reported upon first read.
|
||||
func openBlob(c context.Context, blobKey appengine.BlobKey) Reader {
|
||||
return &reader{
|
||||
c: c,
|
||||
blobKey: blobKey,
|
||||
}
|
||||
}
|
||||
|
||||
const readBufferSize = 256 * 1024
|
||||
|
||||
// reader is a blob reader. It implements the Reader interface.
|
||||
type reader struct {
|
||||
c context.Context
|
||||
|
||||
// Either blobKey or filename is set:
|
||||
blobKey appengine.BlobKey
|
||||
filename string
|
||||
|
||||
closeFunc func() // is nil if unavailable or already closed.
|
||||
|
||||
// buf is the read buffer. r is how much of buf has been read.
|
||||
// off is the offset of buf[0] relative to the start of the blob.
|
||||
// An invariant is 0 <= r && r <= len(buf).
|
||||
// Reads that don't require an RPC call will increment r but not off.
|
||||
// Seeks may modify r without discarding the buffer, but only if the
|
||||
// invariant can be maintained.
|
||||
mu sync.Mutex
|
||||
buf []byte
|
||||
r int
|
||||
off int64
|
||||
}
|
||||
|
||||
func (r *reader) Close() error {
|
||||
if f := r.closeFunc; f != nil {
|
||||
f()
|
||||
}
|
||||
r.closeFunc = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *reader) Read(p []byte) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
if r.r == len(r.buf) {
|
||||
if err := r.fetch(r.off + int64(r.r)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
n := copy(p, r.buf[r.r:])
|
||||
r.r += n
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *reader) ReadAt(p []byte, off int64) (int, error) {
|
||||
if len(p) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
// Convert relative offsets to absolute offsets.
|
||||
ab0 := r.off + int64(r.r)
|
||||
ab1 := r.off + int64(len(r.buf))
|
||||
ap0 := off
|
||||
ap1 := off + int64(len(p))
|
||||
// Check if we can satisfy the read entirely out of the existing buffer.
|
||||
if r.off <= ap0 && ap1 <= ab1 {
|
||||
// Convert off from an absolute offset to a relative offset.
|
||||
rp0 := int(ap0 - r.off)
|
||||
return copy(p, r.buf[rp0:]), nil
|
||||
}
|
||||
// Restore the original Read/Seek offset after ReadAt completes.
|
||||
defer r.seek(ab0)
|
||||
// Repeatedly fetch and copy until we have filled p.
|
||||
n := 0
|
||||
for len(p) > 0 {
|
||||
if err := r.fetch(off + int64(n)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
r.r = copy(p, r.buf)
|
||||
n += r.r
|
||||
p = p[r.r:]
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *reader) Seek(offset int64, whence int) (ret int64, err error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
switch whence {
|
||||
case os.SEEK_SET:
|
||||
ret = offset
|
||||
case os.SEEK_CUR:
|
||||
ret = r.off + int64(r.r) + offset
|
||||
case os.SEEK_END:
|
||||
return 0, errors.New("seeking relative to the end of a blob isn't supported")
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid Seek whence value: %d", whence)
|
||||
}
|
||||
if ret < 0 {
|
||||
return 0, errors.New("negative Seek offset")
|
||||
}
|
||||
return r.seek(ret)
|
||||
}
|
||||
|
||||
// fetch fetches readBufferSize bytes starting at the given offset. On success,
|
||||
// the data is saved as r.buf.
|
||||
func (r *reader) fetch(off int64) error {
|
||||
req := &blobpb.FetchDataRequest{
|
||||
BlobKey: proto.String(string(r.blobKey)),
|
||||
StartIndex: proto.Int64(off),
|
||||
EndIndex: proto.Int64(off + readBufferSize - 1), // EndIndex is inclusive.
|
||||
}
|
||||
res := &blobpb.FetchDataResponse{}
|
||||
if err := internal.Call(r.c, "blobstore", "FetchData", req, res); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(res.Data) == 0 {
|
||||
return io.EOF
|
||||
}
|
||||
r.buf, r.r, r.off = res.Data, 0, off
|
||||
return nil
|
||||
}
|
||||
|
||||
// seek seeks to the given offset with an effective whence equal to SEEK_SET.
|
||||
// It discards the read buffer if the invariant cannot be maintained.
|
||||
func (r *reader) seek(off int64) (int64, error) {
|
||||
delta := off - r.off
|
||||
if delta >= 0 && delta < int64(len(r.buf)) {
|
||||
r.r = int(delta)
|
||||
return off, nil
|
||||
}
|
||||
r.buf, r.r, r.off = nil, 0, off
|
||||
return off, nil
|
||||
}
|
|
@ -0,0 +1,52 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package capability exposes information about outages and scheduled downtime
|
||||
for specific API capabilities.
|
||||
|
||||
This package does not work in App Engine "flexible environment".
|
||||
|
||||
Example:
|
||||
if !capability.Enabled(c, "datastore_v3", "write") {
|
||||
// show user a different page
|
||||
}
|
||||
*/
|
||||
package capability // import "google.golang.org/appengine/capability"
|
||||
|
||||
import (
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
"google.golang.org/appengine/log"
|
||||
|
||||
pb "google.golang.org/appengine/internal/capability"
|
||||
)
|
||||
|
||||
// Enabled returns whether an API's capabilities are enabled.
|
||||
// The wildcard "*" capability matches every capability of an API.
|
||||
// If the underlying RPC fails (if the package is unknown, for example),
|
||||
// false is returned and information is written to the application log.
|
||||
func Enabled(ctx context.Context, api, capability string) bool {
|
||||
req := &pb.IsEnabledRequest{
|
||||
Package: &api,
|
||||
Capability: []string{capability},
|
||||
}
|
||||
res := &pb.IsEnabledResponse{}
|
||||
if err := internal.Call(ctx, "capability_service", "IsEnabled", req, res); err != nil {
|
||||
log.Warningf(ctx, "capability.Enabled: RPC failed: %v", err)
|
||||
return false
|
||||
}
|
||||
switch *res.SummaryStatus {
|
||||
case pb.IsEnabledResponse_ENABLED,
|
||||
pb.IsEnabledResponse_SCHEDULED_FUTURE,
|
||||
pb.IsEnabledResponse_SCHEDULED_NOW:
|
||||
return true
|
||||
case pb.IsEnabledResponse_UNKNOWN:
|
||||
log.Errorf(ctx, "capability.Enabled: unknown API capability %s/%s", api, capability)
|
||||
return false
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
|
@ -0,0 +1,83 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package channel implements the server side of App Engine's Channel API.
|
||||
|
||||
Create creates a new channel associated with the given clientID,
|
||||
which must be unique to the client that will use the returned token.
|
||||
|
||||
token, err := channel.Create(c, "player1")
|
||||
if err != nil {
|
||||
// handle error
|
||||
}
|
||||
// return token to the client in an HTTP response
|
||||
|
||||
Send sends a message to the client over the channel identified by clientID.
|
||||
|
||||
channel.Send(c, "player1", "Game over!")
|
||||
*/
|
||||
package channel // import "google.golang.org/appengine/channel"
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/internal"
|
||||
basepb "google.golang.org/appengine/internal/base"
|
||||
pb "google.golang.org/appengine/internal/channel"
|
||||
)
|
||||
|
||||
// Create creates a channel and returns a token for use by the client.
|
||||
// The clientID is an application-provided string used to identify the client.
|
||||
func Create(c context.Context, clientID string) (token string, err error) {
|
||||
req := &pb.CreateChannelRequest{
|
||||
ApplicationKey: &clientID,
|
||||
}
|
||||
resp := &pb.CreateChannelResponse{}
|
||||
err = internal.Call(c, service, "CreateChannel", req, resp)
|
||||
token = resp.GetToken()
|
||||
return token, remapError(err)
|
||||
}
|
||||
|
||||
// Send sends a message on the channel associated with clientID.
|
||||
func Send(c context.Context, clientID, message string) error {
|
||||
req := &pb.SendMessageRequest{
|
||||
ApplicationKey: &clientID,
|
||||
Message: &message,
|
||||
}
|
||||
resp := &basepb.VoidProto{}
|
||||
return remapError(internal.Call(c, service, "SendChannelMessage", req, resp))
|
||||
}
|
||||
|
||||
// SendJSON is a helper function that sends a JSON-encoded value
|
||||
// on the channel associated with clientID.
|
||||
func SendJSON(c context.Context, clientID string, value interface{}) error {
|
||||
m, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return Send(c, clientID, string(m))
|
||||
}
|
||||
|
||||
// remapError fixes any APIError referencing "xmpp" into one referencing "channel".
|
||||
func remapError(err error) error {
|
||||
if e, ok := err.(*internal.APIError); ok {
|
||||
if e.Service == "xmpp" {
|
||||
e.Service = "channel"
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var service = "xmpp" // prod
|
||||
|
||||
func init() {
|
||||
if appengine.IsDevAppServer() {
|
||||
service = "channel" // dev
|
||||
}
|
||||
internal.RegisterErrorCodeMap("channel", pb.ChannelServiceError_ErrorCode_name)
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package channel
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
)
|
||||
|
||||
func TestRemapError(t *testing.T) {
|
||||
err := &internal.APIError{
|
||||
Service: "xmpp",
|
||||
}
|
||||
err = remapError(err).(*internal.APIError)
|
||||
if err.Service != "channel" {
|
||||
t.Errorf("err.Service = %q, want %q", err.Service, "channel")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,62 @@
|
|||
// Copyright 2013 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package cloudsql exposes access to Google Cloud SQL databases.
|
||||
|
||||
This package does not work in App Engine "flexible environment".
|
||||
|
||||
This package is intended for MySQL drivers to make App Engine-specific
|
||||
connections. Applications should use this package through database/sql:
|
||||
Select a pure Go MySQL driver that supports this package, and use sql.Open
|
||||
with protocol "cloudsql" and an address of the Cloud SQL instance.
|
||||
|
||||
A Go MySQL driver that has been tested to work well with Cloud SQL
|
||||
is the go-sql-driver:
|
||||
import "database/sql"
|
||||
import _ "github.com/go-sql-driver/mysql"
|
||||
|
||||
db, err := sql.Open("mysql", "user@cloudsql(project-id:instance-name)/dbname")
|
||||
|
||||
|
||||
Another driver that works well with Cloud SQL is the mymysql driver:
|
||||
import "database/sql"
|
||||
import _ "github.com/ziutek/mymysql/godrv"
|
||||
|
||||
db, err := sql.Open("mymysql", "cloudsql:instance-name*dbname/user/password")
|
||||
|
||||
|
||||
Using either of these drivers, you can perform a standard SQL query.
|
||||
This example assumes there is a table named 'users' with
|
||||
columns 'first_name' and 'last_name':
|
||||
|
||||
rows, err := db.Query("SELECT first_name, last_name FROM users")
|
||||
if err != nil {
|
||||
log.Errorf(ctx, "db.Query: %v", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var firstName string
|
||||
var lastName string
|
||||
if err := rows.Scan(&firstName, &lastName); err != nil {
|
||||
log.Errorf(ctx, "rows.Scan: %v", err)
|
||||
continue
|
||||
}
|
||||
log.Infof(ctx, "First: %v - Last: %v", firstName, lastName)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
log.Errorf(ctx, "Row error: %v", err)
|
||||
}
|
||||
*/
|
||||
package cloudsql
|
||||
|
||||
import (
|
||||
"net"
|
||||
)
|
||||
|
||||
// Dial connects to the named Cloud SQL instance.
|
||||
func Dial(instance string) (net.Conn, error) {
|
||||
return connect(instance)
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
// Copyright 2013 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build appengine
|
||||
|
||||
package cloudsql
|
||||
|
||||
import (
|
||||
"net"
|
||||
|
||||
"appengine/cloudsql"
|
||||
)
|
||||
|
||||
func connect(instance string) (net.Conn, error) {
|
||||
return cloudsql.Dial(instance)
|
||||
}
|
|
@ -0,0 +1,16 @@
|
|||
// Copyright 2013 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// +build !appengine
|
||||
|
||||
package cloudsql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
)
|
||||
|
||||
func connect(instance string) (net.Conn, error) {
|
||||
return nil, errors.New(`cloudsql: not supported in App Engine "flexible environment"`)
|
||||
}
|
|
@ -0,0 +1,342 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Program aebundler turns a Go app into a fully self-contained tar file.
|
||||
// The app and its subdirectories (if any) are placed under "."
|
||||
// and the dependencies from $GOPATH are placed under ./_gopath/src.
|
||||
// A main func is synthesized if one does not exist.
|
||||
//
|
||||
// A sample Dockerfile to be used with this bundler could look like this:
|
||||
// FROM gcr.io/google_appengine/go-compat
|
||||
// ADD . /app
|
||||
// RUN GOPATH=/app/_gopath go build -tags appenginevm -o /app/_ah/exe
|
||||
package main
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/build"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
output = flag.String("o", "", "name of output tar file or '-' for stdout")
|
||||
rootDir = flag.String("root", ".", "directory name of application root")
|
||||
vm = flag.Bool("vm", true, `bundle an app for App Engine "flexible environment"`)
|
||||
|
||||
skipFiles = map[string]bool{
|
||||
".git": true,
|
||||
".gitconfig": true,
|
||||
".hg": true,
|
||||
".travis.yml": true,
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
newMain = `package main
|
||||
import "google.golang.org/appengine"
|
||||
func main() {
|
||||
appengine.Main()
|
||||
}
|
||||
`
|
||||
)
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "\t%s -o <file.tar|->\tBundle app to named tar file or stdout\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "\noptional arguments:\n")
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
|
||||
var tags []string
|
||||
if *vm {
|
||||
tags = append(tags, "appenginevm")
|
||||
} else {
|
||||
tags = append(tags, "appengine")
|
||||
}
|
||||
|
||||
tarFile := *output
|
||||
if tarFile == "" {
|
||||
usage()
|
||||
errorf("Required -o flag not specified.")
|
||||
}
|
||||
|
||||
app, err := analyze(tags)
|
||||
if err != nil {
|
||||
errorf("Error analyzing app: %v", err)
|
||||
}
|
||||
if err := app.bundle(tarFile); err != nil {
|
||||
errorf("Unable to bundle app: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// errorf prints the error message and exits.
|
||||
func errorf(format string, a ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, "aebundler: "+format+"\n", a...)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
type app struct {
|
||||
hasMain bool
|
||||
appFiles []string
|
||||
imports map[string]string
|
||||
}
|
||||
|
||||
// analyze checks the app for building with the given build tags and returns hasMain,
|
||||
// app files, and a map of full directory import names to original import names.
|
||||
func analyze(tags []string) (*app, error) {
|
||||
ctxt := buildContext(tags)
|
||||
hasMain, appFiles, err := checkMain(ctxt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gopath := filepath.SplitList(ctxt.GOPATH)
|
||||
im, err := imports(ctxt, *rootDir, gopath)
|
||||
return &app{
|
||||
hasMain: hasMain,
|
||||
appFiles: appFiles,
|
||||
imports: im,
|
||||
}, err
|
||||
}
|
||||
|
||||
// buildContext returns the context for building the source.
|
||||
func buildContext(tags []string) *build.Context {
|
||||
return &build.Context{
|
||||
GOARCH: build.Default.GOARCH,
|
||||
GOOS: build.Default.GOOS,
|
||||
GOROOT: build.Default.GOROOT,
|
||||
GOPATH: build.Default.GOPATH,
|
||||
Compiler: build.Default.Compiler,
|
||||
BuildTags: append(build.Default.BuildTags, tags...),
|
||||
}
|
||||
}
|
||||
|
||||
// bundle bundles the app into the named tarFile ("-"==stdout).
|
||||
func (s *app) bundle(tarFile string) (err error) {
|
||||
var out io.Writer
|
||||
if tarFile == "-" {
|
||||
out = os.Stdout
|
||||
} else {
|
||||
f, err := os.Create(tarFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if cerr := f.Close(); err == nil {
|
||||
err = cerr
|
||||
}
|
||||
}()
|
||||
out = f
|
||||
}
|
||||
tw := tar.NewWriter(out)
|
||||
|
||||
for srcDir, importName := range s.imports {
|
||||
dstDir := "_gopath/src/" + importName
|
||||
if err = copyTree(tw, dstDir, srcDir); err != nil {
|
||||
return fmt.Errorf("unable to copy directory %v to %v: %v", srcDir, dstDir, err)
|
||||
}
|
||||
}
|
||||
if err := copyTree(tw, ".", *rootDir); err != nil {
|
||||
return fmt.Errorf("unable to copy root directory to /app: %v", err)
|
||||
}
|
||||
if !s.hasMain {
|
||||
if err := synthesizeMain(tw, s.appFiles); err != nil {
|
||||
return fmt.Errorf("unable to synthesize new main func: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tw.Close(); err != nil {
|
||||
return fmt.Errorf("unable to close tar file %v: %v", tarFile, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// synthesizeMain generates a new main func and writes it to the tarball.
|
||||
func synthesizeMain(tw *tar.Writer, appFiles []string) error {
|
||||
appMap := make(map[string]bool)
|
||||
for _, f := range appFiles {
|
||||
appMap[f] = true
|
||||
}
|
||||
var f string
|
||||
for i := 0; i < 100; i++ {
|
||||
f = fmt.Sprintf("app_main%d.go", i)
|
||||
if !appMap[filepath.Join(*rootDir, f)] {
|
||||
break
|
||||
}
|
||||
}
|
||||
if appMap[filepath.Join(*rootDir, f)] {
|
||||
return fmt.Errorf("unable to find unique name for %v", f)
|
||||
}
|
||||
hdr := &tar.Header{
|
||||
Name: f,
|
||||
Mode: 0644,
|
||||
Size: int64(len(newMain)),
|
||||
}
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("unable to write header for %v: %v", f, err)
|
||||
}
|
||||
if _, err := tw.Write([]byte(newMain)); err != nil {
|
||||
return fmt.Errorf("unable to write %v to tar file: %v", f, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// imports returns a map of all import directories (recursively) used by the app.
|
||||
// The return value maps full directory names to original import names.
|
||||
func imports(ctxt *build.Context, srcDir string, gopath []string) (map[string]string, error) {
|
||||
pkg, err := ctxt.ImportDir(srcDir, 0)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to analyze source: %v", err)
|
||||
}
|
||||
|
||||
// Resolve all non-standard-library imports
|
||||
result := make(map[string]string)
|
||||
for _, v := range pkg.Imports {
|
||||
if !strings.Contains(v, ".") {
|
||||
continue
|
||||
}
|
||||
src, err := findInGopath(v, gopath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to find import %v in gopath %v: %v", v, gopath, err)
|
||||
}
|
||||
result[src] = v
|
||||
im, err := imports(ctxt, src, gopath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse package %v: %v", src, err)
|
||||
}
|
||||
for k, v := range im {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// findInGopath searches the gopath for the named import directory.
|
||||
func findInGopath(dir string, gopath []string) (string, error) {
|
||||
for _, v := range gopath {
|
||||
dst := filepath.Join(v, "src", dir)
|
||||
if _, err := os.Stat(dst); err == nil {
|
||||
return dst, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("unable to find package %v in gopath %v", dir, gopath)
|
||||
}
|
||||
|
||||
// copyTree copies srcDir to tar file dstDir, ignoring skipFiles.
|
||||
func copyTree(tw *tar.Writer, dstDir, srcDir string) error {
|
||||
entries, err := ioutil.ReadDir(srcDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read dir %v: %v", srcDir, err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
n := entry.Name()
|
||||
if skipFiles[n] {
|
||||
continue
|
||||
}
|
||||
s := filepath.Join(srcDir, n)
|
||||
d := filepath.Join(dstDir, n)
|
||||
if entry.IsDir() {
|
||||
if err := copyTree(tw, d, s); err != nil {
|
||||
return fmt.Errorf("unable to copy dir %v to %v: %v", s, d, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := copyFile(tw, d, s); err != nil {
|
||||
return fmt.Errorf("unable to copy dir %v to %v: %v", s, d, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyFile copies src to tar file dst.
|
||||
func copyFile(tw *tar.Writer, dst, src string) error {
|
||||
s, err := os.Open(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to open %v: %v", src, err)
|
||||
}
|
||||
defer s.Close()
|
||||
fi, err := s.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to stat %v: %v", src, err)
|
||||
}
|
||||
|
||||
hdr, err := tar.FileInfoHeader(fi, dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create tar header for %v: %v", dst, err)
|
||||
}
|
||||
hdr.Name = dst
|
||||
if err := tw.WriteHeader(hdr); err != nil {
|
||||
return fmt.Errorf("unable to write header for %v: %v", dst, err)
|
||||
}
|
||||
_, err = io.Copy(tw, s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to copy %v to %v: %v", src, dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkMain verifies that there is a single "main" function.
|
||||
// It also returns a list of all Go source files in the app.
|
||||
func checkMain(ctxt *build.Context) (bool, []string, error) {
|
||||
pkg, err := ctxt.ImportDir(*rootDir, 0)
|
||||
if err != nil {
|
||||
return false, nil, fmt.Errorf("unable to analyze source: %v", err)
|
||||
}
|
||||
if !pkg.IsCommand() {
|
||||
errorf("Your app's package needs to be changed from %q to \"main\".\n", pkg.Name)
|
||||
}
|
||||
// Search for a "func main"
|
||||
var hasMain bool
|
||||
var appFiles []string
|
||||
for _, f := range pkg.GoFiles {
|
||||
n := filepath.Join(*rootDir, f)
|
||||
appFiles = append(appFiles, n)
|
||||
if hasMain, err = readFile(n); err != nil {
|
||||
return false, nil, fmt.Errorf("error parsing %q: %v", n, err)
|
||||
}
|
||||
}
|
||||
return hasMain, appFiles, nil
|
||||
}
|
||||
|
||||
// isMain returns whether the given function declaration is a main function.
|
||||
// Such a function must be called "main", not have a receiver, and have no arguments or return types.
|
||||
func isMain(f *ast.FuncDecl) bool {
|
||||
ft := f.Type
|
||||
return f.Name.Name == "main" && f.Recv == nil && ft.Params.NumFields() == 0 && ft.Results.NumFields() == 0
|
||||
}
|
||||
|
||||
// readFile reads and parses the Go source code file and returns whether it has a main function.
|
||||
func readFile(filename string) (hasMain bool, err error) {
|
||||
var src []byte
|
||||
src, err = ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
fset := token.NewFileSet()
|
||||
file, err := parser.ParseFile(fset, filename, src, 0)
|
||||
for _, decl := range file.Decls {
|
||||
funcDecl, ok := decl.(*ast.FuncDecl)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if !isMain(funcDecl) {
|
||||
continue
|
||||
}
|
||||
hasMain = true
|
||||
break
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,268 @@
|
|||
// Copyright 2015 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Program aedeploy assists with deploying App Engine "flexible environment" Go apps to production.
|
||||
// A temporary directory is created; the app, its subdirectories, and all its
|
||||
// dependencies from $GOPATH are copied into the directory; then the app
|
||||
// is deployed to production with the provided command.
|
||||
//
|
||||
// The app must be in "package main".
|
||||
//
|
||||
// This command must be issued from within the root directory of the app
|
||||
// (where the app.yaml file is located).
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/build"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
skipFiles = map[string]bool{
|
||||
".git": true,
|
||||
".gitconfig": true,
|
||||
".hg": true,
|
||||
".travis.yml": true,
|
||||
}
|
||||
|
||||
gopathCache = map[string]string{}
|
||||
)
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0])
|
||||
fmt.Fprintf(os.Stderr, "\t%s gcloud --verbosity debug preview app deploy --version myversion ./app.yaml\tDeploy app to production\n", os.Args[0])
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
if flag.NArg() < 1 {
|
||||
usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if err := aedeploy(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, os.Args[0]+": Error: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
func aedeploy() error {
|
||||
tags := []string{"appenginevm"}
|
||||
app, err := analyze(tags)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpDir, err := app.bundle()
|
||||
if tmpDir != "" {
|
||||
defer os.RemoveAll(tmpDir)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Chdir(tmpDir); err != nil {
|
||||
return fmt.Errorf("unable to chdir to %v: %v", tmpDir, err)
|
||||
}
|
||||
return deploy()
|
||||
}
|
||||
|
||||
// deploy calls the provided command to deploy the app from the temporary directory.
|
||||
func deploy() error {
|
||||
cmd := exec.Command(flag.Arg(0), flag.Args()[1:]...)
|
||||
cmd.Stdin, cmd.Stdout, cmd.Stderr = os.Stdin, os.Stdout, os.Stderr
|
||||
if err := cmd.Run(); err != nil {
|
||||
return fmt.Errorf("unable to run %q: %v", strings.Join(flag.Args(), " "), err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type app struct {
|
||||
appFiles []string
|
||||
imports map[string]string
|
||||
}
|
||||
|
||||
// analyze checks the app for building with the given build tags and returns
|
||||
// app files, and a map of full directory import names to original import names.
|
||||
func analyze(tags []string) (*app, error) {
|
||||
ctxt := buildContext(tags)
|
||||
appFiles, err := appFiles(ctxt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gopath := filepath.SplitList(ctxt.GOPATH)
|
||||
im, err := imports(ctxt, ".", gopath)
|
||||
return &app{
|
||||
appFiles: appFiles,
|
||||
imports: im,
|
||||
}, err
|
||||
}
|
||||
|
||||
// buildContext returns the context for building the source.
|
||||
func buildContext(tags []string) *build.Context {
|
||||
return &build.Context{
|
||||
GOARCH: "amd64",
|
||||
GOOS: "linux",
|
||||
GOROOT: build.Default.GOROOT,
|
||||
GOPATH: build.Default.GOPATH,
|
||||
Compiler: build.Default.Compiler,
|
||||
BuildTags: append(defaultBuildTags, tags...),
|
||||
}
|
||||
}
|
||||
|
||||
// All build tags except go1.7, since Go 1.6 is the runtime version.
|
||||
var defaultBuildTags = []string{
|
||||
"go1.1", "go1.2", "go1.3", "go1.4", "go1.5", "go1.6"}
|
||||
|
||||
// bundle bundles the app into a temporary directory.
|
||||
func (s *app) bundle() (tmpdir string, err error) {
|
||||
workDir, err := ioutil.TempDir("", "aedeploy")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to create tmpdir: %v", err)
|
||||
}
|
||||
|
||||
for srcDir, importName := range s.imports {
|
||||
dstDir := "_gopath/src/" + importName
|
||||
if err := copyTree(workDir, dstDir, srcDir); err != nil {
|
||||
return workDir, fmt.Errorf("unable to copy directory %v to %v: %v", srcDir, dstDir, err)
|
||||
}
|
||||
}
|
||||
if err := copyTree(workDir, ".", "."); err != nil {
|
||||
return workDir, fmt.Errorf("unable to copy root directory to /app: %v", err)
|
||||
}
|
||||
return workDir, nil
|
||||
}
|
||||
|
||||
// imports returns a map of all import directories (recursively) used by the app.
|
||||
// The return value maps full directory names to original import names.
|
||||
func imports(ctxt *build.Context, srcDir string, gopath []string) (map[string]string, error) {
|
||||
pkg, err := ctxt.ImportDir(srcDir, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Resolve all non-standard-library imports
|
||||
result := make(map[string]string)
|
||||
for _, v := range pkg.Imports {
|
||||
if !strings.Contains(v, ".") {
|
||||
continue
|
||||
}
|
||||
src, err := findInGopath(v, gopath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to find import %v in gopath %v: %v", v, gopath, err)
|
||||
}
|
||||
if _, ok := result[src]; ok { // Already processed
|
||||
continue
|
||||
}
|
||||
result[src] = v
|
||||
im, err := imports(ctxt, src, gopath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse package %v: %v", src, err)
|
||||
}
|
||||
for k, v := range im {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// findInGopath searches the gopath for the named import directory.
|
||||
func findInGopath(dir string, gopath []string) (string, error) {
|
||||
if v, ok := gopathCache[dir]; ok {
|
||||
return v, nil
|
||||
}
|
||||
for _, v := range gopath {
|
||||
dst := filepath.Join(v, "src", dir)
|
||||
if _, err := os.Stat(dst); err == nil {
|
||||
gopathCache[dir] = dst
|
||||
return dst, nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("unable to find package %v in gopath %v", dir, gopath)
|
||||
}
|
||||
|
||||
// copyTree copies srcDir to dstDir relative to dstRoot, ignoring skipFiles.
|
||||
func copyTree(dstRoot, dstDir, srcDir string) error {
|
||||
d := filepath.Join(dstRoot, dstDir)
|
||||
if err := os.MkdirAll(d, 0755); err != nil {
|
||||
return fmt.Errorf("unable to create directory %q: %v", d, err)
|
||||
}
|
||||
|
||||
entries, err := ioutil.ReadDir(srcDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to read dir %q: %v", srcDir, err)
|
||||
}
|
||||
for _, entry := range entries {
|
||||
n := entry.Name()
|
||||
if skipFiles[n] {
|
||||
continue
|
||||
}
|
||||
s := filepath.Join(srcDir, n)
|
||||
if entry.Mode()&os.ModeSymlink == os.ModeSymlink {
|
||||
if entry, err = os.Stat(s); err != nil {
|
||||
return fmt.Errorf("unable to stat %v: %v", s, err)
|
||||
}
|
||||
}
|
||||
d := filepath.Join(dstDir, n)
|
||||
if entry.IsDir() {
|
||||
if err := copyTree(dstRoot, d, s); err != nil {
|
||||
return fmt.Errorf("unable to copy dir %q to %q: %v", s, d, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err := copyFile(dstRoot, d, s); err != nil {
|
||||
return fmt.Errorf("unable to copy dir %q to %q: %v", s, d, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// copyFile copies src to dst relative to dstRoot.
|
||||
func copyFile(dstRoot, dst, src string) error {
|
||||
s, err := os.Open(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to open %q: %v", src, err)
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
dst = filepath.Join(dstRoot, dst)
|
||||
d, err := os.Create(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to create %q: %v", dst, err)
|
||||
}
|
||||
_, err = io.Copy(d, s)
|
||||
if err != nil {
|
||||
d.Close() // ignore error, copy already failed.
|
||||
return fmt.Errorf("unable to copy %q to %q: %v", src, dst, err)
|
||||
}
|
||||
if err := d.Close(); err != nil {
|
||||
return fmt.Errorf("unable to close %q: %v", dst, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// appFiles returns a list of all Go source files in the app.
|
||||
func appFiles(ctxt *build.Context) ([]string, error) {
|
||||
pkg, err := ctxt.ImportDir(".", 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !pkg.IsCommand() {
|
||||
return nil, fmt.Errorf(`the root of your app needs to be package "main" (currently %q). Please see https://cloud.google.com/appengine/docs/flexible/go/ for more details on structuring your app.`, pkg.Name)
|
||||
}
|
||||
var appFiles []string
|
||||
for _, f := range pkg.GoFiles {
|
||||
n := filepath.Join(".", f)
|
||||
appFiles = append(appFiles, n)
|
||||
}
|
||||
return appFiles, nil
|
||||
}
|
|
@ -0,0 +1,185 @@
|
|||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
ctxPackage = "golang.org/x/net/context"
|
||||
|
||||
newPackageBase = "google.golang.org/"
|
||||
stutterPackage = false
|
||||
)
|
||||
|
||||
func init() {
|
||||
register(fix{
|
||||
"ae",
|
||||
"2016-04-15",
|
||||
aeFn,
|
||||
`Update old App Engine APIs to new App Engine APIs`,
|
||||
})
|
||||
}
|
||||
|
||||
// logMethod is the set of methods on appengine.Context used for logging.
|
||||
var logMethod = map[string]bool{
|
||||
"Debugf": true,
|
||||
"Infof": true,
|
||||
"Warningf": true,
|
||||
"Errorf": true,
|
||||
"Criticalf": true,
|
||||
}
|
||||
|
||||
// mapPackage turns "appengine" into "google.golang.org/appengine", etc.
|
||||
func mapPackage(s string) string {
|
||||
if stutterPackage {
|
||||
s += "/" + path.Base(s)
|
||||
}
|
||||
return newPackageBase + s
|
||||
}
|
||||
|
||||
func aeFn(f *ast.File) bool {
|
||||
// During the walk, we track the last thing seen that looks like
|
||||
// an appengine.Context, and reset it once the walk leaves a func.
|
||||
var lastContext *ast.Ident
|
||||
|
||||
fixed := false
|
||||
|
||||
// Update imports.
|
||||
mainImp := "appengine"
|
||||
for _, imp := range f.Imports {
|
||||
pth, _ := strconv.Unquote(imp.Path.Value)
|
||||
if pth == "appengine" || strings.HasPrefix(pth, "appengine/") {
|
||||
newPth := mapPackage(pth)
|
||||
imp.Path.Value = strconv.Quote(newPth)
|
||||
fixed = true
|
||||
|
||||
if pth == "appengine" {
|
||||
mainImp = newPth
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update any API changes.
|
||||
walk(f, func(n interface{}) {
|
||||
if ft, ok := n.(*ast.FuncType); ok && ft.Params != nil {
|
||||
// See if this func has an `appengine.Context arg`.
|
||||
// If so, remember its identifier.
|
||||
for _, param := range ft.Params.List {
|
||||
if !isPkgDot(param.Type, "appengine", "Context") {
|
||||
continue
|
||||
}
|
||||
if len(param.Names) == 1 {
|
||||
lastContext = param.Names[0]
|
||||
break
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if as, ok := n.(*ast.AssignStmt); ok {
|
||||
if len(as.Lhs) == 1 && len(as.Rhs) == 1 {
|
||||
// If this node is an assignment from an appengine.NewContext invocation,
|
||||
// remember the identifier on the LHS.
|
||||
if isCall(as.Rhs[0], "appengine", "NewContext") {
|
||||
if ident, ok := as.Lhs[0].(*ast.Ident); ok {
|
||||
lastContext = ident
|
||||
return
|
||||
}
|
||||
}
|
||||
// x (=|:=) appengine.Timeout(y, z)
|
||||
// should become
|
||||
// x, _ (=|:=) context.WithTimeout(y, z)
|
||||
if isCall(as.Rhs[0], "appengine", "Timeout") {
|
||||
addImport(f, ctxPackage)
|
||||
as.Lhs = append(as.Lhs, ast.NewIdent("_"))
|
||||
// isCall already did the type checking.
|
||||
sel := as.Rhs[0].(*ast.CallExpr).Fun.(*ast.SelectorExpr)
|
||||
sel.X = ast.NewIdent("context")
|
||||
sel.Sel = ast.NewIdent("WithTimeout")
|
||||
fixed = true
|
||||
return
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If this node is a FuncDecl, we've finished the function, so reset lastContext.
|
||||
if _, ok := n.(*ast.FuncDecl); ok {
|
||||
lastContext = nil
|
||||
return
|
||||
}
|
||||
|
||||
if call, ok := n.(*ast.CallExpr); ok {
|
||||
if isPkgDot(call.Fun, "appengine", "Datacenter") && len(call.Args) == 0 {
|
||||
insertContext(f, call, lastContext)
|
||||
fixed = true
|
||||
return
|
||||
}
|
||||
if isPkgDot(call.Fun, "taskqueue", "QueueStats") && len(call.Args) == 3 {
|
||||
call.Args = call.Args[:2] // drop last arg
|
||||
fixed = true
|
||||
return
|
||||
}
|
||||
|
||||
sel, ok := call.Fun.(*ast.SelectorExpr)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if lastContext != nil && refersTo(sel.X, lastContext) && logMethod[sel.Sel.Name] {
|
||||
// c.Errorf(...)
|
||||
// should become
|
||||
// log.Errorf(c, ...)
|
||||
addImport(f, mapPackage("appengine/log"))
|
||||
sel.X = &ast.Ident{ // ast.NewIdent doesn't preserve the position.
|
||||
NamePos: sel.X.Pos(),
|
||||
Name: "log",
|
||||
}
|
||||
insertContext(f, call, lastContext)
|
||||
fixed = true
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// Change any `appengine.Context` to `context.Context`.
|
||||
// Do this in a separate walk because the previous walk
|
||||
// wants to identify "appengine.Context".
|
||||
walk(f, func(n interface{}) {
|
||||
expr, ok := n.(ast.Expr)
|
||||
if ok && isPkgDot(expr, "appengine", "Context") {
|
||||
addImport(f, ctxPackage)
|
||||
// isPkgDot did the type checking.
|
||||
n.(*ast.SelectorExpr).X.(*ast.Ident).Name = "context"
|
||||
fixed = true
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
// The changes above might remove the need to import "appengine".
|
||||
// Check if it's used, and drop it if it isn't.
|
||||
if fixed && !usesImport(f, mainImp) {
|
||||
deleteImport(f, mainImp)
|
||||
}
|
||||
|
||||
return fixed
|
||||
}
|
||||
|
||||
// ctx may be nil.
|
||||
func insertContext(f *ast.File, call *ast.CallExpr, ctx *ast.Ident) {
|
||||
if ctx == nil {
|
||||
// context is unknown, so use a plain "ctx".
|
||||
ctx = ast.NewIdent("ctx")
|
||||
} else {
|
||||
// Create a fresh *ast.Ident so we drop the position information.
|
||||
ctx = ast.NewIdent(ctx.Name)
|
||||
}
|
||||
|
||||
call.Args = append([]ast.Expr{ctx}, call.Args...)
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
func init() {
|
||||
addTestCases(aeTests, nil)
|
||||
}
|
||||
|
||||
var aeTests = []testCase{
|
||||
// Collection of fixes:
|
||||
// - imports
|
||||
// - appengine.Timeout -> context.WithTimeout
|
||||
// - add ctx arg to appengine.Datacenter
|
||||
// - logging API
|
||||
{
|
||||
Name: "ae.0",
|
||||
In: `package foo
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"appengine"
|
||||
"appengine/datastore"
|
||||
)
|
||||
|
||||
func f(w http.ResponseWriter, r *http.Request) {
|
||||
c := appengine.NewContext(r)
|
||||
|
||||
c = appengine.Timeout(c, 5*time.Second)
|
||||
err := datastore.ErrNoSuchEntity
|
||||
c.Errorf("Something interesting happened: %v", err)
|
||||
_ = appengine.Datacenter()
|
||||
}
|
||||
`,
|
||||
Out: `package foo
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/datastore"
|
||||
"google.golang.org/appengine/log"
|
||||
)
|
||||
|
||||
func f(w http.ResponseWriter, r *http.Request) {
|
||||
c := appengine.NewContext(r)
|
||||
|
||||
c, _ = context.WithTimeout(c, 5*time.Second)
|
||||
err := datastore.ErrNoSuchEntity
|
||||
log.Errorf(c, "Something interesting happened: %v", err)
|
||||
_ = appengine.Datacenter(c)
|
||||
}
|
||||
`,
|
||||
},
|
||||
|
||||
// Updating a function that takes an appengine.Context arg.
|
||||
{
|
||||
Name: "ae.1",
|
||||
In: `package foo
|
||||
|
||||
import (
|
||||
"appengine"
|
||||
)
|
||||
|
||||
func LogSomething(c2 appengine.Context) {
|
||||
c2.Warningf("Stand back! I'm going to try science!")
|
||||
}
|
||||
`,
|
||||
Out: `package foo
|
||||
|
||||
import (
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/appengine/log"
|
||||
)
|
||||
|
||||
func LogSomething(c2 context.Context) {
|
||||
log.Warningf(c2, "Stand back! I'm going to try science!")
|
||||
}
|
||||
`,
|
||||
},
|
||||
|
||||
// Less widely used API changes:
|
||||
// - drop maxTasks arg to taskqueue.QueueStats
|
||||
{
|
||||
Name: "ae.2",
|
||||
In: `package foo
|
||||
|
||||
import (
|
||||
"appengine"
|
||||
"appengine/taskqueue"
|
||||
)
|
||||
|
||||
func f(ctx appengine.Context) {
|
||||
stats, err := taskqueue.QueueStats(ctx, []string{"one", "two"}, 0)
|
||||
}
|
||||
`,
|
||||
Out: `package foo
|
||||
|
||||
import (
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/appengine/taskqueue"
|
||||
)
|
||||
|
||||
func f(ctx context.Context) {
|
||||
stats, err := taskqueue.QueueStats(ctx, []string{"one", "two"})
|
||||
}
|
||||
`,
|
||||
},
|
||||
|
||||
// Check that the main "appengine" import will not be dropped
|
||||
// if an appengine.Context -> context.Context change happens
|
||||
// but the appengine package is still referenced.
|
||||
{
|
||||
Name: "ae.3",
|
||||
In: `package foo
|
||||
|
||||
import (
|
||||
"appengine"
|
||||
"io"
|
||||
)
|
||||
|
||||
func f(ctx appengine.Context, w io.Writer) {
|
||||
_ = appengine.IsDevAppServer()
|
||||
}
|
||||
`,
|
||||
Out: `package foo
|
||||
|
||||
import (
|
||||
"golang.org/x/net/context"
|
||||
"google.golang.org/appengine"
|
||||
"io"
|
||||
)
|
||||
|
||||
func f(ctx context.Context, w io.Writer) {
|
||||
_ = appengine.IsDevAppServer()
|
||||
}
|
||||
`,
|
||||
},
|
||||
}
|
|
@ -0,0 +1,848 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"go/token"
|
||||
"os"
|
||||
"path"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type fix struct {
|
||||
name string
|
||||
date string // date that fix was introduced, in YYYY-MM-DD format
|
||||
f func(*ast.File) bool
|
||||
desc string
|
||||
}
|
||||
|
||||
// main runs sort.Sort(byName(fixes)) before printing list of fixes.
|
||||
type byName []fix
|
||||
|
||||
func (f byName) Len() int { return len(f) }
|
||||
func (f byName) Swap(i, j int) { f[i], f[j] = f[j], f[i] }
|
||||
func (f byName) Less(i, j int) bool { return f[i].name < f[j].name }
|
||||
|
||||
// main runs sort.Sort(byDate(fixes)) before applying fixes.
|
||||
type byDate []fix
|
||||
|
||||
func (f byDate) Len() int { return len(f) }
|
||||
func (f byDate) Swap(i, j int) { f[i], f[j] = f[j], f[i] }
|
||||
func (f byDate) Less(i, j int) bool { return f[i].date < f[j].date }
|
||||
|
||||
var fixes []fix
|
||||
|
||||
func register(f fix) {
|
||||
fixes = append(fixes, f)
|
||||
}
|
||||
|
||||
// walk traverses the AST x, calling visit(y) for each node y in the tree but
|
||||
// also with a pointer to each ast.Expr, ast.Stmt, and *ast.BlockStmt,
|
||||
// in a bottom-up traversal.
|
||||
func walk(x interface{}, visit func(interface{})) {
|
||||
walkBeforeAfter(x, nop, visit)
|
||||
}
|
||||
|
||||
func nop(interface{}) {}
|
||||
|
||||
// walkBeforeAfter is like walk but calls before(x) before traversing
|
||||
// x's children and after(x) afterward.
|
||||
func walkBeforeAfter(x interface{}, before, after func(interface{})) {
|
||||
before(x)
|
||||
|
||||
switch n := x.(type) {
|
||||
default:
|
||||
panic(fmt.Errorf("unexpected type %T in walkBeforeAfter", x))
|
||||
|
||||
case nil:
|
||||
|
||||
// pointers to interfaces
|
||||
case *ast.Decl:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *ast.Expr:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *ast.Spec:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *ast.Stmt:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
|
||||
// pointers to struct pointers
|
||||
case **ast.BlockStmt:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case **ast.CallExpr:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case **ast.FieldList:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case **ast.FuncType:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case **ast.Ident:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case **ast.BasicLit:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
|
||||
// pointers to slices
|
||||
case *[]ast.Decl:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *[]ast.Expr:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *[]*ast.File:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *[]*ast.Ident:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *[]ast.Spec:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
case *[]ast.Stmt:
|
||||
walkBeforeAfter(*n, before, after)
|
||||
|
||||
// These are ordered and grouped to match ../../pkg/go/ast/ast.go
|
||||
case *ast.Field:
|
||||
walkBeforeAfter(&n.Names, before, after)
|
||||
walkBeforeAfter(&n.Type, before, after)
|
||||
walkBeforeAfter(&n.Tag, before, after)
|
||||
case *ast.FieldList:
|
||||
for _, field := range n.List {
|
||||
walkBeforeAfter(field, before, after)
|
||||
}
|
||||
case *ast.BadExpr:
|
||||
case *ast.Ident:
|
||||
case *ast.Ellipsis:
|
||||
walkBeforeAfter(&n.Elt, before, after)
|
||||
case *ast.BasicLit:
|
||||
case *ast.FuncLit:
|
||||
walkBeforeAfter(&n.Type, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
case *ast.CompositeLit:
|
||||
walkBeforeAfter(&n.Type, before, after)
|
||||
walkBeforeAfter(&n.Elts, before, after)
|
||||
case *ast.ParenExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
case *ast.SelectorExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
case *ast.IndexExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
walkBeforeAfter(&n.Index, before, after)
|
||||
case *ast.SliceExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
if n.Low != nil {
|
||||
walkBeforeAfter(&n.Low, before, after)
|
||||
}
|
||||
if n.High != nil {
|
||||
walkBeforeAfter(&n.High, before, after)
|
||||
}
|
||||
case *ast.TypeAssertExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
walkBeforeAfter(&n.Type, before, after)
|
||||
case *ast.CallExpr:
|
||||
walkBeforeAfter(&n.Fun, before, after)
|
||||
walkBeforeAfter(&n.Args, before, after)
|
||||
case *ast.StarExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
case *ast.UnaryExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
case *ast.BinaryExpr:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
walkBeforeAfter(&n.Y, before, after)
|
||||
case *ast.KeyValueExpr:
|
||||
walkBeforeAfter(&n.Key, before, after)
|
||||
walkBeforeAfter(&n.Value, before, after)
|
||||
|
||||
case *ast.ArrayType:
|
||||
walkBeforeAfter(&n.Len, before, after)
|
||||
walkBeforeAfter(&n.Elt, before, after)
|
||||
case *ast.StructType:
|
||||
walkBeforeAfter(&n.Fields, before, after)
|
||||
case *ast.FuncType:
|
||||
walkBeforeAfter(&n.Params, before, after)
|
||||
if n.Results != nil {
|
||||
walkBeforeAfter(&n.Results, before, after)
|
||||
}
|
||||
case *ast.InterfaceType:
|
||||
walkBeforeAfter(&n.Methods, before, after)
|
||||
case *ast.MapType:
|
||||
walkBeforeAfter(&n.Key, before, after)
|
||||
walkBeforeAfter(&n.Value, before, after)
|
||||
case *ast.ChanType:
|
||||
walkBeforeAfter(&n.Value, before, after)
|
||||
|
||||
case *ast.BadStmt:
|
||||
case *ast.DeclStmt:
|
||||
walkBeforeAfter(&n.Decl, before, after)
|
||||
case *ast.EmptyStmt:
|
||||
case *ast.LabeledStmt:
|
||||
walkBeforeAfter(&n.Stmt, before, after)
|
||||
case *ast.ExprStmt:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
case *ast.SendStmt:
|
||||
walkBeforeAfter(&n.Chan, before, after)
|
||||
walkBeforeAfter(&n.Value, before, after)
|
||||
case *ast.IncDecStmt:
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
case *ast.AssignStmt:
|
||||
walkBeforeAfter(&n.Lhs, before, after)
|
||||
walkBeforeAfter(&n.Rhs, before, after)
|
||||
case *ast.GoStmt:
|
||||
walkBeforeAfter(&n.Call, before, after)
|
||||
case *ast.DeferStmt:
|
||||
walkBeforeAfter(&n.Call, before, after)
|
||||
case *ast.ReturnStmt:
|
||||
walkBeforeAfter(&n.Results, before, after)
|
||||
case *ast.BranchStmt:
|
||||
case *ast.BlockStmt:
|
||||
walkBeforeAfter(&n.List, before, after)
|
||||
case *ast.IfStmt:
|
||||
walkBeforeAfter(&n.Init, before, after)
|
||||
walkBeforeAfter(&n.Cond, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
walkBeforeAfter(&n.Else, before, after)
|
||||
case *ast.CaseClause:
|
||||
walkBeforeAfter(&n.List, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
case *ast.SwitchStmt:
|
||||
walkBeforeAfter(&n.Init, before, after)
|
||||
walkBeforeAfter(&n.Tag, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
case *ast.TypeSwitchStmt:
|
||||
walkBeforeAfter(&n.Init, before, after)
|
||||
walkBeforeAfter(&n.Assign, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
case *ast.CommClause:
|
||||
walkBeforeAfter(&n.Comm, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
case *ast.SelectStmt:
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
case *ast.ForStmt:
|
||||
walkBeforeAfter(&n.Init, before, after)
|
||||
walkBeforeAfter(&n.Cond, before, after)
|
||||
walkBeforeAfter(&n.Post, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
case *ast.RangeStmt:
|
||||
walkBeforeAfter(&n.Key, before, after)
|
||||
walkBeforeAfter(&n.Value, before, after)
|
||||
walkBeforeAfter(&n.X, before, after)
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
|
||||
case *ast.ImportSpec:
|
||||
case *ast.ValueSpec:
|
||||
walkBeforeAfter(&n.Type, before, after)
|
||||
walkBeforeAfter(&n.Values, before, after)
|
||||
walkBeforeAfter(&n.Names, before, after)
|
||||
case *ast.TypeSpec:
|
||||
walkBeforeAfter(&n.Type, before, after)
|
||||
|
||||
case *ast.BadDecl:
|
||||
case *ast.GenDecl:
|
||||
walkBeforeAfter(&n.Specs, before, after)
|
||||
case *ast.FuncDecl:
|
||||
if n.Recv != nil {
|
||||
walkBeforeAfter(&n.Recv, before, after)
|
||||
}
|
||||
walkBeforeAfter(&n.Type, before, after)
|
||||
if n.Body != nil {
|
||||
walkBeforeAfter(&n.Body, before, after)
|
||||
}
|
||||
|
||||
case *ast.File:
|
||||
walkBeforeAfter(&n.Decls, before, after)
|
||||
|
||||
case *ast.Package:
|
||||
walkBeforeAfter(&n.Files, before, after)
|
||||
|
||||
case []*ast.File:
|
||||
for i := range n {
|
||||
walkBeforeAfter(&n[i], before, after)
|
||||
}
|
||||
case []ast.Decl:
|
||||
for i := range n {
|
||||
walkBeforeAfter(&n[i], before, after)
|
||||
}
|
||||
case []ast.Expr:
|
||||
for i := range n {
|
||||
walkBeforeAfter(&n[i], before, after)
|
||||
}
|
||||
case []*ast.Ident:
|
||||
for i := range n {
|
||||
walkBeforeAfter(&n[i], before, after)
|
||||
}
|
||||
case []ast.Stmt:
|
||||
for i := range n {
|
||||
walkBeforeAfter(&n[i], before, after)
|
||||
}
|
||||
case []ast.Spec:
|
||||
for i := range n {
|
||||
walkBeforeAfter(&n[i], before, after)
|
||||
}
|
||||
}
|
||||
after(x)
|
||||
}
|
||||
|
||||
// imports returns true if f imports path.
|
||||
func imports(f *ast.File, path string) bool {
|
||||
return importSpec(f, path) != nil
|
||||
}
|
||||
|
||||
// importSpec returns the import spec if f imports path,
|
||||
// or nil otherwise.
|
||||
func importSpec(f *ast.File, path string) *ast.ImportSpec {
|
||||
for _, s := range f.Imports {
|
||||
if importPath(s) == path {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// importPath returns the unquoted import path of s,
|
||||
// or "" if the path is not properly quoted.
|
||||
func importPath(s *ast.ImportSpec) string {
|
||||
t, err := strconv.Unquote(s.Path.Value)
|
||||
if err == nil {
|
||||
return t
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// declImports reports whether gen contains an import of path.
|
||||
func declImports(gen *ast.GenDecl, path string) bool {
|
||||
if gen.Tok != token.IMPORT {
|
||||
return false
|
||||
}
|
||||
for _, spec := range gen.Specs {
|
||||
impspec := spec.(*ast.ImportSpec)
|
||||
if importPath(impspec) == path {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isPkgDot returns true if t is the expression "pkg.name"
|
||||
// where pkg is an imported identifier.
|
||||
func isPkgDot(t ast.Expr, pkg, name string) bool {
|
||||
sel, ok := t.(*ast.SelectorExpr)
|
||||
return ok && isTopName(sel.X, pkg) && sel.Sel.String() == name
|
||||
}
|
||||
|
||||
// isPtrPkgDot returns true if f is the expression "*pkg.name"
|
||||
// where pkg is an imported identifier.
|
||||
func isPtrPkgDot(t ast.Expr, pkg, name string) bool {
|
||||
ptr, ok := t.(*ast.StarExpr)
|
||||
return ok && isPkgDot(ptr.X, pkg, name)
|
||||
}
|
||||
|
||||
// isTopName returns true if n is a top-level unresolved identifier with the given name.
|
||||
func isTopName(n ast.Expr, name string) bool {
|
||||
id, ok := n.(*ast.Ident)
|
||||
return ok && id.Name == name && id.Obj == nil
|
||||
}
|
||||
|
||||
// isName returns true if n is an identifier with the given name.
|
||||
func isName(n ast.Expr, name string) bool {
|
||||
id, ok := n.(*ast.Ident)
|
||||
return ok && id.String() == name
|
||||
}
|
||||
|
||||
// isCall returns true if t is a call to pkg.name.
|
||||
func isCall(t ast.Expr, pkg, name string) bool {
|
||||
call, ok := t.(*ast.CallExpr)
|
||||
return ok && isPkgDot(call.Fun, pkg, name)
|
||||
}
|
||||
|
||||
// If n is an *ast.Ident, isIdent returns it; otherwise isIdent returns nil.
|
||||
func isIdent(n interface{}) *ast.Ident {
|
||||
id, _ := n.(*ast.Ident)
|
||||
return id
|
||||
}
|
||||
|
||||
// refersTo returns true if n is a reference to the same object as x.
|
||||
func refersTo(n ast.Node, x *ast.Ident) bool {
|
||||
id, ok := n.(*ast.Ident)
|
||||
// The test of id.Name == x.Name handles top-level unresolved
|
||||
// identifiers, which all have Obj == nil.
|
||||
return ok && id.Obj == x.Obj && id.Name == x.Name
|
||||
}
|
||||
|
||||
// isBlank returns true if n is the blank identifier.
|
||||
func isBlank(n ast.Expr) bool {
|
||||
return isName(n, "_")
|
||||
}
|
||||
|
||||
// isEmptyString returns true if n is an empty string literal.
|
||||
func isEmptyString(n ast.Expr) bool {
|
||||
lit, ok := n.(*ast.BasicLit)
|
||||
return ok && lit.Kind == token.STRING && len(lit.Value) == 2
|
||||
}
|
||||
|
||||
func warn(pos token.Pos, msg string, args ...interface{}) {
|
||||
if pos.IsValid() {
|
||||
msg = "%s: " + msg
|
||||
arg1 := []interface{}{fset.Position(pos).String()}
|
||||
args = append(arg1, args...)
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, msg+"\n", args...)
|
||||
}
|
||||
|
||||
// countUses returns the number of uses of the identifier x in scope.
|
||||
func countUses(x *ast.Ident, scope []ast.Stmt) int {
|
||||
count := 0
|
||||
ff := func(n interface{}) {
|
||||
if n, ok := n.(ast.Node); ok && refersTo(n, x) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
for _, n := range scope {
|
||||
walk(n, ff)
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
// rewriteUses replaces all uses of the identifier x and !x in scope
|
||||
// with f(x.Pos()) and fnot(x.Pos()).
|
||||
func rewriteUses(x *ast.Ident, f, fnot func(token.Pos) ast.Expr, scope []ast.Stmt) {
|
||||
var lastF ast.Expr
|
||||
ff := func(n interface{}) {
|
||||
ptr, ok := n.(*ast.Expr)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
nn := *ptr
|
||||
|
||||
// The child node was just walked and possibly replaced.
|
||||
// If it was replaced and this is a negation, replace with fnot(p).
|
||||
not, ok := nn.(*ast.UnaryExpr)
|
||||
if ok && not.Op == token.NOT && not.X == lastF {
|
||||
*ptr = fnot(nn.Pos())
|
||||
return
|
||||
}
|
||||
if refersTo(nn, x) {
|
||||
lastF = f(nn.Pos())
|
||||
*ptr = lastF
|
||||
}
|
||||
}
|
||||
for _, n := range scope {
|
||||
walk(n, ff)
|
||||
}
|
||||
}
|
||||
|
||||
// assignsTo returns true if any of the code in scope assigns to or takes the address of x.
|
||||
func assignsTo(x *ast.Ident, scope []ast.Stmt) bool {
|
||||
assigned := false
|
||||
ff := func(n interface{}) {
|
||||
if assigned {
|
||||
return
|
||||
}
|
||||
switch n := n.(type) {
|
||||
case *ast.UnaryExpr:
|
||||
// use of &x
|
||||
if n.Op == token.AND && refersTo(n.X, x) {
|
||||
assigned = true
|
||||
return
|
||||
}
|
||||
case *ast.AssignStmt:
|
||||
for _, l := range n.Lhs {
|
||||
if refersTo(l, x) {
|
||||
assigned = true
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, n := range scope {
|
||||
if assigned {
|
||||
break
|
||||
}
|
||||
walk(n, ff)
|
||||
}
|
||||
return assigned
|
||||
}
|
||||
|
||||
// newPkgDot returns an ast.Expr referring to "pkg.name" at position pos.
|
||||
func newPkgDot(pos token.Pos, pkg, name string) ast.Expr {
|
||||
return &ast.SelectorExpr{
|
||||
X: &ast.Ident{
|
||||
NamePos: pos,
|
||||
Name: pkg,
|
||||
},
|
||||
Sel: &ast.Ident{
|
||||
NamePos: pos,
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// renameTop renames all references to the top-level name old.
|
||||
// It returns true if it makes any changes.
|
||||
func renameTop(f *ast.File, old, new string) bool {
|
||||
var fixed bool
|
||||
|
||||
// Rename any conflicting imports
|
||||
// (assuming package name is last element of path).
|
||||
for _, s := range f.Imports {
|
||||
if s.Name != nil {
|
||||
if s.Name.Name == old {
|
||||
s.Name.Name = new
|
||||
fixed = true
|
||||
}
|
||||
} else {
|
||||
_, thisName := path.Split(importPath(s))
|
||||
if thisName == old {
|
||||
s.Name = ast.NewIdent(new)
|
||||
fixed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rename any top-level declarations.
|
||||
for _, d := range f.Decls {
|
||||
switch d := d.(type) {
|
||||
case *ast.FuncDecl:
|
||||
if d.Recv == nil && d.Name.Name == old {
|
||||
d.Name.Name = new
|
||||
d.Name.Obj.Name = new
|
||||
fixed = true
|
||||
}
|
||||
case *ast.GenDecl:
|
||||
for _, s := range d.Specs {
|
||||
switch s := s.(type) {
|
||||
case *ast.TypeSpec:
|
||||
if s.Name.Name == old {
|
||||
s.Name.Name = new
|
||||
s.Name.Obj.Name = new
|
||||
fixed = true
|
||||
}
|
||||
case *ast.ValueSpec:
|
||||
for _, n := range s.Names {
|
||||
if n.Name == old {
|
||||
n.Name = new
|
||||
n.Obj.Name = new
|
||||
fixed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Rename top-level old to new, both unresolved names
|
||||
// (probably defined in another file) and names that resolve
|
||||
// to a declaration we renamed.
|
||||
walk(f, func(n interface{}) {
|
||||
id, ok := n.(*ast.Ident)
|
||||
if ok && isTopName(id, old) {
|
||||
id.Name = new
|
||||
fixed = true
|
||||
}
|
||||
if ok && id.Obj != nil && id.Name == old && id.Obj.Name == new {
|
||||
id.Name = id.Obj.Name
|
||||
fixed = true
|
||||
}
|
||||
})
|
||||
|
||||
return fixed
|
||||
}
|
||||
|
||||
// matchLen returns the length of the longest prefix shared by x and y.
|
||||
func matchLen(x, y string) int {
|
||||
i := 0
|
||||
for i < len(x) && i < len(y) && x[i] == y[i] {
|
||||
i++
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
// addImport adds the import path to the file f, if absent.
|
||||
func addImport(f *ast.File, ipath string) (added bool) {
|
||||
if imports(f, ipath) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Determine name of import.
|
||||
// Assume added imports follow convention of using last element.
|
||||
_, name := path.Split(ipath)
|
||||
|
||||
// Rename any conflicting top-level references from name to name_.
|
||||
renameTop(f, name, name+"_")
|
||||
|
||||
newImport := &ast.ImportSpec{
|
||||
Path: &ast.BasicLit{
|
||||
Kind: token.STRING,
|
||||
Value: strconv.Quote(ipath),
|
||||
},
|
||||
}
|
||||
|
||||
// Find an import decl to add to.
|
||||
var (
|
||||
bestMatch = -1
|
||||
lastImport = -1
|
||||
impDecl *ast.GenDecl
|
||||
impIndex = -1
|
||||
)
|
||||
for i, decl := range f.Decls {
|
||||
gen, ok := decl.(*ast.GenDecl)
|
||||
if ok && gen.Tok == token.IMPORT {
|
||||
lastImport = i
|
||||
// Do not add to import "C", to avoid disrupting the
|
||||
// association with its doc comment, breaking cgo.
|
||||
if declImports(gen, "C") {
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute longest shared prefix with imports in this block.
|
||||
for j, spec := range gen.Specs {
|
||||
impspec := spec.(*ast.ImportSpec)
|
||||
n := matchLen(importPath(impspec), ipath)
|
||||
if n > bestMatch {
|
||||
bestMatch = n
|
||||
impDecl = gen
|
||||
impIndex = j
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no import decl found, add one after the last import.
|
||||
if impDecl == nil {
|
||||
impDecl = &ast.GenDecl{
|
||||
Tok: token.IMPORT,
|
||||
}
|
||||
f.Decls = append(f.Decls, nil)
|
||||
copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
|
||||
f.Decls[lastImport+1] = impDecl
|
||||
}
|
||||
|
||||
// Ensure the import decl has parentheses, if needed.
|
||||
if len(impDecl.Specs) > 0 && !impDecl.Lparen.IsValid() {
|
||||
impDecl.Lparen = impDecl.Pos()
|
||||
}
|
||||
|
||||
insertAt := impIndex + 1
|
||||
if insertAt == 0 {
|
||||
insertAt = len(impDecl.Specs)
|
||||
}
|
||||
impDecl.Specs = append(impDecl.Specs, nil)
|
||||
copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
|
||||
impDecl.Specs[insertAt] = newImport
|
||||
if insertAt > 0 {
|
||||
// Assign same position as the previous import,
|
||||
// so that the sorter sees it as being in the same block.
|
||||
prev := impDecl.Specs[insertAt-1]
|
||||
newImport.Path.ValuePos = prev.Pos()
|
||||
newImport.EndPos = prev.Pos()
|
||||
}
|
||||
|
||||
f.Imports = append(f.Imports, newImport)
|
||||
return true
|
||||
}
|
||||
|
||||
// deleteImport deletes the import path from the file f, if present.
|
||||
func deleteImport(f *ast.File, path string) (deleted bool) {
|
||||
oldImport := importSpec(f, path)
|
||||
|
||||
// Find the import node that imports path, if any.
|
||||
for i, decl := range f.Decls {
|
||||
gen, ok := decl.(*ast.GenDecl)
|
||||
if !ok || gen.Tok != token.IMPORT {
|
||||
continue
|
||||
}
|
||||
for j, spec := range gen.Specs {
|
||||
impspec := spec.(*ast.ImportSpec)
|
||||
if oldImport != impspec {
|
||||
continue
|
||||
}
|
||||
|
||||
// We found an import spec that imports path.
|
||||
// Delete it.
|
||||
deleted = true
|
||||
copy(gen.Specs[j:], gen.Specs[j+1:])
|
||||
gen.Specs = gen.Specs[:len(gen.Specs)-1]
|
||||
|
||||
// If this was the last import spec in this decl,
|
||||
// delete the decl, too.
|
||||
if len(gen.Specs) == 0 {
|
||||
copy(f.Decls[i:], f.Decls[i+1:])
|
||||
f.Decls = f.Decls[:len(f.Decls)-1]
|
||||
} else if len(gen.Specs) == 1 {
|
||||
gen.Lparen = token.NoPos // drop parens
|
||||
}
|
||||
if j > 0 {
|
||||
// We deleted an entry but now there will be
|
||||
// a blank line-sized hole where the import was.
|
||||
// Close the hole by making the previous
|
||||
// import appear to "end" where this one did.
|
||||
gen.Specs[j-1].(*ast.ImportSpec).EndPos = impspec.End()
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Delete it from f.Imports.
|
||||
for i, imp := range f.Imports {
|
||||
if imp == oldImport {
|
||||
copy(f.Imports[i:], f.Imports[i+1:])
|
||||
f.Imports = f.Imports[:len(f.Imports)-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// rewriteImport rewrites any import of path oldPath to path newPath.
|
||||
func rewriteImport(f *ast.File, oldPath, newPath string) (rewrote bool) {
|
||||
for _, imp := range f.Imports {
|
||||
if importPath(imp) == oldPath {
|
||||
rewrote = true
|
||||
// record old End, because the default is to compute
|
||||
// it using the length of imp.Path.Value.
|
||||
imp.EndPos = imp.End()
|
||||
imp.Path.Value = strconv.Quote(newPath)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func usesImport(f *ast.File, path string) (used bool) {
|
||||
spec := importSpec(f, path)
|
||||
if spec == nil {
|
||||
return
|
||||
}
|
||||
|
||||
name := spec.Name.String()
|
||||
switch name {
|
||||
case "<nil>":
|
||||
// If the package name is not explicitly specified,
|
||||
// make an educated guess. This is not guaranteed to be correct.
|
||||
lastSlash := strings.LastIndex(path, "/")
|
||||
if lastSlash == -1 {
|
||||
name = path
|
||||
} else {
|
||||
name = path[lastSlash+1:]
|
||||
}
|
||||
case "_", ".":
|
||||
// Not sure if this import is used - err on the side of caution.
|
||||
return true
|
||||
}
|
||||
|
||||
walk(f, func(n interface{}) {
|
||||
sel, ok := n.(*ast.SelectorExpr)
|
||||
if ok && isTopName(sel.X, name) {
|
||||
used = true
|
||||
}
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func expr(s string) ast.Expr {
|
||||
x, err := parser.ParseExpr(s)
|
||||
if err != nil {
|
||||
panic("parsing " + s + ": " + err.Error())
|
||||
}
|
||||
// Remove position information to avoid spurious newlines.
|
||||
killPos(reflect.ValueOf(x))
|
||||
return x
|
||||
}
|
||||
|
||||
var posType = reflect.TypeOf(token.Pos(0))
|
||||
|
||||
func killPos(v reflect.Value) {
|
||||
switch v.Kind() {
|
||||
case reflect.Ptr, reflect.Interface:
|
||||
if !v.IsNil() {
|
||||
killPos(v.Elem())
|
||||
}
|
||||
case reflect.Slice:
|
||||
n := v.Len()
|
||||
for i := 0; i < n; i++ {
|
||||
killPos(v.Index(i))
|
||||
}
|
||||
case reflect.Struct:
|
||||
n := v.NumField()
|
||||
for i := 0; i < n; i++ {
|
||||
f := v.Field(i)
|
||||
if f.Type() == posType {
|
||||
f.SetInt(0)
|
||||
continue
|
||||
}
|
||||
killPos(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A Rename describes a single renaming.
|
||||
type rename struct {
|
||||
OldImport string // only apply rename if this import is present
|
||||
NewImport string // add this import during rewrite
|
||||
Old string // old name: p.T or *p.T
|
||||
New string // new name: p.T or *p.T
|
||||
}
|
||||
|
||||
func renameFix(tab []rename) func(*ast.File) bool {
|
||||
return func(f *ast.File) bool {
|
||||
return renameFixTab(f, tab)
|
||||
}
|
||||
}
|
||||
|
||||
func parseName(s string) (ptr bool, pkg, nam string) {
|
||||
i := strings.Index(s, ".")
|
||||
if i < 0 {
|
||||
panic("parseName: invalid name " + s)
|
||||
}
|
||||
if strings.HasPrefix(s, "*") {
|
||||
ptr = true
|
||||
s = s[1:]
|
||||
i--
|
||||
}
|
||||
pkg = s[:i]
|
||||
nam = s[i+1:]
|
||||
return
|
||||
}
|
||||
|
||||
func renameFixTab(f *ast.File, tab []rename) bool {
|
||||
fixed := false
|
||||
added := map[string]bool{}
|
||||
check := map[string]bool{}
|
||||
for _, t := range tab {
|
||||
if !imports(f, t.OldImport) {
|
||||
continue
|
||||
}
|
||||
optr, opkg, onam := parseName(t.Old)
|
||||
walk(f, func(n interface{}) {
|
||||
np, ok := n.(*ast.Expr)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
x := *np
|
||||
if optr {
|
||||
p, ok := x.(*ast.StarExpr)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
x = p.X
|
||||
}
|
||||
if !isPkgDot(x, opkg, onam) {
|
||||
return
|
||||
}
|
||||
if t.NewImport != "" && !added[t.NewImport] {
|
||||
addImport(f, t.NewImport)
|
||||
added[t.NewImport] = true
|
||||
}
|
||||
*np = expr(t.New)
|
||||
check[t.OldImport] = true
|
||||
fixed = true
|
||||
})
|
||||
}
|
||||
|
||||
for ipath := range check {
|
||||
if !usesImport(f, ipath) {
|
||||
deleteImport(f, ipath)
|
||||
}
|
||||
}
|
||||
return fixed
|
||||
}
|
|
@ -0,0 +1,258 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"flag"
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/format"
|
||||
"go/parser"
|
||||
"go/scanner"
|
||||
"go/token"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var (
|
||||
fset = token.NewFileSet()
|
||||
exitCode = 0
|
||||
)
|
||||
|
||||
var allowedRewrites = flag.String("r", "",
|
||||
"restrict the rewrites to this comma-separated list")
|
||||
|
||||
var forceRewrites = flag.String("force", "",
|
||||
"force these fixes to run even if the code looks updated")
|
||||
|
||||
var allowed, force map[string]bool
|
||||
|
||||
var doDiff = flag.Bool("diff", false, "display diffs instead of rewriting files")
|
||||
|
||||
// enable for debugging fix failures
|
||||
const debug = false // display incorrectly reformatted source and exit
|
||||
|
||||
func usage() {
|
||||
fmt.Fprintf(os.Stderr, "usage: aefix [-diff] [-r fixname,...] [-force fixname,...] [path ...]\n")
|
||||
flag.PrintDefaults()
|
||||
fmt.Fprintf(os.Stderr, "\nAvailable rewrites are:\n")
|
||||
sort.Sort(byName(fixes))
|
||||
for _, f := range fixes {
|
||||
fmt.Fprintf(os.Stderr, "\n%s\n", f.name)
|
||||
desc := strings.TrimSpace(f.desc)
|
||||
desc = strings.Replace(desc, "\n", "\n\t", -1)
|
||||
fmt.Fprintf(os.Stderr, "\t%s\n", desc)
|
||||
}
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
func main() {
|
||||
flag.Usage = usage
|
||||
flag.Parse()
|
||||
|
||||
sort.Sort(byDate(fixes))
|
||||
|
||||
if *allowedRewrites != "" {
|
||||
allowed = make(map[string]bool)
|
||||
for _, f := range strings.Split(*allowedRewrites, ",") {
|
||||
allowed[f] = true
|
||||
}
|
||||
}
|
||||
|
||||
if *forceRewrites != "" {
|
||||
force = make(map[string]bool)
|
||||
for _, f := range strings.Split(*forceRewrites, ",") {
|
||||
force[f] = true
|
||||
}
|
||||
}
|
||||
|
||||
if flag.NArg() == 0 {
|
||||
if err := processFile("standard input", true); err != nil {
|
||||
report(err)
|
||||
}
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
for i := 0; i < flag.NArg(); i++ {
|
||||
path := flag.Arg(i)
|
||||
switch dir, err := os.Stat(path); {
|
||||
case err != nil:
|
||||
report(err)
|
||||
case dir.IsDir():
|
||||
walkDir(path)
|
||||
default:
|
||||
if err := processFile(path, false); err != nil {
|
||||
report(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
|
||||
const parserMode = parser.ParseComments
|
||||
|
||||
func gofmtFile(f *ast.File) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := format.Node(&buf, fset, f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func processFile(filename string, useStdin bool) error {
|
||||
var f *os.File
|
||||
var err error
|
||||
var fixlog bytes.Buffer
|
||||
|
||||
if useStdin {
|
||||
f = os.Stdin
|
||||
} else {
|
||||
f, err = os.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
}
|
||||
|
||||
src, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
file, err := parser.ParseFile(fset, filename, src, parserMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply all fixes to file.
|
||||
newFile := file
|
||||
fixed := false
|
||||
for _, fix := range fixes {
|
||||
if allowed != nil && !allowed[fix.name] {
|
||||
continue
|
||||
}
|
||||
if fix.f(newFile) {
|
||||
fixed = true
|
||||
fmt.Fprintf(&fixlog, " %s", fix.name)
|
||||
|
||||
// AST changed.
|
||||
// Print and parse, to update any missing scoping
|
||||
// or position information for subsequent fixers.
|
||||
newSrc, err := gofmtFile(newFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newFile, err = parser.ParseFile(fset, filename, newSrc, parserMode)
|
||||
if err != nil {
|
||||
if debug {
|
||||
fmt.Printf("%s", newSrc)
|
||||
report(err)
|
||||
os.Exit(exitCode)
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
if !fixed {
|
||||
return nil
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "%s: fixed %s\n", filename, fixlog.String()[1:])
|
||||
|
||||
// Print AST. We did that after each fix, so this appears
|
||||
// redundant, but it is necessary to generate gofmt-compatible
|
||||
// source code in a few cases. The official gofmt style is the
|
||||
// output of the printer run on a standard AST generated by the parser,
|
||||
// but the source we generated inside the loop above is the
|
||||
// output of the printer run on a mangled AST generated by a fixer.
|
||||
newSrc, err := gofmtFile(newFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *doDiff {
|
||||
data, err := diff(src, newSrc)
|
||||
if err != nil {
|
||||
return fmt.Errorf("computing diff: %s", err)
|
||||
}
|
||||
fmt.Printf("diff %s fixed/%s\n", filename, filename)
|
||||
os.Stdout.Write(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
if useStdin {
|
||||
os.Stdout.Write(newSrc)
|
||||
return nil
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(f.Name(), newSrc, 0)
|
||||
}
|
||||
|
||||
var gofmtBuf bytes.Buffer
|
||||
|
||||
func gofmt(n interface{}) string {
|
||||
gofmtBuf.Reset()
|
||||
if err := format.Node(&gofmtBuf, fset, n); err != nil {
|
||||
return "<" + err.Error() + ">"
|
||||
}
|
||||
return gofmtBuf.String()
|
||||
}
|
||||
|
||||
func report(err error) {
|
||||
scanner.PrintError(os.Stderr, err)
|
||||
exitCode = 2
|
||||
}
|
||||
|
||||
func walkDir(path string) {
|
||||
filepath.Walk(path, visitFile)
|
||||
}
|
||||
|
||||
func visitFile(path string, f os.FileInfo, err error) error {
|
||||
if err == nil && isGoFile(f) {
|
||||
err = processFile(path, false)
|
||||
}
|
||||
if err != nil {
|
||||
report(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isGoFile(f os.FileInfo) bool {
|
||||
// ignore non-Go files
|
||||
name := f.Name()
|
||||
return !f.IsDir() && !strings.HasPrefix(name, ".") && strings.HasSuffix(name, ".go")
|
||||
}
|
||||
|
||||
func diff(b1, b2 []byte) (data []byte, err error) {
|
||||
f1, err := ioutil.TempFile("", "go-fix")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.Remove(f1.Name())
|
||||
defer f1.Close()
|
||||
|
||||
f2, err := ioutil.TempFile("", "go-fix")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer os.Remove(f2.Name())
|
||||
defer f2.Close()
|
||||
|
||||
f1.Write(b1)
|
||||
f2.Write(b2)
|
||||
|
||||
data, err = exec.Command("diff", "-u", f1.Name(), f2.Name()).CombinedOutput()
|
||||
if len(data) > 0 {
|
||||
// diff exits with a non-zero status when the files don't match.
|
||||
// Ignore that failure as long as we get output.
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
|
@ -0,0 +1,129 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"go/ast"
|
||||
"go/parser"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type testCase struct {
|
||||
Name string
|
||||
Fn func(*ast.File) bool
|
||||
In string
|
||||
Out string
|
||||
}
|
||||
|
||||
var testCases []testCase
|
||||
|
||||
func addTestCases(t []testCase, fn func(*ast.File) bool) {
|
||||
// Fill in fn to avoid repetition in definitions.
|
||||
if fn != nil {
|
||||
for i := range t {
|
||||
if t[i].Fn == nil {
|
||||
t[i].Fn = fn
|
||||
}
|
||||
}
|
||||
}
|
||||
testCases = append(testCases, t...)
|
||||
}
|
||||
|
||||
func fnop(*ast.File) bool { return false }
|
||||
|
||||
func parseFixPrint(t *testing.T, fn func(*ast.File) bool, desc, in string, mustBeGofmt bool) (out string, fixed, ok bool) {
|
||||
file, err := parser.ParseFile(fset, desc, in, parserMode)
|
||||
if err != nil {
|
||||
t.Errorf("%s: parsing: %v", desc, err)
|
||||
return
|
||||
}
|
||||
|
||||
outb, err := gofmtFile(file)
|
||||
if err != nil {
|
||||
t.Errorf("%s: printing: %v", desc, err)
|
||||
return
|
||||
}
|
||||
if s := string(outb); in != s && mustBeGofmt {
|
||||
t.Errorf("%s: not gofmt-formatted.\n--- %s\n%s\n--- %s | gofmt\n%s",
|
||||
desc, desc, in, desc, s)
|
||||
tdiff(t, in, s)
|
||||
return
|
||||
}
|
||||
|
||||
if fn == nil {
|
||||
for _, fix := range fixes {
|
||||
if fix.f(file) {
|
||||
fixed = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
fixed = fn(file)
|
||||
}
|
||||
|
||||
outb, err = gofmtFile(file)
|
||||
if err != nil {
|
||||
t.Errorf("%s: printing: %v", desc, err)
|
||||
return
|
||||
}
|
||||
|
||||
return string(outb), fixed, true
|
||||
}
|
||||
|
||||
func TestRewrite(t *testing.T) {
|
||||
for _, tt := range testCases {
|
||||
// Apply fix: should get tt.Out.
|
||||
out, fixed, ok := parseFixPrint(t, tt.Fn, tt.Name, tt.In, true)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// reformat to get printing right
|
||||
out, _, ok = parseFixPrint(t, fnop, tt.Name, out, false)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if out != tt.Out {
|
||||
t.Errorf("%s: incorrect output.\n", tt.Name)
|
||||
if !strings.HasPrefix(tt.Name, "testdata/") {
|
||||
t.Errorf("--- have\n%s\n--- want\n%s", out, tt.Out)
|
||||
}
|
||||
tdiff(t, out, tt.Out)
|
||||
continue
|
||||
}
|
||||
|
||||
if changed := out != tt.In; changed != fixed {
|
||||
t.Errorf("%s: changed=%v != fixed=%v", tt.Name, changed, fixed)
|
||||
continue
|
||||
}
|
||||
|
||||
// Should not change if run again.
|
||||
out2, fixed2, ok := parseFixPrint(t, tt.Fn, tt.Name+" output", out, true)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if fixed2 {
|
||||
t.Errorf("%s: applied fixes during second round", tt.Name)
|
||||
continue
|
||||
}
|
||||
|
||||
if out2 != out {
|
||||
t.Errorf("%s: changed output after second round of fixes.\n--- output after first round\n%s\n--- output after second round\n%s",
|
||||
tt.Name, out, out2)
|
||||
tdiff(t, out, out2)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func tdiff(t *testing.T, a, b string) {
|
||||
data, err := diff([]byte(a), []byte(b))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
t.Error(string(data))
|
||||
}
|
|
@ -0,0 +1,673 @@
|
|||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"go/ast"
|
||||
"go/token"
|
||||
"os"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Partial type checker.
|
||||
//
|
||||
// The fact that it is partial is very important: the input is
|
||||
// an AST and a description of some type information to
|
||||
// assume about one or more packages, but not all the
|
||||
// packages that the program imports. The checker is
|
||||
// expected to do as much as it can with what it has been
|
||||
// given. There is not enough information supplied to do
|
||||
// a full type check, but the type checker is expected to
|
||||
// apply information that can be derived from variable
|
||||
// declarations, function and method returns, and type switches
|
||||
// as far as it can, so that the caller can still tell the types
|
||||
// of expression relevant to a particular fix.
|
||||
//
|
||||
// TODO(rsc,gri): Replace with go/typechecker.
|
||||
// Doing that could be an interesting test case for go/typechecker:
|
||||
// the constraints about working with partial information will
|
||||
// likely exercise it in interesting ways. The ideal interface would
|
||||
// be to pass typecheck a map from importpath to package API text
|
||||
// (Go source code), but for now we use data structures (TypeConfig, Type).
|
||||
//
|
||||
// The strings mostly use gofmt form.
|
||||
//
|
||||
// A Field or FieldList has as its type a comma-separated list
|
||||
// of the types of the fields. For example, the field list
|
||||
// x, y, z int
|
||||
// has type "int, int, int".
|
||||
|
||||
// The prefix "type " is the type of a type.
|
||||
// For example, given
|
||||
// var x int
|
||||
// type T int
|
||||
// x's type is "int" but T's type is "type int".
|
||||
// mkType inserts the "type " prefix.
|
||||
// getType removes it.
|
||||
// isType tests for it.
|
||||
|
||||
func mkType(t string) string {
|
||||
return "type " + t
|
||||
}
|
||||
|
||||
func getType(t string) string {
|
||||
if !isType(t) {
|
||||
return ""
|
||||
}
|
||||
return t[len("type "):]
|
||||
}
|
||||
|
||||
func isType(t string) bool {
|
||||
return strings.HasPrefix(t, "type ")
|
||||
}
|
||||
|
||||
// TypeConfig describes the universe of relevant types.
|
||||
// For ease of creation, the types are all referred to by string
|
||||
// name (e.g., "reflect.Value"). TypeByName is the only place
|
||||
// where the strings are resolved.
|
||||
|
||||
type TypeConfig struct {
|
||||
Type map[string]*Type
|
||||
Var map[string]string
|
||||
Func map[string]string
|
||||
}
|
||||
|
||||
// typeof returns the type of the given name, which may be of
|
||||
// the form "x" or "p.X".
|
||||
func (cfg *TypeConfig) typeof(name string) string {
|
||||
if cfg.Var != nil {
|
||||
if t := cfg.Var[name]; t != "" {
|
||||
return t
|
||||
}
|
||||
}
|
||||
if cfg.Func != nil {
|
||||
if t := cfg.Func[name]; t != "" {
|
||||
return "func()" + t
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Type describes the Fields and Methods of a type.
|
||||
// If the field or method cannot be found there, it is next
|
||||
// looked for in the Embed list.
|
||||
type Type struct {
|
||||
Field map[string]string // map field name to type
|
||||
Method map[string]string // map method name to comma-separated return types (should start with "func ")
|
||||
Embed []string // list of types this type embeds (for extra methods)
|
||||
Def string // definition of named type
|
||||
}
|
||||
|
||||
// dot returns the type of "typ.name", making its decision
|
||||
// using the type information in cfg.
|
||||
func (typ *Type) dot(cfg *TypeConfig, name string) string {
|
||||
if typ.Field != nil {
|
||||
if t := typ.Field[name]; t != "" {
|
||||
return t
|
||||
}
|
||||
}
|
||||
if typ.Method != nil {
|
||||
if t := typ.Method[name]; t != "" {
|
||||
return t
|
||||
}
|
||||
}
|
||||
|
||||
for _, e := range typ.Embed {
|
||||
etyp := cfg.Type[e]
|
||||
if etyp != nil {
|
||||
if t := etyp.dot(cfg, name); t != "" {
|
||||
return t
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// typecheck type checks the AST f assuming the information in cfg.
|
||||
// It returns two maps with type information:
|
||||
// typeof maps AST nodes to type information in gofmt string form.
|
||||
// assign maps type strings to lists of expressions that were assigned
|
||||
// to values of another type that were assigned to that type.
|
||||
func typecheck(cfg *TypeConfig, f *ast.File) (typeof map[interface{}]string, assign map[string][]interface{}) {
|
||||
typeof = make(map[interface{}]string)
|
||||
assign = make(map[string][]interface{})
|
||||
cfg1 := &TypeConfig{}
|
||||
*cfg1 = *cfg // make copy so we can add locally
|
||||
copied := false
|
||||
|
||||
// gather function declarations
|
||||
for _, decl := range f.Decls {
|
||||
fn, ok := decl.(*ast.FuncDecl)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
typecheck1(cfg, fn.Type, typeof, assign)
|
||||
t := typeof[fn.Type]
|
||||
if fn.Recv != nil {
|
||||
// The receiver must be a type.
|
||||
rcvr := typeof[fn.Recv]
|
||||
if !isType(rcvr) {
|
||||
if len(fn.Recv.List) != 1 {
|
||||
continue
|
||||
}
|
||||
rcvr = mkType(gofmt(fn.Recv.List[0].Type))
|
||||
typeof[fn.Recv.List[0].Type] = rcvr
|
||||
}
|
||||
rcvr = getType(rcvr)
|
||||
if rcvr != "" && rcvr[0] == '*' {
|
||||
rcvr = rcvr[1:]
|
||||
}
|
||||
typeof[rcvr+"."+fn.Name.Name] = t
|
||||
} else {
|
||||
if isType(t) {
|
||||
t = getType(t)
|
||||
} else {
|
||||
t = gofmt(fn.Type)
|
||||
}
|
||||
typeof[fn.Name] = t
|
||||
|
||||
// Record typeof[fn.Name.Obj] for future references to fn.Name.
|
||||
typeof[fn.Name.Obj] = t
|
||||
}
|
||||
}
|
||||
|
||||
// gather struct declarations
|
||||
for _, decl := range f.Decls {
|
||||
d, ok := decl.(*ast.GenDecl)
|
||||
if ok {
|
||||
for _, s := range d.Specs {
|
||||
switch s := s.(type) {
|
||||
case *ast.TypeSpec:
|
||||
if cfg1.Type[s.Name.Name] != nil {
|
||||
break
|
||||
}
|
||||
if !copied {
|
||||
copied = true
|
||||
// Copy map lazily: it's time.
|
||||
cfg1.Type = make(map[string]*Type)
|
||||
for k, v := range cfg.Type {
|
||||
cfg1.Type[k] = v
|
||||
}
|
||||
}
|
||||
t := &Type{Field: map[string]string{}}
|
||||
cfg1.Type[s.Name.Name] = t
|
||||
switch st := s.Type.(type) {
|
||||
case *ast.StructType:
|
||||
for _, f := range st.Fields.List {
|
||||
for _, n := range f.Names {
|
||||
t.Field[n.Name] = gofmt(f.Type)
|
||||
}
|
||||
}
|
||||
case *ast.ArrayType, *ast.StarExpr, *ast.MapType:
|
||||
t.Def = gofmt(st)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
typecheck1(cfg1, f, typeof, assign)
|
||||
return typeof, assign
|
||||
}
|
||||
|
||||
func makeExprList(a []*ast.Ident) []ast.Expr {
|
||||
var b []ast.Expr
|
||||
for _, x := range a {
|
||||
b = append(b, x)
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
||||
// Typecheck1 is the recursive form of typecheck.
|
||||
// It is like typecheck but adds to the information in typeof
|
||||
// instead of allocating a new map.
|
||||
func typecheck1(cfg *TypeConfig, f interface{}, typeof map[interface{}]string, assign map[string][]interface{}) {
|
||||
// set sets the type of n to typ.
|
||||
// If isDecl is true, n is being declared.
|
||||
set := func(n ast.Expr, typ string, isDecl bool) {
|
||||
if typeof[n] != "" || typ == "" {
|
||||
if typeof[n] != typ {
|
||||
assign[typ] = append(assign[typ], n)
|
||||
}
|
||||
return
|
||||
}
|
||||
typeof[n] = typ
|
||||
|
||||
// If we obtained typ from the declaration of x
|
||||
// propagate the type to all the uses.
|
||||
// The !isDecl case is a cheat here, but it makes
|
||||
// up in some cases for not paying attention to
|
||||
// struct fields. The real type checker will be
|
||||
// more accurate so we won't need the cheat.
|
||||
if id, ok := n.(*ast.Ident); ok && id.Obj != nil && (isDecl || typeof[id.Obj] == "") {
|
||||
typeof[id.Obj] = typ
|
||||
}
|
||||
}
|
||||
|
||||
// Type-check an assignment lhs = rhs.
|
||||
// If isDecl is true, this is := so we can update
|
||||
// the types of the objects that lhs refers to.
|
||||
typecheckAssign := func(lhs, rhs []ast.Expr, isDecl bool) {
|
||||
if len(lhs) > 1 && len(rhs) == 1 {
|
||||
if _, ok := rhs[0].(*ast.CallExpr); ok {
|
||||
t := split(typeof[rhs[0]])
|
||||
// Lists should have same length but may not; pair what can be paired.
|
||||
for i := 0; i < len(lhs) && i < len(t); i++ {
|
||||
set(lhs[i], t[i], isDecl)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if len(lhs) == 1 && len(rhs) == 2 {
|
||||
// x = y, ok
|
||||
rhs = rhs[:1]
|
||||
} else if len(lhs) == 2 && len(rhs) == 1 {
|
||||
// x, ok = y
|
||||
lhs = lhs[:1]
|
||||
}
|
||||
|
||||
// Match as much as we can.
|
||||
for i := 0; i < len(lhs) && i < len(rhs); i++ {
|
||||
x, y := lhs[i], rhs[i]
|
||||
if typeof[y] != "" {
|
||||
set(x, typeof[y], isDecl)
|
||||
} else {
|
||||
set(y, typeof[x], false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expand := func(s string) string {
|
||||
typ := cfg.Type[s]
|
||||
if typ != nil && typ.Def != "" {
|
||||
return typ.Def
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
// The main type check is a recursive algorithm implemented
|
||||
// by walkBeforeAfter(n, before, after).
|
||||
// Most of it is bottom-up, but in a few places we need
|
||||
// to know the type of the function we are checking.
|
||||
// The before function records that information on
|
||||
// the curfn stack.
|
||||
var curfn []*ast.FuncType
|
||||
|
||||
before := func(n interface{}) {
|
||||
// push function type on stack
|
||||
switch n := n.(type) {
|
||||
case *ast.FuncDecl:
|
||||
curfn = append(curfn, n.Type)
|
||||
case *ast.FuncLit:
|
||||
curfn = append(curfn, n.Type)
|
||||
}
|
||||
}
|
||||
|
||||
// After is the real type checker.
|
||||
after := func(n interface{}) {
|
||||
if n == nil {
|
||||
return
|
||||
}
|
||||
if false && reflect.TypeOf(n).Kind() == reflect.Ptr { // debugging trace
|
||||
defer func() {
|
||||
if t := typeof[n]; t != "" {
|
||||
pos := fset.Position(n.(ast.Node).Pos())
|
||||
fmt.Fprintf(os.Stderr, "%s: typeof[%s] = %s\n", pos, gofmt(n), t)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
switch n := n.(type) {
|
||||
case *ast.FuncDecl, *ast.FuncLit:
|
||||
// pop function type off stack
|
||||
curfn = curfn[:len(curfn)-1]
|
||||
|
||||
case *ast.FuncType:
|
||||
typeof[n] = mkType(joinFunc(split(typeof[n.Params]), split(typeof[n.Results])))
|
||||
|
||||
case *ast.FieldList:
|
||||
// Field list is concatenation of sub-lists.
|
||||
t := ""
|
||||
for _, field := range n.List {
|
||||
if t != "" {
|
||||
t += ", "
|
||||
}
|
||||
t += typeof[field]
|
||||
}
|
||||
typeof[n] = t
|
||||
|
||||
case *ast.Field:
|
||||
// Field is one instance of the type per name.
|
||||
all := ""
|
||||
t := typeof[n.Type]
|
||||
if !isType(t) {
|
||||
// Create a type, because it is typically *T or *p.T
|
||||
// and we might care about that type.
|
||||
t = mkType(gofmt(n.Type))
|
||||
typeof[n.Type] = t
|
||||
}
|
||||
t = getType(t)
|
||||
if len(n.Names) == 0 {
|
||||
all = t
|
||||
} else {
|
||||
for _, id := range n.Names {
|
||||
if all != "" {
|
||||
all += ", "
|
||||
}
|
||||
all += t
|
||||
typeof[id.Obj] = t
|
||||
typeof[id] = t
|
||||
}
|
||||
}
|
||||
typeof[n] = all
|
||||
|
||||
case *ast.ValueSpec:
|
||||
// var declaration. Use type if present.
|
||||
if n.Type != nil {
|
||||
t := typeof[n.Type]
|
||||
if !isType(t) {
|
||||
t = mkType(gofmt(n.Type))
|
||||
typeof[n.Type] = t
|
||||
}
|
||||
t = getType(t)
|
||||
for _, id := range n.Names {
|
||||
set(id, t, true)
|
||||
}
|
||||
}
|
||||
// Now treat same as assignment.
|
||||
typecheckAssign(makeExprList(n.Names), n.Values, true)
|
||||
|
||||
case *ast.AssignStmt:
|
||||
typecheckAssign(n.Lhs, n.Rhs, n.Tok == token.DEFINE)
|
||||
|
||||
case *ast.Ident:
|
||||
// Identifier can take its type from underlying object.
|
||||
if t := typeof[n.Obj]; t != "" {
|
||||
typeof[n] = t
|
||||
}
|
||||
|
||||
case *ast.SelectorExpr:
|
||||
// Field or method.
|
||||
name := n.Sel.Name
|
||||
if t := typeof[n.X]; t != "" {
|
||||
if strings.HasPrefix(t, "*") {
|
||||
t = t[1:] // implicit *
|
||||
}
|
||||
if typ := cfg.Type[t]; typ != nil {
|
||||
if t := typ.dot(cfg, name); t != "" {
|
||||
typeof[n] = t
|
||||
return
|
||||
}
|
||||
}
|
||||
tt := typeof[t+"."+name]
|
||||
if isType(tt) {
|
||||
typeof[n] = getType(tt)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Package selector.
|
||||
if x, ok := n.X.(*ast.Ident); ok && x.Obj == nil {
|
||||
str := x.Name + "." + name
|
||||
if cfg.Type[str] != nil {
|
||||
typeof[n] = mkType(str)
|
||||
return
|
||||
}
|
||||
if t := cfg.typeof(x.Name + "." + name); t != "" {
|
||||
typeof[n] = t
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
case *ast.CallExpr:
|
||||
// make(T) has type T.
|
||||
if isTopName(n.Fun, "make") && len(n.Args) >= 1 {
|
||||
typeof[n] = gofmt(n.Args[0])
|
||||
return
|
||||
}
|
||||
// new(T) has type *T
|
||||
if isTopName(n.Fun, "new") && len(n.Args) == 1 {
|
||||
typeof[n] = "*" + gofmt(n.Args[0])
|
||||
return
|
||||
}
|
||||
// Otherwise, use type of function to determine arguments.
|
||||
t := typeof[n.Fun]
|
||||
in, out := splitFunc(t)
|
||||
if in == nil && out == nil {
|
||||
return
|
||||
}
|
||||
typeof[n] = join(out)
|
||||
for i, arg := range n.Args {
|
||||
if i >= len(in) {
|
||||
break
|
||||
}
|
||||
if typeof[arg] == "" {
|
||||
typeof[arg] = in[i]
|
||||
}
|
||||
}
|
||||
|
||||
case *ast.TypeAssertExpr:
|
||||
// x.(type) has type of x.
|
||||
if n.Type == nil {
|
||||
typeof[n] = typeof[n.X]
|
||||
return
|
||||
}
|
||||
// x.(T) has type T.
|
||||
if t := typeof[n.Type]; isType(t) {
|
||||
typeof[n] = getType(t)
|
||||
} else {
|
||||
typeof[n] = gofmt(n.Type)
|
||||
}
|
||||
|
||||
case *ast.SliceExpr:
|
||||
// x[i:j] has type of x.
|
||||
typeof[n] = typeof[n.X]
|
||||
|
||||
case *ast.IndexExpr:
|
||||
// x[i] has key type of x's type.
|
||||
t := expand(typeof[n.X])
|
||||
if strings.HasPrefix(t, "[") || strings.HasPrefix(t, "map[") {
|
||||
// Lazy: assume there are no nested [] in the array
|
||||
// length or map key type.
|
||||
if i := strings.Index(t, "]"); i >= 0 {
|
||||
typeof[n] = t[i+1:]
|
||||
}
|
||||
}
|
||||
|
||||
case *ast.StarExpr:
|
||||
// *x for x of type *T has type T when x is an expr.
|
||||
// We don't use the result when *x is a type, but
|
||||
// compute it anyway.
|
||||
t := expand(typeof[n.X])
|
||||
if isType(t) {
|
||||
typeof[n] = "type *" + getType(t)
|
||||
} else if strings.HasPrefix(t, "*") {
|
||||
typeof[n] = t[len("*"):]
|
||||
}
|
||||
|
||||
case *ast.UnaryExpr:
|
||||
// &x for x of type T has type *T.
|
||||
t := typeof[n.X]
|
||||
if t != "" && n.Op == token.AND {
|
||||
typeof[n] = "*" + t
|
||||
}
|
||||
|
||||
case *ast.CompositeLit:
|
||||
// T{...} has type T.
|
||||
typeof[n] = gofmt(n.Type)
|
||||
|
||||
case *ast.ParenExpr:
|
||||
// (x) has type of x.
|
||||
typeof[n] = typeof[n.X]
|
||||
|
||||
case *ast.RangeStmt:
|
||||
t := expand(typeof[n.X])
|
||||
if t == "" {
|
||||
return
|
||||
}
|
||||
var key, value string
|
||||
if t == "string" {
|
||||
key, value = "int", "rune"
|
||||
} else if strings.HasPrefix(t, "[") {
|
||||
key = "int"
|
||||
if i := strings.Index(t, "]"); i >= 0 {
|
||||
value = t[i+1:]
|
||||
}
|
||||
} else if strings.HasPrefix(t, "map[") {
|
||||
if i := strings.Index(t, "]"); i >= 0 {
|
||||
key, value = t[4:i], t[i+1:]
|
||||
}
|
||||
}
|
||||
changed := false
|
||||
if n.Key != nil && key != "" {
|
||||
changed = true
|
||||
set(n.Key, key, n.Tok == token.DEFINE)
|
||||
}
|
||||
if n.Value != nil && value != "" {
|
||||
changed = true
|
||||
set(n.Value, value, n.Tok == token.DEFINE)
|
||||
}
|
||||
// Ugly failure of vision: already type-checked body.
|
||||
// Do it again now that we have that type info.
|
||||
if changed {
|
||||
typecheck1(cfg, n.Body, typeof, assign)
|
||||
}
|
||||
|
||||
case *ast.TypeSwitchStmt:
|
||||
// Type of variable changes for each case in type switch,
|
||||
// but go/parser generates just one variable.
|
||||
// Repeat type check for each case with more precise
|
||||
// type information.
|
||||
as, ok := n.Assign.(*ast.AssignStmt)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
varx, ok := as.Lhs[0].(*ast.Ident)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
t := typeof[varx]
|
||||
for _, cas := range n.Body.List {
|
||||
cas := cas.(*ast.CaseClause)
|
||||
if len(cas.List) == 1 {
|
||||
// Variable has specific type only when there is
|
||||
// exactly one type in the case list.
|
||||
if tt := typeof[cas.List[0]]; isType(tt) {
|
||||
tt = getType(tt)
|
||||
typeof[varx] = tt
|
||||
typeof[varx.Obj] = tt
|
||||
typecheck1(cfg, cas.Body, typeof, assign)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Restore t.
|
||||
typeof[varx] = t
|
||||
typeof[varx.Obj] = t
|
||||
|
||||
case *ast.ReturnStmt:
|
||||
if len(curfn) == 0 {
|
||||
// Probably can't happen.
|
||||
return
|
||||
}
|
||||
f := curfn[len(curfn)-1]
|
||||
res := n.Results
|
||||
if f.Results != nil {
|
||||
t := split(typeof[f.Results])
|
||||
for i := 0; i < len(res) && i < len(t); i++ {
|
||||
set(res[i], t[i], false)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
walkBeforeAfter(f, before, after)
|
||||
}
|
||||
|
||||
// Convert between function type strings and lists of types.
|
||||
// Using strings makes this a little harder, but it makes
|
||||
// a lot of the rest of the code easier. This will all go away
|
||||
// when we can use go/typechecker directly.
|
||||
|
||||
// splitFunc splits "func(x,y,z) (a,b,c)" into ["x", "y", "z"] and ["a", "b", "c"].
|
||||
func splitFunc(s string) (in, out []string) {
|
||||
if !strings.HasPrefix(s, "func(") {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
i := len("func(") // index of beginning of 'in' arguments
|
||||
nparen := 0
|
||||
for j := i; j < len(s); j++ {
|
||||
switch s[j] {
|
||||
case '(':
|
||||
nparen++
|
||||
case ')':
|
||||
nparen--
|
||||
if nparen < 0 {
|
||||
// found end of parameter list
|
||||
out := strings.TrimSpace(s[j+1:])
|
||||
if len(out) >= 2 && out[0] == '(' && out[len(out)-1] == ')' {
|
||||
out = out[1 : len(out)-1]
|
||||
}
|
||||
return split(s[i:j]), split(out)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// joinFunc is the inverse of splitFunc.
|
||||
func joinFunc(in, out []string) string {
|
||||
outs := ""
|
||||
if len(out) == 1 {
|
||||
outs = " " + out[0]
|
||||
} else if len(out) > 1 {
|
||||
outs = " (" + join(out) + ")"
|
||||
}
|
||||
return "func(" + join(in) + ")" + outs
|
||||
}
|
||||
|
||||
// split splits "int, float" into ["int", "float"] and splits "" into [].
|
||||
func split(s string) []string {
|
||||
out := []string{}
|
||||
i := 0 // current type being scanned is s[i:j].
|
||||
nparen := 0
|
||||
for j := 0; j < len(s); j++ {
|
||||
switch s[j] {
|
||||
case ' ':
|
||||
if i == j {
|
||||
i++
|
||||
}
|
||||
case '(':
|
||||
nparen++
|
||||
case ')':
|
||||
nparen--
|
||||
if nparen < 0 {
|
||||
// probably can't happen
|
||||
return nil
|
||||
}
|
||||
case ',':
|
||||
if nparen == 0 {
|
||||
if i < j {
|
||||
out = append(out, s[i:j])
|
||||
}
|
||||
i = j + 1
|
||||
}
|
||||
}
|
||||
}
|
||||
if nparen != 0 {
|
||||
// probably can't happen
|
||||
return nil
|
||||
}
|
||||
if i < len(s) {
|
||||
out = append(out, s[i:])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// join is the inverse of split.
|
||||
func join(x []string) string {
|
||||
return strings.Join(x, ", ")
|
||||
}
|
|
@ -0,0 +1,406 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/internal"
|
||||
pb "google.golang.org/appengine/internal/datastore"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidEntityType is returned when functions like Get or Next are
|
||||
// passed a dst or src argument of invalid type.
|
||||
ErrInvalidEntityType = errors.New("datastore: invalid entity type")
|
||||
// ErrInvalidKey is returned when an invalid key is presented.
|
||||
ErrInvalidKey = errors.New("datastore: invalid key")
|
||||
// ErrNoSuchEntity is returned when no entity was found for a given key.
|
||||
ErrNoSuchEntity = errors.New("datastore: no such entity")
|
||||
)
|
||||
|
||||
// ErrFieldMismatch is returned when a field is to be loaded into a different
|
||||
// type than the one it was stored from, or when a field is missing or
|
||||
// unexported in the destination struct.
|
||||
// StructType is the type of the struct pointed to by the destination argument
|
||||
// passed to Get or to Iterator.Next.
|
||||
type ErrFieldMismatch struct {
|
||||
StructType reflect.Type
|
||||
FieldName string
|
||||
Reason string
|
||||
}
|
||||
|
||||
func (e *ErrFieldMismatch) Error() string {
|
||||
return fmt.Sprintf("datastore: cannot load field %q into a %q: %s",
|
||||
e.FieldName, e.StructType, e.Reason)
|
||||
}
|
||||
|
||||
// protoToKey converts a Reference proto to a *Key.
|
||||
func protoToKey(r *pb.Reference) (k *Key, err error) {
|
||||
appID := r.GetApp()
|
||||
namespace := r.GetNameSpace()
|
||||
for _, e := range r.Path.Element {
|
||||
k = &Key{
|
||||
kind: e.GetType(),
|
||||
stringID: e.GetName(),
|
||||
intID: e.GetId(),
|
||||
parent: k,
|
||||
appID: appID,
|
||||
namespace: namespace,
|
||||
}
|
||||
if !k.valid() {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// keyToProto converts a *Key to a Reference proto.
|
||||
func keyToProto(defaultAppID string, k *Key) *pb.Reference {
|
||||
appID := k.appID
|
||||
if appID == "" {
|
||||
appID = defaultAppID
|
||||
}
|
||||
n := 0
|
||||
for i := k; i != nil; i = i.parent {
|
||||
n++
|
||||
}
|
||||
e := make([]*pb.Path_Element, n)
|
||||
for i := k; i != nil; i = i.parent {
|
||||
n--
|
||||
e[n] = &pb.Path_Element{
|
||||
Type: &i.kind,
|
||||
}
|
||||
// At most one of {Name,Id} should be set.
|
||||
// Neither will be set for incomplete keys.
|
||||
if i.stringID != "" {
|
||||
e[n].Name = &i.stringID
|
||||
} else if i.intID != 0 {
|
||||
e[n].Id = &i.intID
|
||||
}
|
||||
}
|
||||
var namespace *string
|
||||
if k.namespace != "" {
|
||||
namespace = proto.String(k.namespace)
|
||||
}
|
||||
return &pb.Reference{
|
||||
App: proto.String(appID),
|
||||
NameSpace: namespace,
|
||||
Path: &pb.Path{
|
||||
Element: e,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// multiKeyToProto is a batch version of keyToProto.
|
||||
func multiKeyToProto(appID string, key []*Key) []*pb.Reference {
|
||||
ret := make([]*pb.Reference, len(key))
|
||||
for i, k := range key {
|
||||
ret[i] = keyToProto(appID, k)
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
// multiValid is a batch version of Key.valid. It returns an error, not a
|
||||
// []bool.
|
||||
func multiValid(key []*Key) error {
|
||||
invalid := false
|
||||
for _, k := range key {
|
||||
if !k.valid() {
|
||||
invalid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !invalid {
|
||||
return nil
|
||||
}
|
||||
err := make(appengine.MultiError, len(key))
|
||||
for i, k := range key {
|
||||
if !k.valid() {
|
||||
err[i] = ErrInvalidKey
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// It's unfortunate that the two semantically equivalent concepts pb.Reference
|
||||
// and pb.PropertyValue_ReferenceValue aren't the same type. For example, the
|
||||
// two have different protobuf field numbers.
|
||||
|
||||
// referenceValueToKey is the same as protoToKey except the input is a
|
||||
// PropertyValue_ReferenceValue instead of a Reference.
|
||||
func referenceValueToKey(r *pb.PropertyValue_ReferenceValue) (k *Key, err error) {
|
||||
appID := r.GetApp()
|
||||
namespace := r.GetNameSpace()
|
||||
for _, e := range r.Pathelement {
|
||||
k = &Key{
|
||||
kind: e.GetType(),
|
||||
stringID: e.GetName(),
|
||||
intID: e.GetId(),
|
||||
parent: k,
|
||||
appID: appID,
|
||||
namespace: namespace,
|
||||
}
|
||||
if !k.valid() {
|
||||
return nil, ErrInvalidKey
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// keyToReferenceValue is the same as keyToProto except the output is a
|
||||
// PropertyValue_ReferenceValue instead of a Reference.
|
||||
func keyToReferenceValue(defaultAppID string, k *Key) *pb.PropertyValue_ReferenceValue {
|
||||
ref := keyToProto(defaultAppID, k)
|
||||
pe := make([]*pb.PropertyValue_ReferenceValue_PathElement, len(ref.Path.Element))
|
||||
for i, e := range ref.Path.Element {
|
||||
pe[i] = &pb.PropertyValue_ReferenceValue_PathElement{
|
||||
Type: e.Type,
|
||||
Id: e.Id,
|
||||
Name: e.Name,
|
||||
}
|
||||
}
|
||||
return &pb.PropertyValue_ReferenceValue{
|
||||
App: ref.App,
|
||||
NameSpace: ref.NameSpace,
|
||||
Pathelement: pe,
|
||||
}
|
||||
}
|
||||
|
||||
type multiArgType int
|
||||
|
||||
const (
|
||||
multiArgTypeInvalid multiArgType = iota
|
||||
multiArgTypePropertyLoadSaver
|
||||
multiArgTypeStruct
|
||||
multiArgTypeStructPtr
|
||||
multiArgTypeInterface
|
||||
)
|
||||
|
||||
// checkMultiArg checks that v has type []S, []*S, []I, or []P, for some struct
|
||||
// type S, for some interface type I, or some non-interface non-pointer type P
|
||||
// such that P or *P implements PropertyLoadSaver.
|
||||
//
|
||||
// It returns what category the slice's elements are, and the reflect.Type
|
||||
// that represents S, I or P.
|
||||
//
|
||||
// As a special case, PropertyList is an invalid type for v.
|
||||
func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
|
||||
if v.Kind() != reflect.Slice {
|
||||
return multiArgTypeInvalid, nil
|
||||
}
|
||||
if v.Type() == typeOfPropertyList {
|
||||
return multiArgTypeInvalid, nil
|
||||
}
|
||||
elemType = v.Type().Elem()
|
||||
if reflect.PtrTo(elemType).Implements(typeOfPropertyLoadSaver) {
|
||||
return multiArgTypePropertyLoadSaver, elemType
|
||||
}
|
||||
switch elemType.Kind() {
|
||||
case reflect.Struct:
|
||||
return multiArgTypeStruct, elemType
|
||||
case reflect.Interface:
|
||||
return multiArgTypeInterface, elemType
|
||||
case reflect.Ptr:
|
||||
elemType = elemType.Elem()
|
||||
if elemType.Kind() == reflect.Struct {
|
||||
return multiArgTypeStructPtr, elemType
|
||||
}
|
||||
}
|
||||
return multiArgTypeInvalid, nil
|
||||
}
|
||||
|
||||
// Get loads the entity stored for k into dst, which must be a struct pointer
|
||||
// or implement PropertyLoadSaver. If there is no such entity for the key, Get
|
||||
// returns ErrNoSuchEntity.
|
||||
//
|
||||
// The values of dst's unmatched struct fields are not modified, and matching
|
||||
// slice-typed fields are not reset before appending to them. In particular, it
|
||||
// is recommended to pass a pointer to a zero valued struct on each Get call.
|
||||
//
|
||||
// ErrFieldMismatch is returned when a field is to be loaded into a different
|
||||
// type than the one it was stored from, or when a field is missing or
|
||||
// unexported in the destination struct. ErrFieldMismatch is only returned if
|
||||
// dst is a struct pointer.
|
||||
func Get(c context.Context, key *Key, dst interface{}) error {
|
||||
if dst == nil { // GetMulti catches nil interface; we need to catch nil ptr here
|
||||
return ErrInvalidEntityType
|
||||
}
|
||||
err := GetMulti(c, []*Key{key}, []interface{}{dst})
|
||||
if me, ok := err.(appengine.MultiError); ok {
|
||||
return me[0]
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GetMulti is a batch version of Get.
|
||||
//
|
||||
// dst must be a []S, []*S, []I or []P, for some struct type S, some interface
|
||||
// type I, or some non-interface non-pointer type P such that P or *P
|
||||
// implements PropertyLoadSaver. If an []I, each element must be a valid dst
|
||||
// for Get: it must be a struct pointer or implement PropertyLoadSaver.
|
||||
//
|
||||
// As a special case, PropertyList is an invalid type for dst, even though a
|
||||
// PropertyList is a slice of structs. It is treated as invalid to avoid being
|
||||
// mistakenly passed when []PropertyList was intended.
|
||||
func GetMulti(c context.Context, key []*Key, dst interface{}) error {
|
||||
v := reflect.ValueOf(dst)
|
||||
multiArgType, _ := checkMultiArg(v)
|
||||
if multiArgType == multiArgTypeInvalid {
|
||||
return errors.New("datastore: dst has invalid type")
|
||||
}
|
||||
if len(key) != v.Len() {
|
||||
return errors.New("datastore: key and dst slices have different length")
|
||||
}
|
||||
if len(key) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := multiValid(key); err != nil {
|
||||
return err
|
||||
}
|
||||
req := &pb.GetRequest{
|
||||
Key: multiKeyToProto(internal.FullyQualifiedAppID(c), key),
|
||||
}
|
||||
res := &pb.GetResponse{}
|
||||
if err := internal.Call(c, "datastore_v3", "Get", req, res); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(key) != len(res.Entity) {
|
||||
return errors.New("datastore: internal error: server returned the wrong number of entities")
|
||||
}
|
||||
multiErr, any := make(appengine.MultiError, len(key)), false
|
||||
for i, e := range res.Entity {
|
||||
if e.Entity == nil {
|
||||
multiErr[i] = ErrNoSuchEntity
|
||||
} else {
|
||||
elem := v.Index(i)
|
||||
if multiArgType == multiArgTypePropertyLoadSaver || multiArgType == multiArgTypeStruct {
|
||||
elem = elem.Addr()
|
||||
}
|
||||
if multiArgType == multiArgTypeStructPtr && elem.IsNil() {
|
||||
elem.Set(reflect.New(elem.Type().Elem()))
|
||||
}
|
||||
multiErr[i] = loadEntity(elem.Interface(), e.Entity)
|
||||
}
|
||||
if multiErr[i] != nil {
|
||||
any = true
|
||||
}
|
||||
}
|
||||
if any {
|
||||
return multiErr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put saves the entity src into the datastore with key k. src must be a struct
|
||||
// pointer or implement PropertyLoadSaver; if a struct pointer then any
|
||||
// unexported fields of that struct will be skipped. If k is an incomplete key,
|
||||
// the returned key will be a unique key generated by the datastore.
|
||||
func Put(c context.Context, key *Key, src interface{}) (*Key, error) {
|
||||
k, err := PutMulti(c, []*Key{key}, []interface{}{src})
|
||||
if err != nil {
|
||||
if me, ok := err.(appengine.MultiError); ok {
|
||||
return nil, me[0]
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return k[0], nil
|
||||
}
|
||||
|
||||
// PutMulti is a batch version of Put.
|
||||
//
|
||||
// src must satisfy the same conditions as the dst argument to GetMulti.
|
||||
func PutMulti(c context.Context, key []*Key, src interface{}) ([]*Key, error) {
|
||||
v := reflect.ValueOf(src)
|
||||
multiArgType, _ := checkMultiArg(v)
|
||||
if multiArgType == multiArgTypeInvalid {
|
||||
return nil, errors.New("datastore: src has invalid type")
|
||||
}
|
||||
if len(key) != v.Len() {
|
||||
return nil, errors.New("datastore: key and src slices have different length")
|
||||
}
|
||||
if len(key) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
appID := internal.FullyQualifiedAppID(c)
|
||||
if err := multiValid(key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req := &pb.PutRequest{}
|
||||
for i := range key {
|
||||
elem := v.Index(i)
|
||||
if multiArgType == multiArgTypePropertyLoadSaver || multiArgType == multiArgTypeStruct {
|
||||
elem = elem.Addr()
|
||||
}
|
||||
sProto, err := saveEntity(appID, key[i], elem.Interface())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Entity = append(req.Entity, sProto)
|
||||
}
|
||||
res := &pb.PutResponse{}
|
||||
if err := internal.Call(c, "datastore_v3", "Put", req, res); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(key) != len(res.Key) {
|
||||
return nil, errors.New("datastore: internal error: server returned the wrong number of keys")
|
||||
}
|
||||
ret := make([]*Key, len(key))
|
||||
for i := range ret {
|
||||
var err error
|
||||
ret[i], err = protoToKey(res.Key[i])
|
||||
if err != nil || ret[i].Incomplete() {
|
||||
return nil, errors.New("datastore: internal error: server returned an invalid key")
|
||||
}
|
||||
}
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
// Delete deletes the entity for the given key.
|
||||
func Delete(c context.Context, key *Key) error {
|
||||
err := DeleteMulti(c, []*Key{key})
|
||||
if me, ok := err.(appengine.MultiError); ok {
|
||||
return me[0]
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// DeleteMulti is a batch version of Delete.
|
||||
func DeleteMulti(c context.Context, key []*Key) error {
|
||||
if len(key) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := multiValid(key); err != nil {
|
||||
return err
|
||||
}
|
||||
req := &pb.DeleteRequest{
|
||||
Key: multiKeyToProto(internal.FullyQualifiedAppID(c), key),
|
||||
}
|
||||
res := &pb.DeleteResponse{}
|
||||
return internal.Call(c, "datastore_v3", "Delete", req, res)
|
||||
}
|
||||
|
||||
func namespaceMod(m proto.Message, namespace string) {
|
||||
// pb.Query is the only type that has a name_space field.
|
||||
// All other namespace support in datastore is in the keys.
|
||||
switch m := m.(type) {
|
||||
case *pb.Query:
|
||||
if m.NameSpace == nil {
|
||||
m.NameSpace = &namespace
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
internal.NamespaceMods["datastore_v3"] = namespaceMod
|
||||
internal.RegisterErrorCodeMap("datastore_v3", pb.Error_ErrorCode_name)
|
||||
internal.RegisterTimeoutErrorCode("datastore_v3", int32(pb.Error_TIMEOUT))
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,351 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package datastore provides a client for App Engine's datastore service.
|
||||
|
||||
|
||||
Basic Operations
|
||||
|
||||
Entities are the unit of storage and are associated with a key. A key
|
||||
consists of an optional parent key, a string application ID, a string kind
|
||||
(also known as an entity type), and either a StringID or an IntID. A
|
||||
StringID is also known as an entity name or key name.
|
||||
|
||||
It is valid to create a key with a zero StringID and a zero IntID; this is
|
||||
called an incomplete key, and does not refer to any saved entity. Putting an
|
||||
entity into the datastore under an incomplete key will cause a unique key
|
||||
to be generated for that entity, with a non-zero IntID.
|
||||
|
||||
An entity's contents are a mapping from case-sensitive field names to values.
|
||||
Valid value types are:
|
||||
- signed integers (int, int8, int16, int32 and int64),
|
||||
- bool,
|
||||
- string,
|
||||
- float32 and float64,
|
||||
- []byte (up to 1 megabyte in length),
|
||||
- any type whose underlying type is one of the above predeclared types,
|
||||
- ByteString,
|
||||
- *Key,
|
||||
- time.Time (stored with microsecond precision),
|
||||
- appengine.BlobKey,
|
||||
- appengine.GeoPoint,
|
||||
- structs whose fields are all valid value types,
|
||||
- slices of any of the above.
|
||||
|
||||
Slices of structs are valid, as are structs that contain slices. However, if
|
||||
one struct contains another, then at most one of those can be repeated. This
|
||||
disqualifies recursively defined struct types: any struct T that (directly or
|
||||
indirectly) contains a []T.
|
||||
|
||||
The Get and Put functions load and save an entity's contents. An entity's
|
||||
contents are typically represented by a struct pointer.
|
||||
|
||||
Example code:
|
||||
|
||||
type Entity struct {
|
||||
Value string
|
||||
}
|
||||
|
||||
func handle(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := appengine.NewContext(r)
|
||||
|
||||
k := datastore.NewKey(ctx, "Entity", "stringID", 0, nil)
|
||||
e := new(Entity)
|
||||
if err := datastore.Get(ctx, k, e); err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
old := e.Value
|
||||
e.Value = r.URL.Path
|
||||
|
||||
if _, err := datastore.Put(ctx, k, e); err != nil {
|
||||
http.Error(w, err.Error(), 500)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprintf(w, "old=%q\nnew=%q\n", old, e.Value)
|
||||
}
|
||||
|
||||
GetMulti, PutMulti and DeleteMulti are batch versions of the Get, Put and
|
||||
Delete functions. They take a []*Key instead of a *Key, and may return an
|
||||
appengine.MultiError when encountering partial failure.
|
||||
|
||||
|
||||
Properties
|
||||
|
||||
An entity's contents can be represented by a variety of types. These are
|
||||
typically struct pointers, but can also be any type that implements the
|
||||
PropertyLoadSaver interface. If using a struct pointer, you do not have to
|
||||
explicitly implement the PropertyLoadSaver interface; the datastore will
|
||||
automatically convert via reflection. If a struct pointer does implement that
|
||||
interface then those methods will be used in preference to the default
|
||||
behavior for struct pointers. Struct pointers are more strongly typed and are
|
||||
easier to use; PropertyLoadSavers are more flexible.
|
||||
|
||||
The actual types passed do not have to match between Get and Put calls or even
|
||||
across different App Engine requests. It is valid to put a *PropertyList and
|
||||
get that same entity as a *myStruct, or put a *myStruct0 and get a *myStruct1.
|
||||
Conceptually, any entity is saved as a sequence of properties, and is loaded
|
||||
into the destination value on a property-by-property basis. When loading into
|
||||
a struct pointer, an entity that cannot be completely represented (such as a
|
||||
missing field) will result in an ErrFieldMismatch error but it is up to the
|
||||
caller whether this error is fatal, recoverable or ignorable.
|
||||
|
||||
By default, for struct pointers, all properties are potentially indexed, and
|
||||
the property name is the same as the field name (and hence must start with an
|
||||
upper case letter). Fields may have a `datastore:"name,options"` tag. The tag
|
||||
name is the property name, which must be one or more valid Go identifiers
|
||||
joined by ".", but may start with a lower case letter. An empty tag name means
|
||||
to just use the field name. A "-" tag name means that the datastore will
|
||||
ignore that field. If options is "noindex" then the field will not be indexed.
|
||||
If the options is "" then the comma may be omitted. There are no other
|
||||
recognized options.
|
||||
|
||||
Fields (except for []byte) are indexed by default. Strings longer than 1500
|
||||
bytes cannot be indexed; fields used to store long strings should be
|
||||
tagged with "noindex". Similarly, ByteStrings longer than 1500 bytes cannot be
|
||||
indexed.
|
||||
|
||||
Example code:
|
||||
|
||||
// A and B are renamed to a and b.
|
||||
// A, C and J are not indexed.
|
||||
// D's tag is equivalent to having no tag at all (E).
|
||||
// I is ignored entirely by the datastore.
|
||||
// J has tag information for both the datastore and json packages.
|
||||
type TaggedStruct struct {
|
||||
A int `datastore:"a,noindex"`
|
||||
B int `datastore:"b"`
|
||||
C int `datastore:",noindex"`
|
||||
D int `datastore:""`
|
||||
E int
|
||||
I int `datastore:"-"`
|
||||
J int `datastore:",noindex" json:"j"`
|
||||
}
|
||||
|
||||
|
||||
Structured Properties
|
||||
|
||||
If the struct pointed to contains other structs, then the nested or embedded
|
||||
structs are flattened. For example, given these definitions:
|
||||
|
||||
type Inner1 struct {
|
||||
W int32
|
||||
X string
|
||||
}
|
||||
|
||||
type Inner2 struct {
|
||||
Y float64
|
||||
}
|
||||
|
||||
type Inner3 struct {
|
||||
Z bool
|
||||
}
|
||||
|
||||
type Outer struct {
|
||||
A int16
|
||||
I []Inner1
|
||||
J Inner2
|
||||
Inner3
|
||||
}
|
||||
|
||||
then an Outer's properties would be equivalent to those of:
|
||||
|
||||
type OuterEquivalent struct {
|
||||
A int16
|
||||
IDotW []int32 `datastore:"I.W"`
|
||||
IDotX []string `datastore:"I.X"`
|
||||
JDotY float64 `datastore:"J.Y"`
|
||||
Z bool
|
||||
}
|
||||
|
||||
If Outer's embedded Inner3 field was tagged as `datastore:"Foo"` then the
|
||||
equivalent field would instead be: FooDotZ bool `datastore:"Foo.Z"`.
|
||||
|
||||
If an outer struct is tagged "noindex" then all of its implicit flattened
|
||||
fields are effectively "noindex".
|
||||
|
||||
|
||||
The PropertyLoadSaver Interface
|
||||
|
||||
An entity's contents can also be represented by any type that implements the
|
||||
PropertyLoadSaver interface. This type may be a struct pointer, but it does
|
||||
not have to be. The datastore package will call Load when getting the entity's
|
||||
contents, and Save when putting the entity's contents.
|
||||
Possible uses include deriving non-stored fields, verifying fields, or indexing
|
||||
a field only if its value is positive.
|
||||
|
||||
Example code:
|
||||
|
||||
type CustomPropsExample struct {
|
||||
I, J int
|
||||
// Sum is not stored, but should always be equal to I + J.
|
||||
Sum int `datastore:"-"`
|
||||
}
|
||||
|
||||
func (x *CustomPropsExample) Load(ps []datastore.Property) error {
|
||||
// Load I and J as usual.
|
||||
if err := datastore.LoadStruct(x, ps); err != nil {
|
||||
return err
|
||||
}
|
||||
// Derive the Sum field.
|
||||
x.Sum = x.I + x.J
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *CustomPropsExample) Save() ([]datastore.Property, error) {
|
||||
// Validate the Sum field.
|
||||
if x.Sum != x.I + x.J {
|
||||
return errors.New("CustomPropsExample has inconsistent sum")
|
||||
}
|
||||
// Save I and J as usual. The code below is equivalent to calling
|
||||
// "return datastore.SaveStruct(x)", but is done manually for
|
||||
// demonstration purposes.
|
||||
return []datastore.Property{
|
||||
{
|
||||
Name: "I",
|
||||
Value: int64(x.I),
|
||||
},
|
||||
{
|
||||
Name: "J",
|
||||
Value: int64(x.J),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
The *PropertyList type implements PropertyLoadSaver, and can therefore hold an
|
||||
arbitrary entity's contents.
|
||||
|
||||
|
||||
Queries
|
||||
|
||||
Queries retrieve entities based on their properties or key's ancestry. Running
|
||||
a query yields an iterator of results: either keys or (key, entity) pairs.
|
||||
Queries are re-usable and it is safe to call Query.Run from concurrent
|
||||
goroutines. Iterators are not safe for concurrent use.
|
||||
|
||||
Queries are immutable, and are either created by calling NewQuery, or derived
|
||||
from an existing query by calling a method like Filter or Order that returns a
|
||||
new query value. A query is typically constructed by calling NewQuery followed
|
||||
by a chain of zero or more such methods. These methods are:
|
||||
- Ancestor and Filter constrain the entities returned by running a query.
|
||||
- Order affects the order in which they are returned.
|
||||
- Project constrains the fields returned.
|
||||
- Distinct de-duplicates projected entities.
|
||||
- KeysOnly makes the iterator return only keys, not (key, entity) pairs.
|
||||
- Start, End, Offset and Limit define which sub-sequence of matching entities
|
||||
to return. Start and End take cursors, Offset and Limit take integers. Start
|
||||
and Offset affect the first result, End and Limit affect the last result.
|
||||
If both Start and Offset are set, then the offset is relative to Start.
|
||||
If both End and Limit are set, then the earliest constraint wins. Limit is
|
||||
relative to Start+Offset, not relative to End. As a special case, a
|
||||
negative limit means unlimited.
|
||||
|
||||
Example code:
|
||||
|
||||
type Widget struct {
|
||||
Description string
|
||||
Price int
|
||||
}
|
||||
|
||||
func handle(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := appengine.NewContext(r)
|
||||
q := datastore.NewQuery("Widget").
|
||||
Filter("Price <", 1000).
|
||||
Order("-Price")
|
||||
b := new(bytes.Buffer)
|
||||
for t := q.Run(ctx); ; {
|
||||
var x Widget
|
||||
key, err := t.Next(&x)
|
||||
if err == datastore.Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
serveError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
fmt.Fprintf(b, "Key=%v\nWidget=%#v\n\n", key, x)
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
io.Copy(w, b)
|
||||
}
|
||||
|
||||
|
||||
Transactions
|
||||
|
||||
RunInTransaction runs a function in a transaction.
|
||||
|
||||
Example code:
|
||||
|
||||
type Counter struct {
|
||||
Count int
|
||||
}
|
||||
|
||||
func inc(ctx context.Context, key *datastore.Key) (int, error) {
|
||||
var x Counter
|
||||
if err := datastore.Get(ctx, key, &x); err != nil && err != datastore.ErrNoSuchEntity {
|
||||
return 0, err
|
||||
}
|
||||
x.Count++
|
||||
if _, err := datastore.Put(ctx, key, &x); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return x.Count, nil
|
||||
}
|
||||
|
||||
func handle(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := appengine.NewContext(r)
|
||||
var count int
|
||||
err := datastore.RunInTransaction(ctx, func(ctx context.Context) error {
|
||||
var err1 error
|
||||
count, err1 = inc(ctx, datastore.NewKey(ctx, "Counter", "singleton", 0, nil))
|
||||
return err1
|
||||
}, nil)
|
||||
if err != nil {
|
||||
serveError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
fmt.Fprintf(w, "Count=%d", count)
|
||||
}
|
||||
|
||||
|
||||
Metadata
|
||||
|
||||
The datastore package provides access to some of App Engine's datastore
|
||||
metadata. This metadata includes information about the entity groups,
|
||||
namespaces, entity kinds, and properties in the datastore, as well as the
|
||||
property representations for each property.
|
||||
|
||||
Example code:
|
||||
|
||||
func handle(w http.ResponseWriter, r *http.Request) {
|
||||
// Print all the kinds in the datastore, with all the indexed
|
||||
// properties (and their representations) for each.
|
||||
ctx := appengine.NewContext(r)
|
||||
|
||||
kinds, err := datastore.Kinds(ctx)
|
||||
if err != nil {
|
||||
serveError(ctx, w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
for _, kind := range kinds {
|
||||
fmt.Fprintf(w, "%s:\n", kind)
|
||||
props, err := datastore.KindProperties(ctx, kind)
|
||||
if err != nil {
|
||||
fmt.Fprintln(w, "\t(unable to retrieve properties)")
|
||||
continue
|
||||
}
|
||||
for p, rep := range props {
|
||||
fmt.Fprintf(w, "\t-%s (%s)\n", p, strings.Join(", ", rep))
|
||||
}
|
||||
}
|
||||
}
|
||||
*/
|
||||
package datastore // import "google.golang.org/appengine/datastore"
|
|
@ -0,0 +1,309 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
pb "google.golang.org/appengine/internal/datastore"
|
||||
)
|
||||
|
||||
// Key represents the datastore key for a stored entity, and is immutable.
|
||||
type Key struct {
|
||||
kind string
|
||||
stringID string
|
||||
intID int64
|
||||
parent *Key
|
||||
appID string
|
||||
namespace string
|
||||
}
|
||||
|
||||
// Kind returns the key's kind (also known as entity type).
|
||||
func (k *Key) Kind() string {
|
||||
return k.kind
|
||||
}
|
||||
|
||||
// StringID returns the key's string ID (also known as an entity name or key
|
||||
// name), which may be "".
|
||||
func (k *Key) StringID() string {
|
||||
return k.stringID
|
||||
}
|
||||
|
||||
// IntID returns the key's integer ID, which may be 0.
|
||||
func (k *Key) IntID() int64 {
|
||||
return k.intID
|
||||
}
|
||||
|
||||
// Parent returns the key's parent key, which may be nil.
|
||||
func (k *Key) Parent() *Key {
|
||||
return k.parent
|
||||
}
|
||||
|
||||
// AppID returns the key's application ID.
|
||||
func (k *Key) AppID() string {
|
||||
return k.appID
|
||||
}
|
||||
|
||||
// Namespace returns the key's namespace.
|
||||
func (k *Key) Namespace() string {
|
||||
return k.namespace
|
||||
}
|
||||
|
||||
// Incomplete returns whether the key does not refer to a stored entity.
|
||||
// In particular, whether the key has a zero StringID and a zero IntID.
|
||||
func (k *Key) Incomplete() bool {
|
||||
return k.stringID == "" && k.intID == 0
|
||||
}
|
||||
|
||||
// valid returns whether the key is valid.
|
||||
func (k *Key) valid() bool {
|
||||
if k == nil {
|
||||
return false
|
||||
}
|
||||
for ; k != nil; k = k.parent {
|
||||
if k.kind == "" || k.appID == "" {
|
||||
return false
|
||||
}
|
||||
if k.stringID != "" && k.intID != 0 {
|
||||
return false
|
||||
}
|
||||
if k.parent != nil {
|
||||
if k.parent.Incomplete() {
|
||||
return false
|
||||
}
|
||||
if k.parent.appID != k.appID || k.parent.namespace != k.namespace {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Equal returns whether two keys are equal.
|
||||
func (k *Key) Equal(o *Key) bool {
|
||||
for k != nil && o != nil {
|
||||
if k.kind != o.kind || k.stringID != o.stringID || k.intID != o.intID || k.appID != o.appID || k.namespace != o.namespace {
|
||||
return false
|
||||
}
|
||||
k, o = k.parent, o.parent
|
||||
}
|
||||
return k == o
|
||||
}
|
||||
|
||||
// root returns the furthest ancestor of a key, which may be itself.
|
||||
func (k *Key) root() *Key {
|
||||
for k.parent != nil {
|
||||
k = k.parent
|
||||
}
|
||||
return k
|
||||
}
|
||||
|
||||
// marshal marshals the key's string representation to the buffer.
|
||||
func (k *Key) marshal(b *bytes.Buffer) {
|
||||
if k.parent != nil {
|
||||
k.parent.marshal(b)
|
||||
}
|
||||
b.WriteByte('/')
|
||||
b.WriteString(k.kind)
|
||||
b.WriteByte(',')
|
||||
if k.stringID != "" {
|
||||
b.WriteString(k.stringID)
|
||||
} else {
|
||||
b.WriteString(strconv.FormatInt(k.intID, 10))
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the key.
|
||||
func (k *Key) String() string {
|
||||
if k == nil {
|
||||
return ""
|
||||
}
|
||||
b := bytes.NewBuffer(make([]byte, 0, 512))
|
||||
k.marshal(b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type gobKey struct {
|
||||
Kind string
|
||||
StringID string
|
||||
IntID int64
|
||||
Parent *gobKey
|
||||
AppID string
|
||||
Namespace string
|
||||
}
|
||||
|
||||
func keyToGobKey(k *Key) *gobKey {
|
||||
if k == nil {
|
||||
return nil
|
||||
}
|
||||
return &gobKey{
|
||||
Kind: k.kind,
|
||||
StringID: k.stringID,
|
||||
IntID: k.intID,
|
||||
Parent: keyToGobKey(k.parent),
|
||||
AppID: k.appID,
|
||||
Namespace: k.namespace,
|
||||
}
|
||||
}
|
||||
|
||||
func gobKeyToKey(gk *gobKey) *Key {
|
||||
if gk == nil {
|
||||
return nil
|
||||
}
|
||||
return &Key{
|
||||
kind: gk.Kind,
|
||||
stringID: gk.StringID,
|
||||
intID: gk.IntID,
|
||||
parent: gobKeyToKey(gk.Parent),
|
||||
appID: gk.AppID,
|
||||
namespace: gk.Namespace,
|
||||
}
|
||||
}
|
||||
|
||||
func (k *Key) GobEncode() ([]byte, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
if err := gob.NewEncoder(buf).Encode(keyToGobKey(k)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (k *Key) GobDecode(buf []byte) error {
|
||||
gk := new(gobKey)
|
||||
if err := gob.NewDecoder(bytes.NewBuffer(buf)).Decode(gk); err != nil {
|
||||
return err
|
||||
}
|
||||
*k = *gobKeyToKey(gk)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (k *Key) MarshalJSON() ([]byte, error) {
|
||||
return []byte(`"` + k.Encode() + `"`), nil
|
||||
}
|
||||
|
||||
func (k *Key) UnmarshalJSON(buf []byte) error {
|
||||
if len(buf) < 2 || buf[0] != '"' || buf[len(buf)-1] != '"' {
|
||||
return errors.New("datastore: bad JSON key")
|
||||
}
|
||||
k2, err := DecodeKey(string(buf[1 : len(buf)-1]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*k = *k2
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encode returns an opaque representation of the key
|
||||
// suitable for use in HTML and URLs.
|
||||
// This is compatible with the Python and Java runtimes.
|
||||
func (k *Key) Encode() string {
|
||||
ref := keyToProto("", k)
|
||||
|
||||
b, err := proto.Marshal(ref)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Trailing padding is stripped.
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
|
||||
}
|
||||
|
||||
// DecodeKey decodes a key from the opaque representation returned by Encode.
|
||||
func DecodeKey(encoded string) (*Key, error) {
|
||||
// Re-add padding.
|
||||
if m := len(encoded) % 4; m != 0 {
|
||||
encoded += strings.Repeat("=", 4-m)
|
||||
}
|
||||
|
||||
b, err := base64.URLEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ref := new(pb.Reference)
|
||||
if err := proto.Unmarshal(b, ref); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return protoToKey(ref)
|
||||
}
|
||||
|
||||
// NewIncompleteKey creates a new incomplete key.
|
||||
// kind cannot be empty.
|
||||
func NewIncompleteKey(c context.Context, kind string, parent *Key) *Key {
|
||||
return NewKey(c, kind, "", 0, parent)
|
||||
}
|
||||
|
||||
// NewKey creates a new key.
|
||||
// kind cannot be empty.
|
||||
// Either one or both of stringID and intID must be zero. If both are zero,
|
||||
// the key returned is incomplete.
|
||||
// parent must either be a complete key or nil.
|
||||
func NewKey(c context.Context, kind, stringID string, intID int64, parent *Key) *Key {
|
||||
// If there's a parent key, use its namespace.
|
||||
// Otherwise, use any namespace attached to the context.
|
||||
var namespace string
|
||||
if parent != nil {
|
||||
namespace = parent.namespace
|
||||
} else {
|
||||
namespace = internal.NamespaceFromContext(c)
|
||||
}
|
||||
|
||||
return &Key{
|
||||
kind: kind,
|
||||
stringID: stringID,
|
||||
intID: intID,
|
||||
parent: parent,
|
||||
appID: internal.FullyQualifiedAppID(c),
|
||||
namespace: namespace,
|
||||
}
|
||||
}
|
||||
|
||||
// AllocateIDs returns a range of n integer IDs with the given kind and parent
|
||||
// combination. kind cannot be empty; parent may be nil. The IDs in the range
|
||||
// returned will not be used by the datastore's automatic ID sequence generator
|
||||
// and may be used with NewKey without conflict.
|
||||
//
|
||||
// The range is inclusive at the low end and exclusive at the high end. In
|
||||
// other words, valid intIDs x satisfy low <= x && x < high.
|
||||
//
|
||||
// If no error is returned, low + n == high.
|
||||
func AllocateIDs(c context.Context, kind string, parent *Key, n int) (low, high int64, err error) {
|
||||
if kind == "" {
|
||||
return 0, 0, errors.New("datastore: AllocateIDs given an empty kind")
|
||||
}
|
||||
if n < 0 {
|
||||
return 0, 0, fmt.Errorf("datastore: AllocateIDs given a negative count: %d", n)
|
||||
}
|
||||
if n == 0 {
|
||||
return 0, 0, nil
|
||||
}
|
||||
req := &pb.AllocateIdsRequest{
|
||||
ModelKey: keyToProto("", NewIncompleteKey(c, kind, parent)),
|
||||
Size: proto.Int64(int64(n)),
|
||||
}
|
||||
res := &pb.AllocateIdsResponse{}
|
||||
if err := internal.Call(c, "datastore_v3", "AllocateIds", req, res); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
// The protobuf is inclusive at both ends. Idiomatic Go (e.g. slices, for loops)
|
||||
// is inclusive at the low end and exclusive at the high end, so we add 1.
|
||||
low = res.GetStart()
|
||||
high = res.GetEnd() + 1
|
||||
if low+int64(n) != high {
|
||||
return 0, 0, fmt.Errorf("datastore: internal error: could not allocate %d IDs", n)
|
||||
}
|
||||
return low, high, nil
|
||||
}
|
|
@ -0,0 +1,204 @@
|
|||
// Copyright 2011 Google Inc. All Rights Reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
)
|
||||
|
||||
func TestKeyEncoding(t *testing.T) {
|
||||
testCases := []struct {
|
||||
desc string
|
||||
key *Key
|
||||
exp string
|
||||
}{
|
||||
{
|
||||
desc: "A simple key with an int ID",
|
||||
key: &Key{
|
||||
kind: "Person",
|
||||
intID: 1,
|
||||
appID: "glibrary",
|
||||
},
|
||||
exp: "aghnbGlicmFyeXIMCxIGUGVyc29uGAEM",
|
||||
},
|
||||
{
|
||||
desc: "A simple key with a string ID",
|
||||
key: &Key{
|
||||
kind: "Graph",
|
||||
stringID: "graph:7-day-active",
|
||||
appID: "glibrary",
|
||||
},
|
||||
exp: "aghnbGlicmFyeXIdCxIFR3JhcGgiEmdyYXBoOjctZGF5LWFjdGl2ZQw",
|
||||
},
|
||||
{
|
||||
desc: "A key with a parent",
|
||||
key: &Key{
|
||||
kind: "WordIndex",
|
||||
intID: 1033,
|
||||
parent: &Key{
|
||||
kind: "WordIndex",
|
||||
intID: 1020032,
|
||||
appID: "glibrary",
|
||||
},
|
||||
appID: "glibrary",
|
||||
},
|
||||
exp: "aghnbGlicmFyeXIhCxIJV29yZEluZGV4GIChPgwLEglXb3JkSW5kZXgYiQgM",
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
enc := tc.key.Encode()
|
||||
if enc != tc.exp {
|
||||
t.Errorf("%s: got %q, want %q", tc.desc, enc, tc.exp)
|
||||
}
|
||||
|
||||
key, err := DecodeKey(tc.exp)
|
||||
if err != nil {
|
||||
t.Errorf("%s: failed decoding key: %v", tc.desc, err)
|
||||
continue
|
||||
}
|
||||
if !key.Equal(tc.key) {
|
||||
t.Errorf("%s: decoded key %v, want %v", tc.desc, key, tc.key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyGob(t *testing.T) {
|
||||
k := &Key{
|
||||
kind: "Gopher",
|
||||
intID: 3,
|
||||
parent: &Key{
|
||||
kind: "Mom",
|
||||
stringID: "narwhal",
|
||||
appID: "gopher-con",
|
||||
},
|
||||
appID: "gopher-con",
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := gob.NewEncoder(buf).Encode(k); err != nil {
|
||||
t.Fatalf("gob encode failed: %v", err)
|
||||
}
|
||||
|
||||
k2 := new(Key)
|
||||
if err := gob.NewDecoder(buf).Decode(k2); err != nil {
|
||||
t.Fatalf("gob decode failed: %v", err)
|
||||
}
|
||||
if !k2.Equal(k) {
|
||||
t.Errorf("gob round trip of %v produced %v", k, k2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilKeyGob(t *testing.T) {
|
||||
type S struct {
|
||||
Key *Key
|
||||
}
|
||||
s1 := new(S)
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := gob.NewEncoder(buf).Encode(s1); err != nil {
|
||||
t.Fatalf("gob encode failed: %v", err)
|
||||
}
|
||||
|
||||
s2 := new(S)
|
||||
if err := gob.NewDecoder(buf).Decode(s2); err != nil {
|
||||
t.Fatalf("gob decode failed: %v", err)
|
||||
}
|
||||
if s2.Key != nil {
|
||||
t.Errorf("gob round trip of nil key produced %v", s2.Key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeyJSON(t *testing.T) {
|
||||
k := &Key{
|
||||
kind: "Gopher",
|
||||
intID: 2,
|
||||
parent: &Key{
|
||||
kind: "Mom",
|
||||
stringID: "narwhal",
|
||||
appID: "gopher-con",
|
||||
},
|
||||
appID: "gopher-con",
|
||||
}
|
||||
exp := `"` + k.Encode() + `"`
|
||||
|
||||
buf, err := json.Marshal(k)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
if s := string(buf); s != exp {
|
||||
t.Errorf("JSON encoding of key %v: got %q, want %q", k, s, exp)
|
||||
}
|
||||
|
||||
k2 := new(Key)
|
||||
if err := json.Unmarshal(buf, k2); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
if !k2.Equal(k) {
|
||||
t.Errorf("JSON round trip of %v produced %v", k, k2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilKeyJSON(t *testing.T) {
|
||||
type S struct {
|
||||
Key *Key
|
||||
}
|
||||
s1 := new(S)
|
||||
|
||||
buf, err := json.Marshal(s1)
|
||||
if err != nil {
|
||||
t.Fatalf("json.Marshal failed: %v", err)
|
||||
}
|
||||
|
||||
s2 := new(S)
|
||||
if err := json.Unmarshal(buf, s2); err != nil {
|
||||
t.Fatalf("json.Unmarshal failed: %v", err)
|
||||
}
|
||||
if s2.Key != nil {
|
||||
t.Errorf("JSON round trip of nil key produced %v", s2.Key)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncompleteKeyWithParent(t *testing.T) {
|
||||
c := internal.WithAppIDOverride(context.Background(), "s~some-app")
|
||||
|
||||
// fadduh is a complete key.
|
||||
fadduh := NewKey(c, "Person", "", 1, nil)
|
||||
if fadduh.Incomplete() {
|
||||
t.Fatalf("fadduh is incomplete")
|
||||
}
|
||||
|
||||
// robert is an incomplete key with fadduh as a parent.
|
||||
robert := NewIncompleteKey(c, "Person", fadduh)
|
||||
if !robert.Incomplete() {
|
||||
t.Fatalf("robert is complete")
|
||||
}
|
||||
|
||||
// Both should be valid keys.
|
||||
if !fadduh.valid() {
|
||||
t.Errorf("fadduh is invalid: %v", fadduh)
|
||||
}
|
||||
if !robert.valid() {
|
||||
t.Errorf("robert is invalid: %v", robert)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamespace(t *testing.T) {
|
||||
key := &Key{
|
||||
kind: "Person",
|
||||
intID: 1,
|
||||
appID: "s~some-app",
|
||||
namespace: "mynamespace",
|
||||
}
|
||||
if g, w := key.Namespace(), "mynamespace"; g != w {
|
||||
t.Errorf("key.Namespace() = %q, want %q", g, w)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,334 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
pb "google.golang.org/appengine/internal/datastore"
|
||||
)
|
||||
|
||||
var (
|
||||
typeOfBlobKey = reflect.TypeOf(appengine.BlobKey(""))
|
||||
typeOfByteSlice = reflect.TypeOf([]byte(nil))
|
||||
typeOfByteString = reflect.TypeOf(ByteString(nil))
|
||||
typeOfGeoPoint = reflect.TypeOf(appengine.GeoPoint{})
|
||||
typeOfTime = reflect.TypeOf(time.Time{})
|
||||
)
|
||||
|
||||
// typeMismatchReason returns a string explaining why the property p could not
|
||||
// be stored in an entity field of type v.Type().
|
||||
func typeMismatchReason(p Property, v reflect.Value) string {
|
||||
entityType := "empty"
|
||||
switch p.Value.(type) {
|
||||
case int64:
|
||||
entityType = "int"
|
||||
case bool:
|
||||
entityType = "bool"
|
||||
case string:
|
||||
entityType = "string"
|
||||
case float64:
|
||||
entityType = "float"
|
||||
case *Key:
|
||||
entityType = "*datastore.Key"
|
||||
case time.Time:
|
||||
entityType = "time.Time"
|
||||
case appengine.BlobKey:
|
||||
entityType = "appengine.BlobKey"
|
||||
case appengine.GeoPoint:
|
||||
entityType = "appengine.GeoPoint"
|
||||
case ByteString:
|
||||
entityType = "datastore.ByteString"
|
||||
case []byte:
|
||||
entityType = "[]byte"
|
||||
}
|
||||
return fmt.Sprintf("type mismatch: %s versus %v", entityType, v.Type())
|
||||
}
|
||||
|
||||
type propertyLoader struct {
|
||||
// m holds the number of times a substruct field like "Foo.Bar.Baz" has
|
||||
// been seen so far. The map is constructed lazily.
|
||||
m map[string]int
|
||||
}
|
||||
|
||||
func (l *propertyLoader) load(codec *structCodec, structValue reflect.Value, p Property, requireSlice bool) string {
|
||||
var v reflect.Value
|
||||
// Traverse a struct's struct-typed fields.
|
||||
for name := p.Name; ; {
|
||||
decoder, ok := codec.byName[name]
|
||||
if !ok {
|
||||
return "no such struct field"
|
||||
}
|
||||
v = structValue.Field(decoder.index)
|
||||
if !v.IsValid() {
|
||||
return "no such struct field"
|
||||
}
|
||||
if !v.CanSet() {
|
||||
return "cannot set struct field"
|
||||
}
|
||||
|
||||
if decoder.substructCodec == nil {
|
||||
break
|
||||
}
|
||||
|
||||
if v.Kind() == reflect.Slice {
|
||||
if l.m == nil {
|
||||
l.m = make(map[string]int)
|
||||
}
|
||||
index := l.m[p.Name]
|
||||
l.m[p.Name] = index + 1
|
||||
for v.Len() <= index {
|
||||
v.Set(reflect.Append(v, reflect.New(v.Type().Elem()).Elem()))
|
||||
}
|
||||
structValue = v.Index(index)
|
||||
requireSlice = false
|
||||
} else {
|
||||
structValue = v
|
||||
}
|
||||
// Strip the "I." from "I.X".
|
||||
name = name[len(codec.byIndex[decoder.index].name):]
|
||||
codec = decoder.substructCodec
|
||||
}
|
||||
|
||||
var slice reflect.Value
|
||||
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 {
|
||||
slice = v
|
||||
v = reflect.New(v.Type().Elem()).Elem()
|
||||
} else if requireSlice {
|
||||
return "multiple-valued property requires a slice field type"
|
||||
}
|
||||
|
||||
// Convert indexValues to a Go value with a meaning derived from the
|
||||
// destination type.
|
||||
pValue := p.Value
|
||||
if iv, ok := pValue.(indexValue); ok {
|
||||
meaning := pb.Property_NO_MEANING
|
||||
switch v.Type() {
|
||||
case typeOfBlobKey:
|
||||
meaning = pb.Property_BLOBKEY
|
||||
case typeOfByteSlice:
|
||||
meaning = pb.Property_BLOB
|
||||
case typeOfByteString:
|
||||
meaning = pb.Property_BYTESTRING
|
||||
case typeOfGeoPoint:
|
||||
meaning = pb.Property_GEORSS_POINT
|
||||
case typeOfTime:
|
||||
meaning = pb.Property_GD_WHEN
|
||||
}
|
||||
var err error
|
||||
pValue, err = propValue(iv.value, meaning)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
}
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
x, ok := pValue.(int64)
|
||||
if !ok && pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
if v.OverflowInt(x) {
|
||||
return fmt.Sprintf("value %v overflows struct field of type %v", x, v.Type())
|
||||
}
|
||||
v.SetInt(x)
|
||||
case reflect.Bool:
|
||||
x, ok := pValue.(bool)
|
||||
if !ok && pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
v.SetBool(x)
|
||||
case reflect.String:
|
||||
switch x := pValue.(type) {
|
||||
case appengine.BlobKey:
|
||||
v.SetString(string(x))
|
||||
case ByteString:
|
||||
v.SetString(string(x))
|
||||
case string:
|
||||
v.SetString(x)
|
||||
default:
|
||||
if pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
}
|
||||
case reflect.Float32, reflect.Float64:
|
||||
x, ok := pValue.(float64)
|
||||
if !ok && pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
if v.OverflowFloat(x) {
|
||||
return fmt.Sprintf("value %v overflows struct field of type %v", x, v.Type())
|
||||
}
|
||||
v.SetFloat(x)
|
||||
case reflect.Ptr:
|
||||
x, ok := pValue.(*Key)
|
||||
if !ok && pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
if _, ok := v.Interface().(*Key); !ok {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
v.Set(reflect.ValueOf(x))
|
||||
case reflect.Struct:
|
||||
switch v.Type() {
|
||||
case typeOfTime:
|
||||
x, ok := pValue.(time.Time)
|
||||
if !ok && pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
v.Set(reflect.ValueOf(x))
|
||||
case typeOfGeoPoint:
|
||||
x, ok := pValue.(appengine.GeoPoint)
|
||||
if !ok && pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
v.Set(reflect.ValueOf(x))
|
||||
default:
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
case reflect.Slice:
|
||||
x, ok := pValue.([]byte)
|
||||
if !ok {
|
||||
if y, yok := pValue.(ByteString); yok {
|
||||
x, ok = []byte(y), true
|
||||
}
|
||||
}
|
||||
if !ok && pValue != nil {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
if v.Type().Elem().Kind() != reflect.Uint8 {
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
v.SetBytes(x)
|
||||
default:
|
||||
return typeMismatchReason(p, v)
|
||||
}
|
||||
if slice.IsValid() {
|
||||
slice.Set(reflect.Append(slice, v))
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// loadEntity loads an EntityProto into PropertyLoadSaver or struct pointer.
|
||||
func loadEntity(dst interface{}, src *pb.EntityProto) (err error) {
|
||||
props, err := protoToProperties(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if e, ok := dst.(PropertyLoadSaver); ok {
|
||||
return e.Load(props)
|
||||
}
|
||||
return LoadStruct(dst, props)
|
||||
}
|
||||
|
||||
func (s structPLS) Load(props []Property) error {
|
||||
var fieldName, reason string
|
||||
var l propertyLoader
|
||||
for _, p := range props {
|
||||
if errStr := l.load(s.codec, s.v, p, p.Multiple); errStr != "" {
|
||||
// We don't return early, as we try to load as many properties as possible.
|
||||
// It is valid to load an entity into a struct that cannot fully represent it.
|
||||
// That case returns an error, but the caller is free to ignore it.
|
||||
fieldName, reason = p.Name, errStr
|
||||
}
|
||||
}
|
||||
if reason != "" {
|
||||
return &ErrFieldMismatch{
|
||||
StructType: s.v.Type(),
|
||||
FieldName: fieldName,
|
||||
Reason: reason,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func protoToProperties(src *pb.EntityProto) ([]Property, error) {
|
||||
props, rawProps := src.Property, src.RawProperty
|
||||
out := make([]Property, 0, len(props)+len(rawProps))
|
||||
for {
|
||||
var (
|
||||
x *pb.Property
|
||||
noIndex bool
|
||||
)
|
||||
if len(props) > 0 {
|
||||
x, props = props[0], props[1:]
|
||||
} else if len(rawProps) > 0 {
|
||||
x, rawProps = rawProps[0], rawProps[1:]
|
||||
noIndex = true
|
||||
} else {
|
||||
break
|
||||
}
|
||||
|
||||
var value interface{}
|
||||
if x.Meaning != nil && *x.Meaning == pb.Property_INDEX_VALUE {
|
||||
value = indexValue{x.Value}
|
||||
} else {
|
||||
var err error
|
||||
value, err = propValue(x.Value, x.GetMeaning())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
out = append(out, Property{
|
||||
Name: x.GetName(),
|
||||
Value: value,
|
||||
NoIndex: noIndex,
|
||||
Multiple: x.GetMultiple(),
|
||||
})
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// propValue returns a Go value that combines the raw PropertyValue with a
|
||||
// meaning. For example, an Int64Value with GD_WHEN becomes a time.Time.
|
||||
func propValue(v *pb.PropertyValue, m pb.Property_Meaning) (interface{}, error) {
|
||||
switch {
|
||||
case v.Int64Value != nil:
|
||||
if m == pb.Property_GD_WHEN {
|
||||
return fromUnixMicro(*v.Int64Value), nil
|
||||
} else {
|
||||
return *v.Int64Value, nil
|
||||
}
|
||||
case v.BooleanValue != nil:
|
||||
return *v.BooleanValue, nil
|
||||
case v.StringValue != nil:
|
||||
if m == pb.Property_BLOB {
|
||||
return []byte(*v.StringValue), nil
|
||||
} else if m == pb.Property_BLOBKEY {
|
||||
return appengine.BlobKey(*v.StringValue), nil
|
||||
} else if m == pb.Property_BYTESTRING {
|
||||
return ByteString(*v.StringValue), nil
|
||||
} else {
|
||||
return *v.StringValue, nil
|
||||
}
|
||||
case v.DoubleValue != nil:
|
||||
return *v.DoubleValue, nil
|
||||
case v.Referencevalue != nil:
|
||||
key, err := referenceValueToKey(v.Referencevalue)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
case v.Pointvalue != nil:
|
||||
// NOTE: Strangely, latitude maps to X, longitude to Y.
|
||||
return appengine.GeoPoint{Lat: v.Pointvalue.GetX(), Lng: v.Pointvalue.GetY()}, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// indexValue is a Property value that is created when entities are loaded from
|
||||
// an index, such as from a projection query.
|
||||
//
|
||||
// Such Property values do not contain all of the metadata required to be
|
||||
// faithfully represented as a Go value, and are instead represented as an
|
||||
// opaque indexValue. Load the properties into a concrete struct type (e.g. by
|
||||
// passing a struct pointer to Iterator.Next) to reconstruct actual Go values
|
||||
// of type int, string, time.Time, etc.
|
||||
type indexValue struct {
|
||||
value *pb.PropertyValue
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
// Copyright 2016 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import "golang.org/x/net/context"
|
||||
|
||||
// Datastore kinds for the metadata entities.
|
||||
const (
|
||||
namespaceKind = "__namespace__"
|
||||
kindKind = "__kind__"
|
||||
propertyKind = "__property__"
|
||||
)
|
||||
|
||||
// Namespaces returns all the datastore namespaces.
|
||||
func Namespaces(ctx context.Context) ([]string, error) {
|
||||
// TODO(djd): Support range queries.
|
||||
q := NewQuery(namespaceKind).KeysOnly()
|
||||
keys, err := q.GetAll(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// The empty namespace key uses a numeric ID (==1), but luckily
|
||||
// the string ID defaults to "" for numeric IDs anyway.
|
||||
return keyNames(keys), nil
|
||||
}
|
||||
|
||||
// Kinds returns the names of all the kinds in the current namespace.
|
||||
func Kinds(ctx context.Context) ([]string, error) {
|
||||
// TODO(djd): Support range queries.
|
||||
q := NewQuery(kindKind).KeysOnly()
|
||||
keys, err := q.GetAll(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keyNames(keys), nil
|
||||
}
|
||||
|
||||
// keyNames returns a slice of the provided keys' names (string IDs).
|
||||
func keyNames(keys []*Key) []string {
|
||||
n := make([]string, 0, len(keys))
|
||||
for _, k := range keys {
|
||||
n = append(n, k.StringID())
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// KindProperties returns all the indexed properties for the given kind.
|
||||
// The properties are returned as a map of property names to a slice of the
|
||||
// representation types. The representation types for the supported Go property
|
||||
// types are:
|
||||
// "INT64": signed integers and time.Time
|
||||
// "DOUBLE": float32 and float64
|
||||
// "BOOLEAN": bool
|
||||
// "STRING": string, []byte and ByteString
|
||||
// "POINT": appengine.GeoPoint
|
||||
// "REFERENCE": *Key
|
||||
// "USER": (not used in the Go runtime)
|
||||
func KindProperties(ctx context.Context, kind string) (map[string][]string, error) {
|
||||
// TODO(djd): Support range queries.
|
||||
kindKey := NewKey(ctx, kindKind, kind, 0, nil)
|
||||
q := NewQuery(propertyKind).Ancestor(kindKey)
|
||||
|
||||
propMap := map[string][]string{}
|
||||
props := []struct {
|
||||
Repr []string `datastore:property_representation`
|
||||
}{}
|
||||
|
||||
keys, err := q.GetAll(ctx, &props)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i, p := range props {
|
||||
propMap[keys[i].StringID()] = p.Repr
|
||||
}
|
||||
return propMap, nil
|
||||
}
|
|
@ -0,0 +1,296 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Entities with more than this many indexed properties will not be saved.
|
||||
const maxIndexedProperties = 20000
|
||||
|
||||
// []byte fields more than 1 megabyte long will not be loaded or saved.
|
||||
const maxBlobLen = 1 << 20
|
||||
|
||||
// Property is a name/value pair plus some metadata. A datastore entity's
|
||||
// contents are loaded and saved as a sequence of Properties. An entity can
|
||||
// have multiple Properties with the same name, provided that p.Multiple is
|
||||
// true on all of that entity's Properties with that name.
|
||||
type Property struct {
|
||||
// Name is the property name.
|
||||
Name string
|
||||
// Value is the property value. The valid types are:
|
||||
// - int64
|
||||
// - bool
|
||||
// - string
|
||||
// - float64
|
||||
// - ByteString
|
||||
// - *Key
|
||||
// - time.Time
|
||||
// - appengine.BlobKey
|
||||
// - appengine.GeoPoint
|
||||
// - []byte (up to 1 megabyte in length)
|
||||
// This set is smaller than the set of valid struct field types that the
|
||||
// datastore can load and save. A Property Value cannot be a slice (apart
|
||||
// from []byte); use multiple Properties instead. Also, a Value's type
|
||||
// must be explicitly on the list above; it is not sufficient for the
|
||||
// underlying type to be on that list. For example, a Value of "type
|
||||
// myInt64 int64" is invalid. Smaller-width integers and floats are also
|
||||
// invalid. Again, this is more restrictive than the set of valid struct
|
||||
// field types.
|
||||
//
|
||||
// A Value will have an opaque type when loading entities from an index,
|
||||
// such as via a projection query. Load entities into a struct instead
|
||||
// of a PropertyLoadSaver when using a projection query.
|
||||
//
|
||||
// A Value may also be the nil interface value; this is equivalent to
|
||||
// Python's None but not directly representable by a Go struct. Loading
|
||||
// a nil-valued property into a struct will set that field to the zero
|
||||
// value.
|
||||
Value interface{}
|
||||
// NoIndex is whether the datastore cannot index this property.
|
||||
NoIndex bool
|
||||
// Multiple is whether the entity can have multiple properties with
|
||||
// the same name. Even if a particular instance only has one property with
|
||||
// a certain name, Multiple should be true if a struct would best represent
|
||||
// it as a field of type []T instead of type T.
|
||||
Multiple bool
|
||||
}
|
||||
|
||||
// ByteString is a short byte slice (up to 1500 bytes) that can be indexed.
|
||||
type ByteString []byte
|
||||
|
||||
// PropertyLoadSaver can be converted from and to a slice of Properties.
|
||||
type PropertyLoadSaver interface {
|
||||
Load([]Property) error
|
||||
Save() ([]Property, error)
|
||||
}
|
||||
|
||||
// PropertyList converts a []Property to implement PropertyLoadSaver.
|
||||
type PropertyList []Property
|
||||
|
||||
var (
|
||||
typeOfPropertyLoadSaver = reflect.TypeOf((*PropertyLoadSaver)(nil)).Elem()
|
||||
typeOfPropertyList = reflect.TypeOf(PropertyList(nil))
|
||||
)
|
||||
|
||||
// Load loads all of the provided properties into l.
|
||||
// It does not first reset *l to an empty slice.
|
||||
func (l *PropertyList) Load(p []Property) error {
|
||||
*l = append(*l, p...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Save saves all of l's properties as a slice or Properties.
|
||||
func (l *PropertyList) Save() ([]Property, error) {
|
||||
return *l, nil
|
||||
}
|
||||
|
||||
// validPropertyName returns whether name consists of one or more valid Go
|
||||
// identifiers joined by ".".
|
||||
func validPropertyName(name string) bool {
|
||||
if name == "" {
|
||||
return false
|
||||
}
|
||||
for _, s := range strings.Split(name, ".") {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
first := true
|
||||
for _, c := range s {
|
||||
if first {
|
||||
first = false
|
||||
if c != '_' && !unicode.IsLetter(c) {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
if c != '_' && !unicode.IsLetter(c) && !unicode.IsDigit(c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// structTag is the parsed `datastore:"name,options"` tag of a struct field.
|
||||
// If a field has no tag, or the tag has an empty name, then the structTag's
|
||||
// name is just the field name. A "-" name means that the datastore ignores
|
||||
// that field.
|
||||
type structTag struct {
|
||||
name string
|
||||
noIndex bool
|
||||
}
|
||||
|
||||
// structCodec describes how to convert a struct to and from a sequence of
|
||||
// properties.
|
||||
type structCodec struct {
|
||||
// byIndex gives the structTag for the i'th field.
|
||||
byIndex []structTag
|
||||
// byName gives the field codec for the structTag with the given name.
|
||||
byName map[string]fieldCodec
|
||||
// hasSlice is whether a struct or any of its nested or embedded structs
|
||||
// has a slice-typed field (other than []byte).
|
||||
hasSlice bool
|
||||
// complete is whether the structCodec is complete. An incomplete
|
||||
// structCodec may be encountered when walking a recursive struct.
|
||||
complete bool
|
||||
}
|
||||
|
||||
// fieldCodec is a struct field's index and, if that struct field's type is
|
||||
// itself a struct, that substruct's structCodec.
|
||||
type fieldCodec struct {
|
||||
index int
|
||||
substructCodec *structCodec
|
||||
}
|
||||
|
||||
// structCodecs collects the structCodecs that have already been calculated.
|
||||
var (
|
||||
structCodecsMutex sync.Mutex
|
||||
structCodecs = make(map[reflect.Type]*structCodec)
|
||||
)
|
||||
|
||||
// getStructCodec returns the structCodec for the given struct type.
|
||||
func getStructCodec(t reflect.Type) (*structCodec, error) {
|
||||
structCodecsMutex.Lock()
|
||||
defer structCodecsMutex.Unlock()
|
||||
return getStructCodecLocked(t)
|
||||
}
|
||||
|
||||
// getStructCodecLocked implements getStructCodec. The structCodecsMutex must
|
||||
// be held when calling this function.
|
||||
func getStructCodecLocked(t reflect.Type) (ret *structCodec, retErr error) {
|
||||
c, ok := structCodecs[t]
|
||||
if ok {
|
||||
return c, nil
|
||||
}
|
||||
c = &structCodec{
|
||||
byIndex: make([]structTag, t.NumField()),
|
||||
byName: make(map[string]fieldCodec),
|
||||
}
|
||||
|
||||
// Add c to the structCodecs map before we are sure it is good. If t is
|
||||
// a recursive type, it needs to find the incomplete entry for itself in
|
||||
// the map.
|
||||
structCodecs[t] = c
|
||||
defer func() {
|
||||
if retErr != nil {
|
||||
delete(structCodecs, t)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := range c.byIndex {
|
||||
f := t.Field(i)
|
||||
tags := strings.Split(f.Tag.Get("datastore"), ",")
|
||||
name := tags[0]
|
||||
opts := make(map[string]bool)
|
||||
for _, t := range tags[1:] {
|
||||
opts[t] = true
|
||||
}
|
||||
if name == "" {
|
||||
if !f.Anonymous {
|
||||
name = f.Name
|
||||
}
|
||||
} else if name == "-" {
|
||||
c.byIndex[i] = structTag{name: name}
|
||||
continue
|
||||
} else if !validPropertyName(name) {
|
||||
return nil, fmt.Errorf("datastore: struct tag has invalid property name: %q", name)
|
||||
}
|
||||
|
||||
substructType, fIsSlice := reflect.Type(nil), false
|
||||
switch f.Type.Kind() {
|
||||
case reflect.Struct:
|
||||
substructType = f.Type
|
||||
case reflect.Slice:
|
||||
if f.Type.Elem().Kind() == reflect.Struct {
|
||||
substructType = f.Type.Elem()
|
||||
}
|
||||
fIsSlice = f.Type != typeOfByteSlice
|
||||
c.hasSlice = c.hasSlice || fIsSlice
|
||||
}
|
||||
|
||||
if substructType != nil && substructType != typeOfTime && substructType != typeOfGeoPoint {
|
||||
if name != "" {
|
||||
name = name + "."
|
||||
}
|
||||
sub, err := getStructCodecLocked(substructType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !sub.complete {
|
||||
return nil, fmt.Errorf("datastore: recursive struct: field %q", f.Name)
|
||||
}
|
||||
if fIsSlice && sub.hasSlice {
|
||||
return nil, fmt.Errorf(
|
||||
"datastore: flattening nested structs leads to a slice of slices: field %q", f.Name)
|
||||
}
|
||||
c.hasSlice = c.hasSlice || sub.hasSlice
|
||||
for relName := range sub.byName {
|
||||
absName := name + relName
|
||||
if _, ok := c.byName[absName]; ok {
|
||||
return nil, fmt.Errorf("datastore: struct tag has repeated property name: %q", absName)
|
||||
}
|
||||
c.byName[absName] = fieldCodec{index: i, substructCodec: sub}
|
||||
}
|
||||
} else {
|
||||
if _, ok := c.byName[name]; ok {
|
||||
return nil, fmt.Errorf("datastore: struct tag has repeated property name: %q", name)
|
||||
}
|
||||
c.byName[name] = fieldCodec{index: i}
|
||||
}
|
||||
|
||||
c.byIndex[i] = structTag{
|
||||
name: name,
|
||||
noIndex: opts["noindex"],
|
||||
}
|
||||
}
|
||||
c.complete = true
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// structPLS adapts a struct to be a PropertyLoadSaver.
|
||||
type structPLS struct {
|
||||
v reflect.Value
|
||||
codec *structCodec
|
||||
}
|
||||
|
||||
// newStructPLS returns a PropertyLoadSaver for the struct pointer p.
|
||||
func newStructPLS(p interface{}) (PropertyLoadSaver, error) {
|
||||
v := reflect.ValueOf(p)
|
||||
if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
|
||||
return nil, ErrInvalidEntityType
|
||||
}
|
||||
v = v.Elem()
|
||||
codec, err := getStructCodec(v.Type())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return structPLS{v, codec}, nil
|
||||
}
|
||||
|
||||
// LoadStruct loads the properties from p to dst.
|
||||
// dst must be a struct pointer.
|
||||
func LoadStruct(dst interface{}, p []Property) error {
|
||||
x, err := newStructPLS(dst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return x.Load(p)
|
||||
}
|
||||
|
||||
// SaveStruct returns the properties from src as a slice of Properties.
|
||||
// src must be a struct pointer.
|
||||
func SaveStruct(src interface{}) ([]Property, error) {
|
||||
x, err := newStructPLS(src)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return x.Save()
|
||||
}
|
|
@ -0,0 +1,604 @@
|
|||
// Copyright 2011 Google Inc. All Rights Reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
)
|
||||
|
||||
func TestValidPropertyName(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
want bool
|
||||
}{
|
||||
// Invalid names.
|
||||
{"", false},
|
||||
{"'", false},
|
||||
{".", false},
|
||||
{"..", false},
|
||||
{".foo", false},
|
||||
{"0", false},
|
||||
{"00", false},
|
||||
{"X.X.4.X.X", false},
|
||||
{"\n", false},
|
||||
{"\x00", false},
|
||||
{"abc\xffz", false},
|
||||
{"foo.", false},
|
||||
{"foo..", false},
|
||||
{"foo..bar", false},
|
||||
{"☃", false},
|
||||
{`"`, false},
|
||||
// Valid names.
|
||||
{"AB", true},
|
||||
{"Abc", true},
|
||||
{"X.X.X.X.X", true},
|
||||
{"_", true},
|
||||
{"_0", true},
|
||||
{"a", true},
|
||||
{"a_B", true},
|
||||
{"f00", true},
|
||||
{"f0o", true},
|
||||
{"fo0", true},
|
||||
{"foo", true},
|
||||
{"foo.bar", true},
|
||||
{"foo.bar.baz", true},
|
||||
{"世界", true},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
got := validPropertyName(tc.name)
|
||||
if got != tc.want {
|
||||
t.Errorf("%q: got %v, want %v", tc.name, got, tc.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStructCodec(t *testing.T) {
|
||||
type oStruct struct {
|
||||
O int
|
||||
}
|
||||
type pStruct struct {
|
||||
P int
|
||||
Q int
|
||||
}
|
||||
type rStruct struct {
|
||||
R int
|
||||
S pStruct
|
||||
T oStruct
|
||||
oStruct
|
||||
}
|
||||
type uStruct struct {
|
||||
U int
|
||||
v int
|
||||
}
|
||||
type vStruct struct {
|
||||
V string `datastore:",noindex"`
|
||||
}
|
||||
oStructCodec := &structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "O"},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"O": {index: 0},
|
||||
},
|
||||
complete: true,
|
||||
}
|
||||
pStructCodec := &structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "P"},
|
||||
{name: "Q"},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"P": {index: 0},
|
||||
"Q": {index: 1},
|
||||
},
|
||||
complete: true,
|
||||
}
|
||||
rStructCodec := &structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "R"},
|
||||
{name: "S."},
|
||||
{name: "T."},
|
||||
{name: ""},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"R": {index: 0},
|
||||
"S.P": {index: 1, substructCodec: pStructCodec},
|
||||
"S.Q": {index: 1, substructCodec: pStructCodec},
|
||||
"T.O": {index: 2, substructCodec: oStructCodec},
|
||||
"O": {index: 3, substructCodec: oStructCodec},
|
||||
},
|
||||
complete: true,
|
||||
}
|
||||
uStructCodec := &structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "U"},
|
||||
{name: "v"},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"U": {index: 0},
|
||||
"v": {index: 1},
|
||||
},
|
||||
complete: true,
|
||||
}
|
||||
vStructCodec := &structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "V", noIndex: true},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"V": {index: 0},
|
||||
},
|
||||
complete: true,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
structValue interface{}
|
||||
want *structCodec
|
||||
}{
|
||||
{
|
||||
"oStruct",
|
||||
oStruct{},
|
||||
oStructCodec,
|
||||
},
|
||||
{
|
||||
"pStruct",
|
||||
pStruct{},
|
||||
pStructCodec,
|
||||
},
|
||||
{
|
||||
"rStruct",
|
||||
rStruct{},
|
||||
rStructCodec,
|
||||
},
|
||||
{
|
||||
"uStruct",
|
||||
uStruct{},
|
||||
uStructCodec,
|
||||
},
|
||||
{
|
||||
"non-basic fields",
|
||||
struct {
|
||||
B appengine.BlobKey
|
||||
K *Key
|
||||
T time.Time
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "B"},
|
||||
{name: "K"},
|
||||
{name: "T"},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"B": {index: 0},
|
||||
"K": {index: 1},
|
||||
"T": {index: 2},
|
||||
},
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"struct tags with ignored embed",
|
||||
struct {
|
||||
A int `datastore:"a,noindex"`
|
||||
B int `datastore:"b"`
|
||||
C int `datastore:",noindex"`
|
||||
D int `datastore:""`
|
||||
E int
|
||||
I int `datastore:"-"`
|
||||
J int `datastore:",noindex" json:"j"`
|
||||
oStruct `datastore:"-"`
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "a", noIndex: true},
|
||||
{name: "b", noIndex: false},
|
||||
{name: "C", noIndex: true},
|
||||
{name: "D", noIndex: false},
|
||||
{name: "E", noIndex: false},
|
||||
{name: "-", noIndex: false},
|
||||
{name: "J", noIndex: true},
|
||||
{name: "-", noIndex: false},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"a": {index: 0},
|
||||
"b": {index: 1},
|
||||
"C": {index: 2},
|
||||
"D": {index: 3},
|
||||
"E": {index: 4},
|
||||
"J": {index: 6},
|
||||
},
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"unexported fields",
|
||||
struct {
|
||||
A int
|
||||
b int
|
||||
C int `datastore:"x"`
|
||||
d int `datastore:"Y"`
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "A"},
|
||||
{name: "b"},
|
||||
{name: "x"},
|
||||
{name: "Y"},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"A": {index: 0},
|
||||
"b": {index: 1},
|
||||
"x": {index: 2},
|
||||
"Y": {index: 3},
|
||||
},
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"nested and embedded structs",
|
||||
struct {
|
||||
A int
|
||||
B int
|
||||
CC oStruct
|
||||
DDD rStruct
|
||||
oStruct
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "A"},
|
||||
{name: "B"},
|
||||
{name: "CC."},
|
||||
{name: "DDD."},
|
||||
{name: ""},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"A": {index: 0},
|
||||
"B": {index: 1},
|
||||
"CC.O": {index: 2, substructCodec: oStructCodec},
|
||||
"DDD.R": {index: 3, substructCodec: rStructCodec},
|
||||
"DDD.S.P": {index: 3, substructCodec: rStructCodec},
|
||||
"DDD.S.Q": {index: 3, substructCodec: rStructCodec},
|
||||
"DDD.T.O": {index: 3, substructCodec: rStructCodec},
|
||||
"DDD.O": {index: 3, substructCodec: rStructCodec},
|
||||
"O": {index: 4, substructCodec: oStructCodec},
|
||||
},
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"struct tags with nested and embedded structs",
|
||||
struct {
|
||||
A int `datastore:"-"`
|
||||
B int `datastore:"w"`
|
||||
C oStruct `datastore:"xx"`
|
||||
D rStruct `datastore:"y"`
|
||||
oStruct `datastore:"z"`
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "-"},
|
||||
{name: "w"},
|
||||
{name: "xx."},
|
||||
{name: "y."},
|
||||
{name: "z."},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"w": {index: 1},
|
||||
"xx.O": {index: 2, substructCodec: oStructCodec},
|
||||
"y.R": {index: 3, substructCodec: rStructCodec},
|
||||
"y.S.P": {index: 3, substructCodec: rStructCodec},
|
||||
"y.S.Q": {index: 3, substructCodec: rStructCodec},
|
||||
"y.T.O": {index: 3, substructCodec: rStructCodec},
|
||||
"y.O": {index: 3, substructCodec: rStructCodec},
|
||||
"z.O": {index: 4, substructCodec: oStructCodec},
|
||||
},
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"unexported nested and embedded structs",
|
||||
struct {
|
||||
a int
|
||||
B int
|
||||
c uStruct
|
||||
D uStruct
|
||||
uStruct
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "a"},
|
||||
{name: "B"},
|
||||
{name: "c."},
|
||||
{name: "D."},
|
||||
{name: ""},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"a": {index: 0},
|
||||
"B": {index: 1},
|
||||
"c.U": {index: 2, substructCodec: uStructCodec},
|
||||
"c.v": {index: 2, substructCodec: uStructCodec},
|
||||
"D.U": {index: 3, substructCodec: uStructCodec},
|
||||
"D.v": {index: 3, substructCodec: uStructCodec},
|
||||
"U": {index: 4, substructCodec: uStructCodec},
|
||||
"v": {index: 4, substructCodec: uStructCodec},
|
||||
},
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"noindex nested struct",
|
||||
struct {
|
||||
A oStruct `datastore:",noindex"`
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "A.", noIndex: true},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"A.O": {index: 0, substructCodec: oStructCodec},
|
||||
},
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"noindex slice",
|
||||
struct {
|
||||
A []string `datastore:",noindex"`
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "A", noIndex: true},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"A": {index: 0},
|
||||
},
|
||||
hasSlice: true,
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"noindex embedded struct slice",
|
||||
struct {
|
||||
// vStruct has a single field, V, also with noindex.
|
||||
A []vStruct `datastore:",noindex"`
|
||||
}{},
|
||||
&structCodec{
|
||||
byIndex: []structTag{
|
||||
{name: "A.", noIndex: true},
|
||||
},
|
||||
byName: map[string]fieldCodec{
|
||||
"A.V": {index: 0, substructCodec: vStructCodec},
|
||||
},
|
||||
hasSlice: true,
|
||||
complete: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
got, err := getStructCodec(reflect.TypeOf(tc.structValue))
|
||||
if err != nil {
|
||||
t.Errorf("%s: getStructCodec: %v", tc.desc, err)
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(got, tc.want) {
|
||||
t.Errorf("%s\ngot %+v\nwant %+v\n", tc.desc, got, tc.want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepeatedPropertyName(t *testing.T) {
|
||||
good := []interface{}{
|
||||
struct {
|
||||
A int `datastore:"-"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"b"`
|
||||
B int
|
||||
}{},
|
||||
struct {
|
||||
A int
|
||||
B int `datastore:"B"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"B"`
|
||||
B int `datastore:"-"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"-"`
|
||||
B int `datastore:"A"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"B"`
|
||||
B int `datastore:"A"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"B"`
|
||||
B int `datastore:"C"`
|
||||
C int `datastore:"A"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"B"`
|
||||
B int `datastore:"C"`
|
||||
C int `datastore:"D"`
|
||||
}{},
|
||||
}
|
||||
bad := []interface{}{
|
||||
struct {
|
||||
A int `datastore:"B"`
|
||||
B int
|
||||
}{},
|
||||
struct {
|
||||
A int
|
||||
B int `datastore:"A"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"C"`
|
||||
B int `datastore:"C"`
|
||||
}{},
|
||||
struct {
|
||||
A int `datastore:"B"`
|
||||
B int `datastore:"C"`
|
||||
C int `datastore:"B"`
|
||||
}{},
|
||||
}
|
||||
testGetStructCodec(t, good, bad)
|
||||
}
|
||||
|
||||
func TestFlatteningNestedStructs(t *testing.T) {
|
||||
type deepGood struct {
|
||||
A struct {
|
||||
B []struct {
|
||||
C struct {
|
||||
D int
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type deepBad struct {
|
||||
A struct {
|
||||
B []struct {
|
||||
C struct {
|
||||
D []int
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
type iSay struct {
|
||||
Tomato int
|
||||
}
|
||||
type youSay struct {
|
||||
Tomato int
|
||||
}
|
||||
type tweedledee struct {
|
||||
Dee int `datastore:"D"`
|
||||
}
|
||||
type tweedledum struct {
|
||||
Dum int `datastore:"D"`
|
||||
}
|
||||
|
||||
good := []interface{}{
|
||||
struct {
|
||||
X []struct {
|
||||
Y string
|
||||
}
|
||||
}{},
|
||||
struct {
|
||||
X []struct {
|
||||
Y []byte
|
||||
}
|
||||
}{},
|
||||
struct {
|
||||
P []int
|
||||
X struct {
|
||||
Y []int
|
||||
}
|
||||
}{},
|
||||
struct {
|
||||
X struct {
|
||||
Y []int
|
||||
}
|
||||
Q []int
|
||||
}{},
|
||||
struct {
|
||||
P []int
|
||||
X struct {
|
||||
Y []int
|
||||
}
|
||||
Q []int
|
||||
}{},
|
||||
struct {
|
||||
deepGood
|
||||
}{},
|
||||
struct {
|
||||
DG deepGood
|
||||
}{},
|
||||
struct {
|
||||
Foo struct {
|
||||
Z int `datastore:"X"`
|
||||
} `datastore:"A"`
|
||||
Bar struct {
|
||||
Z int `datastore:"Y"`
|
||||
} `datastore:"A"`
|
||||
}{},
|
||||
}
|
||||
bad := []interface{}{
|
||||
struct {
|
||||
X []struct {
|
||||
Y []string
|
||||
}
|
||||
}{},
|
||||
struct {
|
||||
X []struct {
|
||||
Y []int
|
||||
}
|
||||
}{},
|
||||
struct {
|
||||
deepBad
|
||||
}{},
|
||||
struct {
|
||||
DB deepBad
|
||||
}{},
|
||||
struct {
|
||||
iSay
|
||||
youSay
|
||||
}{},
|
||||
struct {
|
||||
tweedledee
|
||||
tweedledum
|
||||
}{},
|
||||
struct {
|
||||
Foo struct {
|
||||
Z int
|
||||
} `datastore:"A"`
|
||||
Bar struct {
|
||||
Z int
|
||||
} `datastore:"A"`
|
||||
}{},
|
||||
}
|
||||
testGetStructCodec(t, good, bad)
|
||||
}
|
||||
|
||||
func testGetStructCodec(t *testing.T, good []interface{}, bad []interface{}) {
|
||||
for _, x := range good {
|
||||
if _, err := getStructCodec(reflect.TypeOf(x)); err != nil {
|
||||
t.Errorf("type %T: got non-nil error (%s), want nil", x, err)
|
||||
}
|
||||
}
|
||||
for _, x := range bad {
|
||||
if _, err := getStructCodec(reflect.TypeOf(x)); err == nil {
|
||||
t.Errorf("type %T: got nil error, want non-nil", x)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNilKeyIsStored(t *testing.T) {
|
||||
x := struct {
|
||||
K *Key
|
||||
I int
|
||||
}{}
|
||||
p := PropertyList{}
|
||||
// Save x as properties.
|
||||
p1, _ := SaveStruct(&x)
|
||||
p.Load(p1)
|
||||
// Set x's fields to non-zero.
|
||||
x.K = &Key{}
|
||||
x.I = 2
|
||||
// Load x from properties.
|
||||
p2, _ := p.Save()
|
||||
LoadStruct(&x, p2)
|
||||
// Check that x's fields were set to zero.
|
||||
if x.K != nil {
|
||||
t.Errorf("K field was not zero")
|
||||
}
|
||||
if x.I != 0 {
|
||||
t.Errorf("I field was not zero")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,724 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
pb "google.golang.org/appengine/internal/datastore"
|
||||
)
|
||||
|
||||
type operator int
|
||||
|
||||
const (
|
||||
lessThan operator = iota
|
||||
lessEq
|
||||
equal
|
||||
greaterEq
|
||||
greaterThan
|
||||
)
|
||||
|
||||
var operatorToProto = map[operator]*pb.Query_Filter_Operator{
|
||||
lessThan: pb.Query_Filter_LESS_THAN.Enum(),
|
||||
lessEq: pb.Query_Filter_LESS_THAN_OR_EQUAL.Enum(),
|
||||
equal: pb.Query_Filter_EQUAL.Enum(),
|
||||
greaterEq: pb.Query_Filter_GREATER_THAN_OR_EQUAL.Enum(),
|
||||
greaterThan: pb.Query_Filter_GREATER_THAN.Enum(),
|
||||
}
|
||||
|
||||
// filter is a conditional filter on query results.
|
||||
type filter struct {
|
||||
FieldName string
|
||||
Op operator
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
type sortDirection int
|
||||
|
||||
const (
|
||||
ascending sortDirection = iota
|
||||
descending
|
||||
)
|
||||
|
||||
var sortDirectionToProto = map[sortDirection]*pb.Query_Order_Direction{
|
||||
ascending: pb.Query_Order_ASCENDING.Enum(),
|
||||
descending: pb.Query_Order_DESCENDING.Enum(),
|
||||
}
|
||||
|
||||
// order is a sort order on query results.
|
||||
type order struct {
|
||||
FieldName string
|
||||
Direction sortDirection
|
||||
}
|
||||
|
||||
// NewQuery creates a new Query for a specific entity kind.
|
||||
//
|
||||
// An empty kind means to return all entities, including entities created and
|
||||
// managed by other App Engine features, and is called a kindless query.
|
||||
// Kindless queries cannot include filters or sort orders on property values.
|
||||
func NewQuery(kind string) *Query {
|
||||
return &Query{
|
||||
kind: kind,
|
||||
limit: -1,
|
||||
}
|
||||
}
|
||||
|
||||
// Query represents a datastore query.
|
||||
type Query struct {
|
||||
kind string
|
||||
ancestor *Key
|
||||
filter []filter
|
||||
order []order
|
||||
projection []string
|
||||
|
||||
distinct bool
|
||||
keysOnly bool
|
||||
eventual bool
|
||||
limit int32
|
||||
offset int32
|
||||
start *pb.CompiledCursor
|
||||
end *pb.CompiledCursor
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
func (q *Query) clone() *Query {
|
||||
x := *q
|
||||
// Copy the contents of the slice-typed fields to a new backing store.
|
||||
if len(q.filter) > 0 {
|
||||
x.filter = make([]filter, len(q.filter))
|
||||
copy(x.filter, q.filter)
|
||||
}
|
||||
if len(q.order) > 0 {
|
||||
x.order = make([]order, len(q.order))
|
||||
copy(x.order, q.order)
|
||||
}
|
||||
return &x
|
||||
}
|
||||
|
||||
// Ancestor returns a derivative query with an ancestor filter.
|
||||
// The ancestor should not be nil.
|
||||
func (q *Query) Ancestor(ancestor *Key) *Query {
|
||||
q = q.clone()
|
||||
if ancestor == nil {
|
||||
q.err = errors.New("datastore: nil query ancestor")
|
||||
return q
|
||||
}
|
||||
q.ancestor = ancestor
|
||||
return q
|
||||
}
|
||||
|
||||
// EventualConsistency returns a derivative query that returns eventually
|
||||
// consistent results.
|
||||
// It only has an effect on ancestor queries.
|
||||
func (q *Query) EventualConsistency() *Query {
|
||||
q = q.clone()
|
||||
q.eventual = true
|
||||
return q
|
||||
}
|
||||
|
||||
// Filter returns a derivative query with a field-based filter.
|
||||
// The filterStr argument must be a field name followed by optional space,
|
||||
// followed by an operator, one of ">", "<", ">=", "<=", or "=".
|
||||
// Fields are compared against the provided value using the operator.
|
||||
// Multiple filters are AND'ed together.
|
||||
func (q *Query) Filter(filterStr string, value interface{}) *Query {
|
||||
q = q.clone()
|
||||
filterStr = strings.TrimSpace(filterStr)
|
||||
if len(filterStr) < 1 {
|
||||
q.err = errors.New("datastore: invalid filter: " + filterStr)
|
||||
return q
|
||||
}
|
||||
f := filter{
|
||||
FieldName: strings.TrimRight(filterStr, " ><=!"),
|
||||
Value: value,
|
||||
}
|
||||
switch op := strings.TrimSpace(filterStr[len(f.FieldName):]); op {
|
||||
case "<=":
|
||||
f.Op = lessEq
|
||||
case ">=":
|
||||
f.Op = greaterEq
|
||||
case "<":
|
||||
f.Op = lessThan
|
||||
case ">":
|
||||
f.Op = greaterThan
|
||||
case "=":
|
||||
f.Op = equal
|
||||
default:
|
||||
q.err = fmt.Errorf("datastore: invalid operator %q in filter %q", op, filterStr)
|
||||
return q
|
||||
}
|
||||
q.filter = append(q.filter, f)
|
||||
return q
|
||||
}
|
||||
|
||||
// Order returns a derivative query with a field-based sort order. Orders are
|
||||
// applied in the order they are added. The default order is ascending; to sort
|
||||
// in descending order prefix the fieldName with a minus sign (-).
|
||||
func (q *Query) Order(fieldName string) *Query {
|
||||
q = q.clone()
|
||||
fieldName = strings.TrimSpace(fieldName)
|
||||
o := order{
|
||||
Direction: ascending,
|
||||
FieldName: fieldName,
|
||||
}
|
||||
if strings.HasPrefix(fieldName, "-") {
|
||||
o.Direction = descending
|
||||
o.FieldName = strings.TrimSpace(fieldName[1:])
|
||||
} else if strings.HasPrefix(fieldName, "+") {
|
||||
q.err = fmt.Errorf("datastore: invalid order: %q", fieldName)
|
||||
return q
|
||||
}
|
||||
if len(o.FieldName) == 0 {
|
||||
q.err = errors.New("datastore: empty order")
|
||||
return q
|
||||
}
|
||||
q.order = append(q.order, o)
|
||||
return q
|
||||
}
|
||||
|
||||
// Project returns a derivative query that yields only the given fields. It
|
||||
// cannot be used with KeysOnly.
|
||||
func (q *Query) Project(fieldNames ...string) *Query {
|
||||
q = q.clone()
|
||||
q.projection = append([]string(nil), fieldNames...)
|
||||
return q
|
||||
}
|
||||
|
||||
// Distinct returns a derivative query that yields de-duplicated entities with
|
||||
// respect to the set of projected fields. It is only used for projection
|
||||
// queries.
|
||||
func (q *Query) Distinct() *Query {
|
||||
q = q.clone()
|
||||
q.distinct = true
|
||||
return q
|
||||
}
|
||||
|
||||
// KeysOnly returns a derivative query that yields only keys, not keys and
|
||||
// entities. It cannot be used with projection queries.
|
||||
func (q *Query) KeysOnly() *Query {
|
||||
q = q.clone()
|
||||
q.keysOnly = true
|
||||
return q
|
||||
}
|
||||
|
||||
// Limit returns a derivative query that has a limit on the number of results
|
||||
// returned. A negative value means unlimited.
|
||||
func (q *Query) Limit(limit int) *Query {
|
||||
q = q.clone()
|
||||
if limit < math.MinInt32 || limit > math.MaxInt32 {
|
||||
q.err = errors.New("datastore: query limit overflow")
|
||||
return q
|
||||
}
|
||||
q.limit = int32(limit)
|
||||
return q
|
||||
}
|
||||
|
||||
// Offset returns a derivative query that has an offset of how many keys to
|
||||
// skip over before returning results. A negative value is invalid.
|
||||
func (q *Query) Offset(offset int) *Query {
|
||||
q = q.clone()
|
||||
if offset < 0 {
|
||||
q.err = errors.New("datastore: negative query offset")
|
||||
return q
|
||||
}
|
||||
if offset > math.MaxInt32 {
|
||||
q.err = errors.New("datastore: query offset overflow")
|
||||
return q
|
||||
}
|
||||
q.offset = int32(offset)
|
||||
return q
|
||||
}
|
||||
|
||||
// Start returns a derivative query with the given start point.
|
||||
func (q *Query) Start(c Cursor) *Query {
|
||||
q = q.clone()
|
||||
if c.cc == nil {
|
||||
q.err = errors.New("datastore: invalid cursor")
|
||||
return q
|
||||
}
|
||||
q.start = c.cc
|
||||
return q
|
||||
}
|
||||
|
||||
// End returns a derivative query with the given end point.
|
||||
func (q *Query) End(c Cursor) *Query {
|
||||
q = q.clone()
|
||||
if c.cc == nil {
|
||||
q.err = errors.New("datastore: invalid cursor")
|
||||
return q
|
||||
}
|
||||
q.end = c.cc
|
||||
return q
|
||||
}
|
||||
|
||||
// toProto converts the query to a protocol buffer.
|
||||
func (q *Query) toProto(dst *pb.Query, appID string) error {
|
||||
if len(q.projection) != 0 && q.keysOnly {
|
||||
return errors.New("datastore: query cannot both project and be keys-only")
|
||||
}
|
||||
dst.Reset()
|
||||
dst.App = proto.String(appID)
|
||||
if q.kind != "" {
|
||||
dst.Kind = proto.String(q.kind)
|
||||
}
|
||||
if q.ancestor != nil {
|
||||
dst.Ancestor = keyToProto(appID, q.ancestor)
|
||||
if q.eventual {
|
||||
dst.Strong = proto.Bool(false)
|
||||
}
|
||||
}
|
||||
if q.projection != nil {
|
||||
dst.PropertyName = q.projection
|
||||
if q.distinct {
|
||||
dst.GroupByPropertyName = q.projection
|
||||
}
|
||||
}
|
||||
if q.keysOnly {
|
||||
dst.KeysOnly = proto.Bool(true)
|
||||
dst.RequirePerfectPlan = proto.Bool(true)
|
||||
}
|
||||
for _, qf := range q.filter {
|
||||
if qf.FieldName == "" {
|
||||
return errors.New("datastore: empty query filter field name")
|
||||
}
|
||||
p, errStr := valueToProto(appID, qf.FieldName, reflect.ValueOf(qf.Value), false)
|
||||
if errStr != "" {
|
||||
return errors.New("datastore: bad query filter value type: " + errStr)
|
||||
}
|
||||
xf := &pb.Query_Filter{
|
||||
Op: operatorToProto[qf.Op],
|
||||
Property: []*pb.Property{p},
|
||||
}
|
||||
if xf.Op == nil {
|
||||
return errors.New("datastore: unknown query filter operator")
|
||||
}
|
||||
dst.Filter = append(dst.Filter, xf)
|
||||
}
|
||||
for _, qo := range q.order {
|
||||
if qo.FieldName == "" {
|
||||
return errors.New("datastore: empty query order field name")
|
||||
}
|
||||
xo := &pb.Query_Order{
|
||||
Property: proto.String(qo.FieldName),
|
||||
Direction: sortDirectionToProto[qo.Direction],
|
||||
}
|
||||
if xo.Direction == nil {
|
||||
return errors.New("datastore: unknown query order direction")
|
||||
}
|
||||
dst.Order = append(dst.Order, xo)
|
||||
}
|
||||
if q.limit >= 0 {
|
||||
dst.Limit = proto.Int32(q.limit)
|
||||
}
|
||||
if q.offset != 0 {
|
||||
dst.Offset = proto.Int32(q.offset)
|
||||
}
|
||||
dst.CompiledCursor = q.start
|
||||
dst.EndCompiledCursor = q.end
|
||||
dst.Compile = proto.Bool(true)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Count returns the number of results for the query.
|
||||
//
|
||||
// The running time and number of API calls made by Count scale linearly with
|
||||
// the sum of the query's offset and limit. Unless the result count is
|
||||
// expected to be small, it is best to specify a limit; otherwise Count will
|
||||
// continue until it finishes counting or the provided context expires.
|
||||
func (q *Query) Count(c context.Context) (int, error) {
|
||||
// Check that the query is well-formed.
|
||||
if q.err != nil {
|
||||
return 0, q.err
|
||||
}
|
||||
|
||||
// Run a copy of the query, with keysOnly true (if we're not a projection,
|
||||
// since the two are incompatible), and an adjusted offset. We also set the
|
||||
// limit to zero, as we don't want any actual entity data, just the number
|
||||
// of skipped results.
|
||||
newQ := q.clone()
|
||||
newQ.keysOnly = len(newQ.projection) == 0
|
||||
newQ.limit = 0
|
||||
if q.limit < 0 {
|
||||
// If the original query was unlimited, set the new query's offset to maximum.
|
||||
newQ.offset = math.MaxInt32
|
||||
} else {
|
||||
newQ.offset = q.offset + q.limit
|
||||
if newQ.offset < 0 {
|
||||
// Do the best we can, in the presence of overflow.
|
||||
newQ.offset = math.MaxInt32
|
||||
}
|
||||
}
|
||||
req := &pb.Query{}
|
||||
if err := newQ.toProto(req, internal.FullyQualifiedAppID(c)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
res := &pb.QueryResult{}
|
||||
if err := internal.Call(c, "datastore_v3", "RunQuery", req, res); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// n is the count we will return. For example, suppose that our original
|
||||
// query had an offset of 4 and a limit of 2008: the count will be 2008,
|
||||
// provided that there are at least 2012 matching entities. However, the
|
||||
// RPCs will only skip 1000 results at a time. The RPC sequence is:
|
||||
// call RunQuery with (offset, limit) = (2012, 0) // 2012 == newQ.offset
|
||||
// response has (skippedResults, moreResults) = (1000, true)
|
||||
// n += 1000 // n == 1000
|
||||
// call Next with (offset, limit) = (1012, 0) // 1012 == newQ.offset - n
|
||||
// response has (skippedResults, moreResults) = (1000, true)
|
||||
// n += 1000 // n == 2000
|
||||
// call Next with (offset, limit) = (12, 0) // 12 == newQ.offset - n
|
||||
// response has (skippedResults, moreResults) = (12, false)
|
||||
// n += 12 // n == 2012
|
||||
// // exit the loop
|
||||
// n -= 4 // n == 2008
|
||||
var n int32
|
||||
for {
|
||||
// The QueryResult should have no actual entity data, just skipped results.
|
||||
if len(res.Result) != 0 {
|
||||
return 0, errors.New("datastore: internal error: Count request returned too much data")
|
||||
}
|
||||
n += res.GetSkippedResults()
|
||||
if !res.GetMoreResults() {
|
||||
break
|
||||
}
|
||||
if err := callNext(c, res, newQ.offset-n, 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
n -= q.offset
|
||||
if n < 0 {
|
||||
// If the offset was greater than the number of matching entities,
|
||||
// return 0 instead of negative.
|
||||
n = 0
|
||||
}
|
||||
return int(n), nil
|
||||
}
|
||||
|
||||
// callNext issues a datastore_v3/Next RPC to advance a cursor, such as that
|
||||
// returned by a query with more results.
|
||||
func callNext(c context.Context, res *pb.QueryResult, offset, limit int32) error {
|
||||
if res.Cursor == nil {
|
||||
return errors.New("datastore: internal error: server did not return a cursor")
|
||||
}
|
||||
req := &pb.NextRequest{
|
||||
Cursor: res.Cursor,
|
||||
}
|
||||
if limit >= 0 {
|
||||
req.Count = proto.Int32(limit)
|
||||
}
|
||||
if offset != 0 {
|
||||
req.Offset = proto.Int32(offset)
|
||||
}
|
||||
if res.CompiledCursor != nil {
|
||||
req.Compile = proto.Bool(true)
|
||||
}
|
||||
res.Reset()
|
||||
return internal.Call(c, "datastore_v3", "Next", req, res)
|
||||
}
|
||||
|
||||
// GetAll runs the query in the given context and returns all keys that match
|
||||
// that query, as well as appending the values to dst.
|
||||
//
|
||||
// dst must have type *[]S or *[]*S or *[]P, for some struct type S or some non-
|
||||
// interface, non-pointer type P such that P or *P implements PropertyLoadSaver.
|
||||
//
|
||||
// As a special case, *PropertyList is an invalid type for dst, even though a
|
||||
// PropertyList is a slice of structs. It is treated as invalid to avoid being
|
||||
// mistakenly passed when *[]PropertyList was intended.
|
||||
//
|
||||
// The keys returned by GetAll will be in a 1-1 correspondence with the entities
|
||||
// added to dst.
|
||||
//
|
||||
// If q is a ``keys-only'' query, GetAll ignores dst and only returns the keys.
|
||||
//
|
||||
// The running time and number of API calls made by GetAll scale linearly with
|
||||
// with the sum of the query's offset and limit. Unless the result count is
|
||||
// expected to be small, it is best to specify a limit; otherwise GetAll will
|
||||
// continue until it finishes collecting results or the provided context
|
||||
// expires.
|
||||
func (q *Query) GetAll(c context.Context, dst interface{}) ([]*Key, error) {
|
||||
var (
|
||||
dv reflect.Value
|
||||
mat multiArgType
|
||||
elemType reflect.Type
|
||||
errFieldMismatch error
|
||||
)
|
||||
if !q.keysOnly {
|
||||
dv = reflect.ValueOf(dst)
|
||||
if dv.Kind() != reflect.Ptr || dv.IsNil() {
|
||||
return nil, ErrInvalidEntityType
|
||||
}
|
||||
dv = dv.Elem()
|
||||
mat, elemType = checkMultiArg(dv)
|
||||
if mat == multiArgTypeInvalid || mat == multiArgTypeInterface {
|
||||
return nil, ErrInvalidEntityType
|
||||
}
|
||||
}
|
||||
|
||||
var keys []*Key
|
||||
for t := q.Run(c); ; {
|
||||
k, e, err := t.next()
|
||||
if err == Done {
|
||||
break
|
||||
}
|
||||
if err != nil {
|
||||
return keys, err
|
||||
}
|
||||
if !q.keysOnly {
|
||||
ev := reflect.New(elemType)
|
||||
if elemType.Kind() == reflect.Map {
|
||||
// This is a special case. The zero values of a map type are
|
||||
// not immediately useful; they have to be make'd.
|
||||
//
|
||||
// Funcs and channels are similar, in that a zero value is not useful,
|
||||
// but even a freshly make'd channel isn't useful: there's no fixed
|
||||
// channel buffer size that is always going to be large enough, and
|
||||
// there's no goroutine to drain the other end. Theoretically, these
|
||||
// types could be supported, for example by sniffing for a constructor
|
||||
// method or requiring prior registration, but for now it's not a
|
||||
// frequent enough concern to be worth it. Programmers can work around
|
||||
// it by explicitly using Iterator.Next instead of the Query.GetAll
|
||||
// convenience method.
|
||||
x := reflect.MakeMap(elemType)
|
||||
ev.Elem().Set(x)
|
||||
}
|
||||
if err = loadEntity(ev.Interface(), e); err != nil {
|
||||
if _, ok := err.(*ErrFieldMismatch); ok {
|
||||
// We continue loading entities even in the face of field mismatch errors.
|
||||
// If we encounter any other error, that other error is returned. Otherwise,
|
||||
// an ErrFieldMismatch is returned.
|
||||
errFieldMismatch = err
|
||||
} else {
|
||||
return keys, err
|
||||
}
|
||||
}
|
||||
if mat != multiArgTypeStructPtr {
|
||||
ev = ev.Elem()
|
||||
}
|
||||
dv.Set(reflect.Append(dv, ev))
|
||||
}
|
||||
keys = append(keys, k)
|
||||
}
|
||||
return keys, errFieldMismatch
|
||||
}
|
||||
|
||||
// Run runs the query in the given context.
|
||||
func (q *Query) Run(c context.Context) *Iterator {
|
||||
if q.err != nil {
|
||||
return &Iterator{err: q.err}
|
||||
}
|
||||
t := &Iterator{
|
||||
c: c,
|
||||
limit: q.limit,
|
||||
q: q,
|
||||
prevCC: q.start,
|
||||
}
|
||||
var req pb.Query
|
||||
if err := q.toProto(&req, internal.FullyQualifiedAppID(c)); err != nil {
|
||||
t.err = err
|
||||
return t
|
||||
}
|
||||
if err := internal.Call(c, "datastore_v3", "RunQuery", &req, &t.res); err != nil {
|
||||
t.err = err
|
||||
return t
|
||||
}
|
||||
offset := q.offset - t.res.GetSkippedResults()
|
||||
for offset > 0 && t.res.GetMoreResults() {
|
||||
t.prevCC = t.res.CompiledCursor
|
||||
if err := callNext(t.c, &t.res, offset, t.limit); err != nil {
|
||||
t.err = err
|
||||
break
|
||||
}
|
||||
skip := t.res.GetSkippedResults()
|
||||
if skip < 0 {
|
||||
t.err = errors.New("datastore: internal error: negative number of skipped_results")
|
||||
break
|
||||
}
|
||||
offset -= skip
|
||||
}
|
||||
if offset < 0 {
|
||||
t.err = errors.New("datastore: internal error: query offset was overshot")
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// Iterator is the result of running a query.
|
||||
type Iterator struct {
|
||||
c context.Context
|
||||
err error
|
||||
// res is the result of the most recent RunQuery or Next API call.
|
||||
res pb.QueryResult
|
||||
// i is how many elements of res.Result we have iterated over.
|
||||
i int
|
||||
// limit is the limit on the number of results this iterator should return.
|
||||
// A negative value means unlimited.
|
||||
limit int32
|
||||
// q is the original query which yielded this iterator.
|
||||
q *Query
|
||||
// prevCC is the compiled cursor that marks the end of the previous batch
|
||||
// of results.
|
||||
prevCC *pb.CompiledCursor
|
||||
}
|
||||
|
||||
// Done is returned when a query iteration has completed.
|
||||
var Done = errors.New("datastore: query has no more results")
|
||||
|
||||
// Next returns the key of the next result. When there are no more results,
|
||||
// Done is returned as the error.
|
||||
//
|
||||
// If the query is not keys only and dst is non-nil, it also loads the entity
|
||||
// stored for that key into the struct pointer or PropertyLoadSaver dst, with
|
||||
// the same semantics and possible errors as for the Get function.
|
||||
func (t *Iterator) Next(dst interface{}) (*Key, error) {
|
||||
k, e, err := t.next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if dst != nil && !t.q.keysOnly {
|
||||
err = loadEntity(dst, e)
|
||||
}
|
||||
return k, err
|
||||
}
|
||||
|
||||
func (t *Iterator) next() (*Key, *pb.EntityProto, error) {
|
||||
if t.err != nil {
|
||||
return nil, nil, t.err
|
||||
}
|
||||
|
||||
// Issue datastore_v3/Next RPCs as necessary.
|
||||
for t.i == len(t.res.Result) {
|
||||
if !t.res.GetMoreResults() {
|
||||
t.err = Done
|
||||
return nil, nil, t.err
|
||||
}
|
||||
t.prevCC = t.res.CompiledCursor
|
||||
if err := callNext(t.c, &t.res, 0, t.limit); err != nil {
|
||||
t.err = err
|
||||
return nil, nil, t.err
|
||||
}
|
||||
if t.res.GetSkippedResults() != 0 {
|
||||
t.err = errors.New("datastore: internal error: iterator has skipped results")
|
||||
return nil, nil, t.err
|
||||
}
|
||||
t.i = 0
|
||||
if t.limit >= 0 {
|
||||
t.limit -= int32(len(t.res.Result))
|
||||
if t.limit < 0 {
|
||||
t.err = errors.New("datastore: internal error: query returned more results than the limit")
|
||||
return nil, nil, t.err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract the key from the t.i'th element of t.res.Result.
|
||||
e := t.res.Result[t.i]
|
||||
t.i++
|
||||
if e.Key == nil {
|
||||
return nil, nil, errors.New("datastore: internal error: server did not return a key")
|
||||
}
|
||||
k, err := protoToKey(e.Key)
|
||||
if err != nil || k.Incomplete() {
|
||||
return nil, nil, errors.New("datastore: internal error: server returned an invalid key")
|
||||
}
|
||||
return k, e, nil
|
||||
}
|
||||
|
||||
// Cursor returns a cursor for the iterator's current location.
|
||||
func (t *Iterator) Cursor() (Cursor, error) {
|
||||
if t.err != nil && t.err != Done {
|
||||
return Cursor{}, t.err
|
||||
}
|
||||
// If we are at either end of the current batch of results,
|
||||
// return the compiled cursor at that end.
|
||||
skipped := t.res.GetSkippedResults()
|
||||
if t.i == 0 && skipped == 0 {
|
||||
if t.prevCC == nil {
|
||||
// A nil pointer (of type *pb.CompiledCursor) means no constraint:
|
||||
// passing it as the end cursor of a new query means unlimited results
|
||||
// (glossing over the integer limit parameter for now).
|
||||
// A non-nil pointer to an empty pb.CompiledCursor means the start:
|
||||
// passing it as the end cursor of a new query means 0 results.
|
||||
// If prevCC was nil, then the original query had no start cursor, but
|
||||
// Iterator.Cursor should return "the start" instead of unlimited.
|
||||
return Cursor{&zeroCC}, nil
|
||||
}
|
||||
return Cursor{t.prevCC}, nil
|
||||
}
|
||||
if t.i == len(t.res.Result) {
|
||||
return Cursor{t.res.CompiledCursor}, nil
|
||||
}
|
||||
// Otherwise, re-run the query offset to this iterator's position, starting from
|
||||
// the most recent compiled cursor. This is done on a best-effort basis, as it
|
||||
// is racy; if a concurrent process has added or removed entities, then the
|
||||
// cursor returned may be inconsistent.
|
||||
q := t.q.clone()
|
||||
q.start = t.prevCC
|
||||
q.offset = skipped + int32(t.i)
|
||||
q.limit = 0
|
||||
q.keysOnly = len(q.projection) == 0
|
||||
t1 := q.Run(t.c)
|
||||
_, _, err := t1.next()
|
||||
if err != Done {
|
||||
if err == nil {
|
||||
err = fmt.Errorf("datastore: internal error: zero-limit query did not have zero results")
|
||||
}
|
||||
return Cursor{}, err
|
||||
}
|
||||
return Cursor{t1.res.CompiledCursor}, nil
|
||||
}
|
||||
|
||||
var zeroCC pb.CompiledCursor
|
||||
|
||||
// Cursor is an iterator's position. It can be converted to and from an opaque
|
||||
// string. A cursor can be used from different HTTP requests, but only with a
|
||||
// query with the same kind, ancestor, filter and order constraints.
|
||||
type Cursor struct {
|
||||
cc *pb.CompiledCursor
|
||||
}
|
||||
|
||||
// String returns a base-64 string representation of a cursor.
|
||||
func (c Cursor) String() string {
|
||||
if c.cc == nil {
|
||||
return ""
|
||||
}
|
||||
b, err := proto.Marshal(c.cc)
|
||||
if err != nil {
|
||||
// The only way to construct a Cursor with a non-nil cc field is to
|
||||
// unmarshal from the byte representation. We panic if the unmarshal
|
||||
// succeeds but the marshaling of the unchanged protobuf value fails.
|
||||
panic(fmt.Sprintf("datastore: internal error: malformed cursor: %v", err))
|
||||
}
|
||||
return strings.TrimRight(base64.URLEncoding.EncodeToString(b), "=")
|
||||
}
|
||||
|
||||
// Decode decodes a cursor from its base-64 string representation.
|
||||
func DecodeCursor(s string) (Cursor, error) {
|
||||
if s == "" {
|
||||
return Cursor{&zeroCC}, nil
|
||||
}
|
||||
if n := len(s) % 4; n != 0 {
|
||||
s += strings.Repeat("=", 4-n)
|
||||
}
|
||||
b, err := base64.URLEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return Cursor{}, err
|
||||
}
|
||||
cc := &pb.CompiledCursor{}
|
||||
if err := proto.Unmarshal(b, cc); err != nil {
|
||||
return Cursor{}, err
|
||||
}
|
||||
return Cursor{cc}, nil
|
||||
}
|
|
@ -0,0 +1,583 @@
|
|||
// Copyright 2011 Google Inc. All Rights Reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
"google.golang.org/appengine/internal/aetesting"
|
||||
pb "google.golang.org/appengine/internal/datastore"
|
||||
)
|
||||
|
||||
var (
|
||||
path1 = &pb.Path{
|
||||
Element: []*pb.Path_Element{
|
||||
{
|
||||
Type: proto.String("Gopher"),
|
||||
Id: proto.Int64(6),
|
||||
},
|
||||
},
|
||||
}
|
||||
path2 = &pb.Path{
|
||||
Element: []*pb.Path_Element{
|
||||
{
|
||||
Type: proto.String("Gopher"),
|
||||
Id: proto.Int64(6),
|
||||
},
|
||||
{
|
||||
Type: proto.String("Gopher"),
|
||||
Id: proto.Int64(8),
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func fakeRunQuery(in *pb.Query, out *pb.QueryResult) error {
|
||||
expectedIn := &pb.Query{
|
||||
App: proto.String("dev~fake-app"),
|
||||
Kind: proto.String("Gopher"),
|
||||
Compile: proto.Bool(true),
|
||||
}
|
||||
if !proto.Equal(in, expectedIn) {
|
||||
return fmt.Errorf("unsupported argument: got %v want %v", in, expectedIn)
|
||||
}
|
||||
*out = pb.QueryResult{
|
||||
Result: []*pb.EntityProto{
|
||||
{
|
||||
Key: &pb.Reference{
|
||||
App: proto.String("s~test-app"),
|
||||
Path: path1,
|
||||
},
|
||||
EntityGroup: path1,
|
||||
Property: []*pb.Property{
|
||||
{
|
||||
Meaning: pb.Property_TEXT.Enum(),
|
||||
Name: proto.String("Name"),
|
||||
Value: &pb.PropertyValue{
|
||||
StringValue: proto.String("George"),
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: proto.String("Height"),
|
||||
Value: &pb.PropertyValue{
|
||||
Int64Value: proto.Int64(32),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Key: &pb.Reference{
|
||||
App: proto.String("s~test-app"),
|
||||
Path: path2,
|
||||
},
|
||||
EntityGroup: path1, // ancestor is George
|
||||
Property: []*pb.Property{
|
||||
{
|
||||
Meaning: pb.Property_TEXT.Enum(),
|
||||
Name: proto.String("Name"),
|
||||
Value: &pb.PropertyValue{
|
||||
StringValue: proto.String("Rufus"),
|
||||
},
|
||||
},
|
||||
// No height for Rufus.
|
||||
},
|
||||
},
|
||||
},
|
||||
MoreResults: proto.Bool(false),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type StructThatImplementsPLS struct{}
|
||||
|
||||
func (StructThatImplementsPLS) Load(p []Property) error { return nil }
|
||||
func (StructThatImplementsPLS) Save() ([]Property, error) { return nil, nil }
|
||||
|
||||
var _ PropertyLoadSaver = StructThatImplementsPLS{}
|
||||
|
||||
type StructPtrThatImplementsPLS struct{}
|
||||
|
||||
func (*StructPtrThatImplementsPLS) Load(p []Property) error { return nil }
|
||||
func (*StructPtrThatImplementsPLS) Save() ([]Property, error) { return nil, nil }
|
||||
|
||||
var _ PropertyLoadSaver = &StructPtrThatImplementsPLS{}
|
||||
|
||||
type PropertyMap map[string]Property
|
||||
|
||||
func (m PropertyMap) Load(props []Property) error {
|
||||
for _, p := range props {
|
||||
if p.Multiple {
|
||||
return errors.New("PropertyMap does not support multiple properties")
|
||||
}
|
||||
m[p.Name] = p
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m PropertyMap) Save() ([]Property, error) {
|
||||
props := make([]Property, 0, len(m))
|
||||
for _, p := range m {
|
||||
if p.Multiple {
|
||||
return nil, errors.New("PropertyMap does not support multiple properties")
|
||||
}
|
||||
props = append(props, p)
|
||||
}
|
||||
return props, nil
|
||||
}
|
||||
|
||||
var _ PropertyLoadSaver = PropertyMap{}
|
||||
|
||||
type Gopher struct {
|
||||
Name string
|
||||
Height int
|
||||
}
|
||||
|
||||
// typeOfEmptyInterface is the type of interface{}, but we can't use
|
||||
// reflect.TypeOf((interface{})(nil)) directly because TypeOf takes an
|
||||
// interface{}.
|
||||
var typeOfEmptyInterface = reflect.TypeOf((*interface{})(nil)).Elem()
|
||||
|
||||
func TestCheckMultiArg(t *testing.T) {
|
||||
testCases := []struct {
|
||||
v interface{}
|
||||
mat multiArgType
|
||||
elemType reflect.Type
|
||||
}{
|
||||
// Invalid cases.
|
||||
{nil, multiArgTypeInvalid, nil},
|
||||
{Gopher{}, multiArgTypeInvalid, nil},
|
||||
{&Gopher{}, multiArgTypeInvalid, nil},
|
||||
{PropertyList{}, multiArgTypeInvalid, nil}, // This is a special case.
|
||||
{PropertyMap{}, multiArgTypeInvalid, nil},
|
||||
{[]*PropertyList(nil), multiArgTypeInvalid, nil},
|
||||
{[]*PropertyMap(nil), multiArgTypeInvalid, nil},
|
||||
{[]**Gopher(nil), multiArgTypeInvalid, nil},
|
||||
{[]*interface{}(nil), multiArgTypeInvalid, nil},
|
||||
// Valid cases.
|
||||
{
|
||||
[]PropertyList(nil),
|
||||
multiArgTypePropertyLoadSaver,
|
||||
reflect.TypeOf(PropertyList{}),
|
||||
},
|
||||
{
|
||||
[]PropertyMap(nil),
|
||||
multiArgTypePropertyLoadSaver,
|
||||
reflect.TypeOf(PropertyMap{}),
|
||||
},
|
||||
{
|
||||
[]StructThatImplementsPLS(nil),
|
||||
multiArgTypePropertyLoadSaver,
|
||||
reflect.TypeOf(StructThatImplementsPLS{}),
|
||||
},
|
||||
{
|
||||
[]StructPtrThatImplementsPLS(nil),
|
||||
multiArgTypePropertyLoadSaver,
|
||||
reflect.TypeOf(StructPtrThatImplementsPLS{}),
|
||||
},
|
||||
{
|
||||
[]Gopher(nil),
|
||||
multiArgTypeStruct,
|
||||
reflect.TypeOf(Gopher{}),
|
||||
},
|
||||
{
|
||||
[]*Gopher(nil),
|
||||
multiArgTypeStructPtr,
|
||||
reflect.TypeOf(Gopher{}),
|
||||
},
|
||||
{
|
||||
[]interface{}(nil),
|
||||
multiArgTypeInterface,
|
||||
typeOfEmptyInterface,
|
||||
},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
mat, elemType := checkMultiArg(reflect.ValueOf(tc.v))
|
||||
if mat != tc.mat || elemType != tc.elemType {
|
||||
t.Errorf("checkMultiArg(%T): got %v, %v want %v, %v",
|
||||
tc.v, mat, elemType, tc.mat, tc.elemType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSimpleQuery(t *testing.T) {
|
||||
struct1 := Gopher{Name: "George", Height: 32}
|
||||
struct2 := Gopher{Name: "Rufus"}
|
||||
pList1 := PropertyList{
|
||||
{
|
||||
Name: "Name",
|
||||
Value: "George",
|
||||
},
|
||||
{
|
||||
Name: "Height",
|
||||
Value: int64(32),
|
||||
},
|
||||
}
|
||||
pList2 := PropertyList{
|
||||
{
|
||||
Name: "Name",
|
||||
Value: "Rufus",
|
||||
},
|
||||
}
|
||||
pMap1 := PropertyMap{
|
||||
"Name": Property{
|
||||
Name: "Name",
|
||||
Value: "George",
|
||||
},
|
||||
"Height": Property{
|
||||
Name: "Height",
|
||||
Value: int64(32),
|
||||
},
|
||||
}
|
||||
pMap2 := PropertyMap{
|
||||
"Name": Property{
|
||||
Name: "Name",
|
||||
Value: "Rufus",
|
||||
},
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
dst interface{}
|
||||
want interface{}
|
||||
}{
|
||||
// The destination must have type *[]P, *[]S or *[]*S, for some non-interface
|
||||
// type P such that *P implements PropertyLoadSaver, or for some struct type S.
|
||||
{new([]Gopher), &[]Gopher{struct1, struct2}},
|
||||
{new([]*Gopher), &[]*Gopher{&struct1, &struct2}},
|
||||
{new([]PropertyList), &[]PropertyList{pList1, pList2}},
|
||||
{new([]PropertyMap), &[]PropertyMap{pMap1, pMap2}},
|
||||
|
||||
// Any other destination type is invalid.
|
||||
{0, nil},
|
||||
{Gopher{}, nil},
|
||||
{PropertyList{}, nil},
|
||||
{PropertyMap{}, nil},
|
||||
{[]int{}, nil},
|
||||
{[]Gopher{}, nil},
|
||||
{[]PropertyList{}, nil},
|
||||
{new(int), nil},
|
||||
{new(Gopher), nil},
|
||||
{new(PropertyList), nil}, // This is a special case.
|
||||
{new(PropertyMap), nil},
|
||||
{new([]int), nil},
|
||||
{new([]map[int]int), nil},
|
||||
{new([]map[string]Property), nil},
|
||||
{new([]map[string]interface{}), nil},
|
||||
{new([]*int), nil},
|
||||
{new([]*map[int]int), nil},
|
||||
{new([]*map[string]Property), nil},
|
||||
{new([]*map[string]interface{}), nil},
|
||||
{new([]**Gopher), nil},
|
||||
{new([]*PropertyList), nil},
|
||||
{new([]*PropertyMap), nil},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
nCall := 0
|
||||
c := aetesting.FakeSingleContext(t, "datastore_v3", "RunQuery", func(in *pb.Query, out *pb.QueryResult) error {
|
||||
nCall++
|
||||
return fakeRunQuery(in, out)
|
||||
})
|
||||
c = internal.WithAppIDOverride(c, "dev~fake-app")
|
||||
|
||||
var (
|
||||
expectedErr error
|
||||
expectedNCall int
|
||||
)
|
||||
if tc.want == nil {
|
||||
expectedErr = ErrInvalidEntityType
|
||||
} else {
|
||||
expectedNCall = 1
|
||||
}
|
||||
keys, err := NewQuery("Gopher").GetAll(c, tc.dst)
|
||||
if err != expectedErr {
|
||||
t.Errorf("dst type %T: got error [%v], want [%v]", tc.dst, err, expectedErr)
|
||||
continue
|
||||
}
|
||||
if nCall != expectedNCall {
|
||||
t.Errorf("dst type %T: Context.Call was called an incorrect number of times: got %d want %d", tc.dst, nCall, expectedNCall)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
key1 := NewKey(c, "Gopher", "", 6, nil)
|
||||
expectedKeys := []*Key{
|
||||
key1,
|
||||
NewKey(c, "Gopher", "", 8, key1),
|
||||
}
|
||||
if l1, l2 := len(keys), len(expectedKeys); l1 != l2 {
|
||||
t.Errorf("dst type %T: got %d keys, want %d keys", tc.dst, l1, l2)
|
||||
continue
|
||||
}
|
||||
for i, key := range keys {
|
||||
if key.AppID() != "s~test-app" {
|
||||
t.Errorf(`dst type %T: Key #%d's AppID = %q, want "s~test-app"`, tc.dst, i, key.AppID())
|
||||
continue
|
||||
}
|
||||
if !keysEqual(key, expectedKeys[i]) {
|
||||
t.Errorf("dst type %T: got key #%d %v, want %v", tc.dst, i, key, expectedKeys[i])
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.dst, tc.want) {
|
||||
t.Errorf("dst type %T: Entities got %+v, want %+v", tc.dst, tc.dst, tc.want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// keysEqual is like (*Key).Equal, but ignores the App ID.
|
||||
func keysEqual(a, b *Key) bool {
|
||||
for a != nil && b != nil {
|
||||
if a.Kind() != b.Kind() || a.StringID() != b.StringID() || a.IntID() != b.IntID() {
|
||||
return false
|
||||
}
|
||||
a, b = a.Parent(), b.Parent()
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
func TestQueriesAreImmutable(t *testing.T) {
|
||||
// Test that deriving q2 from q1 does not modify q1.
|
||||
q0 := NewQuery("foo")
|
||||
q1 := NewQuery("foo")
|
||||
q2 := q1.Offset(2)
|
||||
if !reflect.DeepEqual(q0, q1) {
|
||||
t.Errorf("q0 and q1 were not equal")
|
||||
}
|
||||
if reflect.DeepEqual(q1, q2) {
|
||||
t.Errorf("q1 and q2 were equal")
|
||||
}
|
||||
|
||||
// Test that deriving from q4 twice does not conflict, even though
|
||||
// q4 has a long list of order clauses. This tests that the arrays
|
||||
// backed by a query's slice of orders are not shared.
|
||||
f := func() *Query {
|
||||
q := NewQuery("bar")
|
||||
// 47 is an ugly number that is unlikely to be near a re-allocation
|
||||
// point in repeated append calls. For example, it's not near a power
|
||||
// of 2 or a multiple of 10.
|
||||
for i := 0; i < 47; i++ {
|
||||
q = q.Order(fmt.Sprintf("x%d", i))
|
||||
}
|
||||
return q
|
||||
}
|
||||
q3 := f().Order("y")
|
||||
q4 := f()
|
||||
q5 := q4.Order("y")
|
||||
q6 := q4.Order("z")
|
||||
if !reflect.DeepEqual(q3, q5) {
|
||||
t.Errorf("q3 and q5 were not equal")
|
||||
}
|
||||
if reflect.DeepEqual(q5, q6) {
|
||||
t.Errorf("q5 and q6 were equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterParser(t *testing.T) {
|
||||
testCases := []struct {
|
||||
filterStr string
|
||||
wantOK bool
|
||||
wantFieldName string
|
||||
wantOp operator
|
||||
}{
|
||||
// Supported ops.
|
||||
{"x<", true, "x", lessThan},
|
||||
{"x <", true, "x", lessThan},
|
||||
{"x <", true, "x", lessThan},
|
||||
{" x < ", true, "x", lessThan},
|
||||
{"x <=", true, "x", lessEq},
|
||||
{"x =", true, "x", equal},
|
||||
{"x >=", true, "x", greaterEq},
|
||||
{"x >", true, "x", greaterThan},
|
||||
{"in >", true, "in", greaterThan},
|
||||
{"in>", true, "in", greaterThan},
|
||||
// Valid but (currently) unsupported ops.
|
||||
{"x!=", false, "", 0},
|
||||
{"x !=", false, "", 0},
|
||||
{" x != ", false, "", 0},
|
||||
{"x IN", false, "", 0},
|
||||
{"x in", false, "", 0},
|
||||
// Invalid ops.
|
||||
{"x EQ", false, "", 0},
|
||||
{"x lt", false, "", 0},
|
||||
{"x <>", false, "", 0},
|
||||
{"x >>", false, "", 0},
|
||||
{"x ==", false, "", 0},
|
||||
{"x =<", false, "", 0},
|
||||
{"x =>", false, "", 0},
|
||||
{"x !", false, "", 0},
|
||||
{"x ", false, "", 0},
|
||||
{"x", false, "", 0},
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
q := NewQuery("foo").Filter(tc.filterStr, 42)
|
||||
if ok := q.err == nil; ok != tc.wantOK {
|
||||
t.Errorf("%q: ok=%t, want %t", tc.filterStr, ok, tc.wantOK)
|
||||
continue
|
||||
}
|
||||
if !tc.wantOK {
|
||||
continue
|
||||
}
|
||||
if len(q.filter) != 1 {
|
||||
t.Errorf("%q: len=%d, want %d", tc.filterStr, len(q.filter), 1)
|
||||
continue
|
||||
}
|
||||
got, want := q.filter[0], filter{tc.wantFieldName, tc.wantOp, 42}
|
||||
if got != want {
|
||||
t.Errorf("%q: got %v, want %v", tc.filterStr, got, want)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQueryToProto(t *testing.T) {
|
||||
// The context is required to make Keys for the test cases.
|
||||
var got *pb.Query
|
||||
NoErr := errors.New("No error")
|
||||
c := aetesting.FakeSingleContext(t, "datastore_v3", "RunQuery", func(in *pb.Query, out *pb.QueryResult) error {
|
||||
got = in
|
||||
return NoErr // return a non-nil error so Run doesn't keep going.
|
||||
})
|
||||
c = internal.WithAppIDOverride(c, "dev~fake-app")
|
||||
|
||||
testCases := []struct {
|
||||
desc string
|
||||
query *Query
|
||||
want *pb.Query
|
||||
err string
|
||||
}{
|
||||
{
|
||||
desc: "empty",
|
||||
query: NewQuery(""),
|
||||
want: &pb.Query{},
|
||||
},
|
||||
{
|
||||
desc: "standard query",
|
||||
query: NewQuery("kind").Order("-I").Filter("I >", 17).Filter("U =", "Dave").Limit(7).Offset(42),
|
||||
want: &pb.Query{
|
||||
Kind: proto.String("kind"),
|
||||
Filter: []*pb.Query_Filter{
|
||||
{
|
||||
Op: pb.Query_Filter_GREATER_THAN.Enum(),
|
||||
Property: []*pb.Property{
|
||||
{
|
||||
Name: proto.String("I"),
|
||||
Value: &pb.PropertyValue{Int64Value: proto.Int64(17)},
|
||||
Multiple: proto.Bool(false),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Op: pb.Query_Filter_EQUAL.Enum(),
|
||||
Property: []*pb.Property{
|
||||
{
|
||||
Name: proto.String("U"),
|
||||
Value: &pb.PropertyValue{StringValue: proto.String("Dave")},
|
||||
Multiple: proto.Bool(false),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Order: []*pb.Query_Order{
|
||||
{
|
||||
Property: proto.String("I"),
|
||||
Direction: pb.Query_Order_DESCENDING.Enum(),
|
||||
},
|
||||
},
|
||||
Limit: proto.Int32(7),
|
||||
Offset: proto.Int32(42),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "ancestor",
|
||||
query: NewQuery("").Ancestor(NewKey(c, "kind", "Mummy", 0, nil)),
|
||||
want: &pb.Query{
|
||||
Ancestor: &pb.Reference{
|
||||
App: proto.String("dev~fake-app"),
|
||||
Path: &pb.Path{
|
||||
Element: []*pb.Path_Element{{Type: proto.String("kind"), Name: proto.String("Mummy")}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "projection",
|
||||
query: NewQuery("").Project("A", "B"),
|
||||
want: &pb.Query{
|
||||
PropertyName: []string{"A", "B"},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "projection with distinct",
|
||||
query: NewQuery("").Project("A", "B").Distinct(),
|
||||
want: &pb.Query{
|
||||
PropertyName: []string{"A", "B"},
|
||||
GroupByPropertyName: []string{"A", "B"},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "keys only",
|
||||
query: NewQuery("").KeysOnly(),
|
||||
want: &pb.Query{
|
||||
KeysOnly: proto.Bool(true),
|
||||
RequirePerfectPlan: proto.Bool(true),
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "empty filter",
|
||||
query: NewQuery("kind").Filter("=", 17),
|
||||
err: "empty query filter field nam",
|
||||
},
|
||||
{
|
||||
desc: "bad filter type",
|
||||
query: NewQuery("kind").Filter("M =", map[string]bool{}),
|
||||
err: "bad query filter value type",
|
||||
},
|
||||
{
|
||||
desc: "bad filter operator",
|
||||
query: NewQuery("kind").Filter("I <<=", 17),
|
||||
err: `invalid operator "<<=" in filter "I <<="`,
|
||||
},
|
||||
{
|
||||
desc: "empty order",
|
||||
query: NewQuery("kind").Order(""),
|
||||
err: "empty order",
|
||||
},
|
||||
{
|
||||
desc: "bad order direction",
|
||||
query: NewQuery("kind").Order("+I"),
|
||||
err: `invalid order: "+I`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testCases {
|
||||
got = nil
|
||||
if _, err := tt.query.Run(c).Next(nil); err != NoErr {
|
||||
if tt.err == "" || !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("%s: error %v, want %q", tt.desc, err, tt.err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if tt.err != "" {
|
||||
t.Errorf("%s: no error, want %q", tt.desc, tt.err)
|
||||
continue
|
||||
}
|
||||
// Fields that are common to all protos.
|
||||
tt.want.App = proto.String("dev~fake-app")
|
||||
tt.want.Compile = proto.Bool(true)
|
||||
if !proto.Equal(got, tt.want) {
|
||||
t.Errorf("%s:\ngot %v\nwant %v", tt.desc, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,300 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
pb "google.golang.org/appengine/internal/datastore"
|
||||
)
|
||||
|
||||
func toUnixMicro(t time.Time) int64 {
|
||||
// We cannot use t.UnixNano() / 1e3 because we want to handle times more than
|
||||
// 2^63 nanoseconds (which is about 292 years) away from 1970, and those cannot
|
||||
// be represented in the numerator of a single int64 divide.
|
||||
return t.Unix()*1e6 + int64(t.Nanosecond()/1e3)
|
||||
}
|
||||
|
||||
func fromUnixMicro(t int64) time.Time {
|
||||
return time.Unix(t/1e6, (t%1e6)*1e3).UTC()
|
||||
}
|
||||
|
||||
var (
|
||||
minTime = time.Unix(int64(math.MinInt64)/1e6, (int64(math.MinInt64)%1e6)*1e3)
|
||||
maxTime = time.Unix(int64(math.MaxInt64)/1e6, (int64(math.MaxInt64)%1e6)*1e3)
|
||||
)
|
||||
|
||||
// valueToProto converts a named value to a newly allocated Property.
|
||||
// The returned error string is empty on success.
|
||||
func valueToProto(defaultAppID, name string, v reflect.Value, multiple bool) (p *pb.Property, errStr string) {
|
||||
var (
|
||||
pv pb.PropertyValue
|
||||
unsupported bool
|
||||
)
|
||||
switch v.Kind() {
|
||||
case reflect.Invalid:
|
||||
// No-op.
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
pv.Int64Value = proto.Int64(v.Int())
|
||||
case reflect.Bool:
|
||||
pv.BooleanValue = proto.Bool(v.Bool())
|
||||
case reflect.String:
|
||||
pv.StringValue = proto.String(v.String())
|
||||
case reflect.Float32, reflect.Float64:
|
||||
pv.DoubleValue = proto.Float64(v.Float())
|
||||
case reflect.Ptr:
|
||||
if k, ok := v.Interface().(*Key); ok {
|
||||
if k != nil {
|
||||
pv.Referencevalue = keyToReferenceValue(defaultAppID, k)
|
||||
}
|
||||
} else {
|
||||
unsupported = true
|
||||
}
|
||||
case reflect.Struct:
|
||||
switch t := v.Interface().(type) {
|
||||
case time.Time:
|
||||
if t.Before(minTime) || t.After(maxTime) {
|
||||
return nil, "time value out of range"
|
||||
}
|
||||
pv.Int64Value = proto.Int64(toUnixMicro(t))
|
||||
case appengine.GeoPoint:
|
||||
if !t.Valid() {
|
||||
return nil, "invalid GeoPoint value"
|
||||
}
|
||||
// NOTE: Strangely, latitude maps to X, longitude to Y.
|
||||
pv.Pointvalue = &pb.PropertyValue_PointValue{X: &t.Lat, Y: &t.Lng}
|
||||
default:
|
||||
unsupported = true
|
||||
}
|
||||
case reflect.Slice:
|
||||
if b, ok := v.Interface().([]byte); ok {
|
||||
pv.StringValue = proto.String(string(b))
|
||||
} else {
|
||||
// nvToProto should already catch slice values.
|
||||
// If we get here, we have a slice of slice values.
|
||||
unsupported = true
|
||||
}
|
||||
default:
|
||||
unsupported = true
|
||||
}
|
||||
if unsupported {
|
||||
return nil, "unsupported datastore value type: " + v.Type().String()
|
||||
}
|
||||
p = &pb.Property{
|
||||
Name: proto.String(name),
|
||||
Value: &pv,
|
||||
Multiple: proto.Bool(multiple),
|
||||
}
|
||||
if v.IsValid() {
|
||||
switch v.Interface().(type) {
|
||||
case []byte:
|
||||
p.Meaning = pb.Property_BLOB.Enum()
|
||||
case ByteString:
|
||||
p.Meaning = pb.Property_BYTESTRING.Enum()
|
||||
case appengine.BlobKey:
|
||||
p.Meaning = pb.Property_BLOBKEY.Enum()
|
||||
case time.Time:
|
||||
p.Meaning = pb.Property_GD_WHEN.Enum()
|
||||
case appengine.GeoPoint:
|
||||
p.Meaning = pb.Property_GEORSS_POINT.Enum()
|
||||
}
|
||||
}
|
||||
return p, ""
|
||||
}
|
||||
|
||||
// saveEntity saves an EntityProto into a PropertyLoadSaver or struct pointer.
|
||||
func saveEntity(defaultAppID string, key *Key, src interface{}) (*pb.EntityProto, error) {
|
||||
var err error
|
||||
var props []Property
|
||||
if e, ok := src.(PropertyLoadSaver); ok {
|
||||
props, err = e.Save()
|
||||
} else {
|
||||
props, err = SaveStruct(src)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return propertiesToProto(defaultAppID, key, props)
|
||||
}
|
||||
|
||||
func saveStructProperty(props *[]Property, name string, noIndex, multiple bool, v reflect.Value) error {
|
||||
p := Property{
|
||||
Name: name,
|
||||
NoIndex: noIndex,
|
||||
Multiple: multiple,
|
||||
}
|
||||
switch x := v.Interface().(type) {
|
||||
case *Key:
|
||||
p.Value = x
|
||||
case time.Time:
|
||||
p.Value = x
|
||||
case appengine.BlobKey:
|
||||
p.Value = x
|
||||
case appengine.GeoPoint:
|
||||
p.Value = x
|
||||
case ByteString:
|
||||
p.Value = x
|
||||
default:
|
||||
switch v.Kind() {
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||
p.Value = v.Int()
|
||||
case reflect.Bool:
|
||||
p.Value = v.Bool()
|
||||
case reflect.String:
|
||||
p.Value = v.String()
|
||||
case reflect.Float32, reflect.Float64:
|
||||
p.Value = v.Float()
|
||||
case reflect.Slice:
|
||||
if v.Type().Elem().Kind() == reflect.Uint8 {
|
||||
p.NoIndex = true
|
||||
p.Value = v.Bytes()
|
||||
}
|
||||
case reflect.Struct:
|
||||
if !v.CanAddr() {
|
||||
return fmt.Errorf("datastore: unsupported struct field: value is unaddressable")
|
||||
}
|
||||
sub, err := newStructPLS(v.Addr().Interface())
|
||||
if err != nil {
|
||||
return fmt.Errorf("datastore: unsupported struct field: %v", err)
|
||||
}
|
||||
return sub.(structPLS).save(props, name, noIndex, multiple)
|
||||
}
|
||||
}
|
||||
if p.Value == nil {
|
||||
return fmt.Errorf("datastore: unsupported struct field type: %v", v.Type())
|
||||
}
|
||||
*props = append(*props, p)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s structPLS) Save() ([]Property, error) {
|
||||
var props []Property
|
||||
if err := s.save(&props, "", false, false); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return props, nil
|
||||
}
|
||||
|
||||
func (s structPLS) save(props *[]Property, prefix string, noIndex, multiple bool) error {
|
||||
for i, t := range s.codec.byIndex {
|
||||
if t.name == "-" {
|
||||
continue
|
||||
}
|
||||
name := t.name
|
||||
if prefix != "" {
|
||||
name = prefix + name
|
||||
}
|
||||
v := s.v.Field(i)
|
||||
if !v.IsValid() || !v.CanSet() {
|
||||
continue
|
||||
}
|
||||
noIndex1 := noIndex || t.noIndex
|
||||
// For slice fields that aren't []byte, save each element.
|
||||
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 {
|
||||
for j := 0; j < v.Len(); j++ {
|
||||
if err := saveStructProperty(props, name, noIndex1, true, v.Index(j)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Otherwise, save the field itself.
|
||||
if err := saveStructProperty(props, name, noIndex1, multiple, v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func propertiesToProto(defaultAppID string, key *Key, props []Property) (*pb.EntityProto, error) {
|
||||
e := &pb.EntityProto{
|
||||
Key: keyToProto(defaultAppID, key),
|
||||
}
|
||||
if key.parent == nil {
|
||||
e.EntityGroup = &pb.Path{}
|
||||
} else {
|
||||
e.EntityGroup = keyToProto(defaultAppID, key.root()).Path
|
||||
}
|
||||
prevMultiple := make(map[string]bool)
|
||||
|
||||
for _, p := range props {
|
||||
if pm, ok := prevMultiple[p.Name]; ok {
|
||||
if !pm || !p.Multiple {
|
||||
return nil, fmt.Errorf("datastore: multiple Properties with Name %q, but Multiple is false", p.Name)
|
||||
}
|
||||
} else {
|
||||
prevMultiple[p.Name] = p.Multiple
|
||||
}
|
||||
|
||||
x := &pb.Property{
|
||||
Name: proto.String(p.Name),
|
||||
Value: new(pb.PropertyValue),
|
||||
Multiple: proto.Bool(p.Multiple),
|
||||
}
|
||||
switch v := p.Value.(type) {
|
||||
case int64:
|
||||
x.Value.Int64Value = proto.Int64(v)
|
||||
case bool:
|
||||
x.Value.BooleanValue = proto.Bool(v)
|
||||
case string:
|
||||
x.Value.StringValue = proto.String(v)
|
||||
if p.NoIndex {
|
||||
x.Meaning = pb.Property_TEXT.Enum()
|
||||
}
|
||||
case float64:
|
||||
x.Value.DoubleValue = proto.Float64(v)
|
||||
case *Key:
|
||||
if v != nil {
|
||||
x.Value.Referencevalue = keyToReferenceValue(defaultAppID, v)
|
||||
}
|
||||
case time.Time:
|
||||
if v.Before(minTime) || v.After(maxTime) {
|
||||
return nil, fmt.Errorf("datastore: time value out of range")
|
||||
}
|
||||
x.Value.Int64Value = proto.Int64(toUnixMicro(v))
|
||||
x.Meaning = pb.Property_GD_WHEN.Enum()
|
||||
case appengine.BlobKey:
|
||||
x.Value.StringValue = proto.String(string(v))
|
||||
x.Meaning = pb.Property_BLOBKEY.Enum()
|
||||
case appengine.GeoPoint:
|
||||
if !v.Valid() {
|
||||
return nil, fmt.Errorf("datastore: invalid GeoPoint value")
|
||||
}
|
||||
// NOTE: Strangely, latitude maps to X, longitude to Y.
|
||||
x.Value.Pointvalue = &pb.PropertyValue_PointValue{X: &v.Lat, Y: &v.Lng}
|
||||
x.Meaning = pb.Property_GEORSS_POINT.Enum()
|
||||
case []byte:
|
||||
x.Value.StringValue = proto.String(string(v))
|
||||
x.Meaning = pb.Property_BLOB.Enum()
|
||||
if !p.NoIndex {
|
||||
return nil, fmt.Errorf("datastore: cannot index a []byte valued Property with Name %q", p.Name)
|
||||
}
|
||||
case ByteString:
|
||||
x.Value.StringValue = proto.String(string(v))
|
||||
x.Meaning = pb.Property_BYTESTRING.Enum()
|
||||
default:
|
||||
if p.Value != nil {
|
||||
return nil, fmt.Errorf("datastore: invalid Value type for a Property with Name %q", p.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if p.NoIndex {
|
||||
e.RawProperty = append(e.RawProperty, x)
|
||||
} else {
|
||||
e.Property = append(e.Property, x)
|
||||
if len(e.Property) > maxIndexedProperties {
|
||||
return nil, errors.New("datastore: too many indexed properties")
|
||||
}
|
||||
}
|
||||
}
|
||||
return e, nil
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
// Copyright 2012 Google Inc. All Rights Reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestUnixMicro(t *testing.T) {
|
||||
// Test that all these time.Time values survive a round trip to unix micros.
|
||||
testCases := []time.Time{
|
||||
{},
|
||||
time.Date(2, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(23, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(234, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Unix(-1e6, -1000),
|
||||
time.Unix(-1e6, 0),
|
||||
time.Unix(-1e6, +1000),
|
||||
time.Unix(-60, -1000),
|
||||
time.Unix(-60, 0),
|
||||
time.Unix(-60, +1000),
|
||||
time.Unix(-1, -1000),
|
||||
time.Unix(-1, 0),
|
||||
time.Unix(-1, +1000),
|
||||
time.Unix(0, -3000),
|
||||
time.Unix(0, -2000),
|
||||
time.Unix(0, -1000),
|
||||
time.Unix(0, 0),
|
||||
time.Unix(0, +1000),
|
||||
time.Unix(0, +2000),
|
||||
time.Unix(+60, -1000),
|
||||
time.Unix(+60, 0),
|
||||
time.Unix(+60, +1000),
|
||||
time.Unix(+1e6, -1000),
|
||||
time.Unix(+1e6, 0),
|
||||
time.Unix(+1e6, +1000),
|
||||
time.Date(1999, 12, 31, 23, 59, 59, 999000, time.UTC),
|
||||
time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
time.Date(2006, 1, 2, 15, 4, 5, 678000, time.UTC),
|
||||
time.Date(2009, 11, 10, 23, 0, 0, 0, time.UTC),
|
||||
time.Date(3456, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
for _, tc := range testCases {
|
||||
got := fromUnixMicro(toUnixMicro(tc))
|
||||
if !got.Equal(tc) {
|
||||
t.Errorf("got %q, want %q", got, tc)
|
||||
}
|
||||
}
|
||||
|
||||
// Test that a time.Time that isn't an integral number of microseconds
|
||||
// is not perfectly reconstructed after a round trip.
|
||||
t0 := time.Unix(0, 123)
|
||||
t1 := fromUnixMicro(toUnixMicro(t0))
|
||||
if t1.Nanosecond()%1000 != 0 || t0.Nanosecond()%1000 == 0 {
|
||||
t.Errorf("quantization to µs: got %q with %d ns, started with %d ns", t1, t1.Nanosecond(), t0.Nanosecond())
|
||||
}
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
pb "google.golang.org/appengine/internal/datastore"
|
||||
)
|
||||
|
||||
func init() {
|
||||
internal.RegisterTransactionSetter(func(x *pb.Query, t *pb.Transaction) {
|
||||
x.Transaction = t
|
||||
})
|
||||
internal.RegisterTransactionSetter(func(x *pb.GetRequest, t *pb.Transaction) {
|
||||
x.Transaction = t
|
||||
})
|
||||
internal.RegisterTransactionSetter(func(x *pb.PutRequest, t *pb.Transaction) {
|
||||
x.Transaction = t
|
||||
})
|
||||
internal.RegisterTransactionSetter(func(x *pb.DeleteRequest, t *pb.Transaction) {
|
||||
x.Transaction = t
|
||||
})
|
||||
}
|
||||
|
||||
// ErrConcurrentTransaction is returned when a transaction is rolled back due
|
||||
// to a conflict with a concurrent transaction.
|
||||
var ErrConcurrentTransaction = errors.New("datastore: concurrent transaction")
|
||||
|
||||
// RunInTransaction runs f in a transaction. It calls f with a transaction
|
||||
// context tc that f should use for all App Engine operations.
|
||||
//
|
||||
// If f returns nil, RunInTransaction attempts to commit the transaction,
|
||||
// returning nil if it succeeds. If the commit fails due to a conflicting
|
||||
// transaction, RunInTransaction retries f, each time with a new transaction
|
||||
// context. It gives up and returns ErrConcurrentTransaction after three
|
||||
// failed attempts. The number of attempts can be configured by specifying
|
||||
// TransactionOptions.Attempts.
|
||||
//
|
||||
// If f returns non-nil, then any datastore changes will not be applied and
|
||||
// RunInTransaction returns that same error. The function f is not retried.
|
||||
//
|
||||
// Note that when f returns, the transaction is not yet committed. Calling code
|
||||
// must be careful not to assume that any of f's changes have been committed
|
||||
// until RunInTransaction returns nil.
|
||||
//
|
||||
// Since f may be called multiple times, f should usually be idempotent.
|
||||
// datastore.Get is not idempotent when unmarshaling slice fields.
|
||||
//
|
||||
// Nested transactions are not supported; c may not be a transaction context.
|
||||
func RunInTransaction(c context.Context, f func(tc context.Context) error, opts *TransactionOptions) error {
|
||||
xg := false
|
||||
if opts != nil {
|
||||
xg = opts.XG
|
||||
}
|
||||
attempts := 3
|
||||
if opts != nil && opts.Attempts > 0 {
|
||||
attempts = opts.Attempts
|
||||
}
|
||||
for i := 0; i < attempts; i++ {
|
||||
if err := internal.RunTransactionOnce(c, f, xg); err != internal.ErrConcurrentTransaction {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return ErrConcurrentTransaction
|
||||
}
|
||||
|
||||
// TransactionOptions are the options for running a transaction.
|
||||
type TransactionOptions struct {
|
||||
// XG is whether the transaction can cross multiple entity groups. In
|
||||
// comparison, a single group transaction is one where all datastore keys
|
||||
// used have the same root key. Note that cross group transactions do not
|
||||
// have the same behavior as single group transactions. In particular, it
|
||||
// is much more likely to see partially applied transactions in different
|
||||
// entity groups, in global queries.
|
||||
// It is valid to set XG to true even if the transaction is within a
|
||||
// single entity group.
|
||||
XG bool
|
||||
// Attempts controls the number of retries to perform when commits fail
|
||||
// due to a conflicting transaction. If omitted, it defaults to 3.
|
||||
Attempts int
|
||||
}
|
|
@ -0,0 +1,278 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
/*
|
||||
Package delay provides a way to execute code outside the scope of a
|
||||
user request by using the taskqueue API.
|
||||
|
||||
To declare a function that may be executed later, call Func
|
||||
in a top-level assignment context, passing it an arbitrary string key
|
||||
and a function whose first argument is of type context.Context.
|
||||
var laterFunc = delay.Func("key", myFunc)
|
||||
It is also possible to use a function literal.
|
||||
var laterFunc = delay.Func("key", func(c context.Context, x string) {
|
||||
// ...
|
||||
})
|
||||
|
||||
To call a function, invoke its Call method.
|
||||
laterFunc.Call(c, "something")
|
||||
A function may be called any number of times. If the function has any
|
||||
return arguments, and the last one is of type error, the function may
|
||||
return a non-nil error to signal that the function should be retried.
|
||||
|
||||
The arguments to functions may be of any type that is encodable by the gob
|
||||
package. If an argument is of interface type, it is the client's responsibility
|
||||
to register with the gob package whatever concrete type may be passed for that
|
||||
argument; see http://golang.org/pkg/gob/#Register for details.
|
||||
|
||||
Any errors during initialization or execution of a function will be
|
||||
logged to the application logs. Error logs that occur during initialization will
|
||||
be associated with the request that invoked the Call method.
|
||||
|
||||
The state of a function invocation that has not yet successfully
|
||||
executed is preserved by combining the file name in which it is declared
|
||||
with the string key that was passed to the Func function. Updating an app
|
||||
with pending function invocations is safe as long as the relevant
|
||||
functions have the (filename, key) combination preserved.
|
||||
|
||||
The delay package uses the Task Queue API to create tasks that call the
|
||||
reserved application path "/_ah/queue/go/delay".
|
||||
This path must not be marked as "login: required" in app.yaml;
|
||||
it must be marked as "login: admin" or have no access restriction.
|
||||
*/
|
||||
package delay // import "google.golang.org/appengine/delay"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"runtime"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/log"
|
||||
"google.golang.org/appengine/taskqueue"
|
||||
)
|
||||
|
||||
// Function represents a function that may have a delayed invocation.
|
||||
type Function struct {
|
||||
fv reflect.Value // Kind() == reflect.Func
|
||||
key string
|
||||
err error // any error during initialization
|
||||
}
|
||||
|
||||
const (
|
||||
// The HTTP path for invocations.
|
||||
path = "/_ah/queue/go/delay"
|
||||
// Use the default queue.
|
||||
queue = ""
|
||||
)
|
||||
|
||||
var (
|
||||
// registry of all delayed functions
|
||||
funcs = make(map[string]*Function)
|
||||
|
||||
// precomputed types
|
||||
contextType = reflect.TypeOf((*context.Context)(nil)).Elem()
|
||||
errorType = reflect.TypeOf((*error)(nil)).Elem()
|
||||
|
||||
// errors
|
||||
errFirstArg = errors.New("first argument must be context.Context")
|
||||
)
|
||||
|
||||
// Func declares a new Function. The second argument must be a function with a
|
||||
// first argument of type context.Context.
|
||||
// This function must be called at program initialization time. That means it
|
||||
// must be called in a global variable declaration or from an init function.
|
||||
// This restriction is necessary because the instance that delays a function
|
||||
// call may not be the one that executes it. Only the code executed at program
|
||||
// initialization time is guaranteed to have been run by an instance before it
|
||||
// receives a request.
|
||||
func Func(key string, i interface{}) *Function {
|
||||
f := &Function{fv: reflect.ValueOf(i)}
|
||||
|
||||
// Derive unique, somewhat stable key for this func.
|
||||
_, file, _, _ := runtime.Caller(1)
|
||||
f.key = file + ":" + key
|
||||
|
||||
t := f.fv.Type()
|
||||
if t.Kind() != reflect.Func {
|
||||
f.err = errors.New("not a function")
|
||||
return f
|
||||
}
|
||||
if t.NumIn() == 0 || t.In(0) != contextType {
|
||||
f.err = errFirstArg
|
||||
return f
|
||||
}
|
||||
|
||||
// Register the function's arguments with the gob package.
|
||||
// This is required because they are marshaled inside a []interface{}.
|
||||
// gob.Register only expects to be called during initialization;
|
||||
// that's fine because this function expects the same.
|
||||
for i := 0; i < t.NumIn(); i++ {
|
||||
// Only concrete types may be registered. If the argument has
|
||||
// interface type, the client is resposible for registering the
|
||||
// concrete types it will hold.
|
||||
if t.In(i).Kind() == reflect.Interface {
|
||||
continue
|
||||
}
|
||||
gob.Register(reflect.Zero(t.In(i)).Interface())
|
||||
}
|
||||
|
||||
if old := funcs[f.key]; old != nil {
|
||||
old.err = fmt.Errorf("multiple functions registered for %s in %s", key, file)
|
||||
}
|
||||
funcs[f.key] = f
|
||||
return f
|
||||
}
|
||||
|
||||
type invocation struct {
|
||||
Key string
|
||||
Args []interface{}
|
||||
}
|
||||
|
||||
// Call invokes a delayed function.
|
||||
// err := f.Call(c, ...)
|
||||
// is equivalent to
|
||||
// t, _ := f.Task(...)
|
||||
// _, err := taskqueue.Add(c, t, "")
|
||||
func (f *Function) Call(c context.Context, args ...interface{}) error {
|
||||
t, err := f.Task(args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = taskqueueAdder(c, t, queue)
|
||||
return err
|
||||
}
|
||||
|
||||
// Task creates a Task that will invoke the function.
|
||||
// Its parameters may be tweaked before adding it to a queue.
|
||||
// Users should not modify the Path or Payload fields of the returned Task.
|
||||
func (f *Function) Task(args ...interface{}) (*taskqueue.Task, error) {
|
||||
if f.err != nil {
|
||||
return nil, fmt.Errorf("delay: func is invalid: %v", f.err)
|
||||
}
|
||||
|
||||
nArgs := len(args) + 1 // +1 for the context.Context
|
||||
ft := f.fv.Type()
|
||||
minArgs := ft.NumIn()
|
||||
if ft.IsVariadic() {
|
||||
minArgs--
|
||||
}
|
||||
if nArgs < minArgs {
|
||||
return nil, fmt.Errorf("delay: too few arguments to func: %d < %d", nArgs, minArgs)
|
||||
}
|
||||
if !ft.IsVariadic() && nArgs > minArgs {
|
||||
return nil, fmt.Errorf("delay: too many arguments to func: %d > %d", nArgs, minArgs)
|
||||
}
|
||||
|
||||
// Check arg types.
|
||||
for i := 1; i < nArgs; i++ {
|
||||
at := reflect.TypeOf(args[i-1])
|
||||
var dt reflect.Type
|
||||
if i < minArgs {
|
||||
// not a variadic arg
|
||||
dt = ft.In(i)
|
||||
} else {
|
||||
// a variadic arg
|
||||
dt = ft.In(minArgs).Elem()
|
||||
}
|
||||
// nil arguments won't have a type, so they need special handling.
|
||||
if at == nil {
|
||||
// nil interface
|
||||
switch dt.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
continue // may be nil
|
||||
}
|
||||
return nil, fmt.Errorf("delay: argument %d has wrong type: %v is not nilable", i, dt)
|
||||
}
|
||||
switch at.Kind() {
|
||||
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
|
||||
av := reflect.ValueOf(args[i-1])
|
||||
if av.IsNil() {
|
||||
// nil value in interface; not supported by gob, so we replace it
|
||||
// with a nil interface value
|
||||
args[i-1] = nil
|
||||
}
|
||||
}
|
||||
if !at.AssignableTo(dt) {
|
||||
return nil, fmt.Errorf("delay: argument %d has wrong type: %v is not assignable to %v", i, at, dt)
|
||||
}
|
||||
}
|
||||
|
||||
inv := invocation{
|
||||
Key: f.key,
|
||||
Args: args,
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
if err := gob.NewEncoder(buf).Encode(inv); err != nil {
|
||||
return nil, fmt.Errorf("delay: gob encoding failed: %v", err)
|
||||
}
|
||||
|
||||
return &taskqueue.Task{
|
||||
Path: path,
|
||||
Payload: buf.Bytes(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
var taskqueueAdder = taskqueue.Add // for testing
|
||||
|
||||
func init() {
|
||||
http.HandleFunc(path, func(w http.ResponseWriter, req *http.Request) {
|
||||
runFunc(appengine.NewContext(req), w, req)
|
||||
})
|
||||
}
|
||||
|
||||
func runFunc(c context.Context, w http.ResponseWriter, req *http.Request) {
|
||||
defer req.Body.Close()
|
||||
|
||||
var inv invocation
|
||||
if err := gob.NewDecoder(req.Body).Decode(&inv); err != nil {
|
||||
log.Errorf(c, "delay: failed decoding task payload: %v", err)
|
||||
log.Warningf(c, "delay: dropping task")
|
||||
return
|
||||
}
|
||||
|
||||
f := funcs[inv.Key]
|
||||
if f == nil {
|
||||
log.Errorf(c, "delay: no func with key %q found", inv.Key)
|
||||
log.Warningf(c, "delay: dropping task")
|
||||
return
|
||||
}
|
||||
|
||||
ft := f.fv.Type()
|
||||
in := []reflect.Value{reflect.ValueOf(c)}
|
||||
for _, arg := range inv.Args {
|
||||
var v reflect.Value
|
||||
if arg != nil {
|
||||
v = reflect.ValueOf(arg)
|
||||
} else {
|
||||
// Task was passed a nil argument, so we must construct
|
||||
// the zero value for the argument here.
|
||||
n := len(in) // we're constructing the nth argument
|
||||
var at reflect.Type
|
||||
if !ft.IsVariadic() || n < ft.NumIn()-1 {
|
||||
at = ft.In(n)
|
||||
} else {
|
||||
at = ft.In(ft.NumIn() - 1).Elem()
|
||||
}
|
||||
v = reflect.Zero(at)
|
||||
}
|
||||
in = append(in, v)
|
||||
}
|
||||
out := f.fv.Call(in)
|
||||
|
||||
if n := ft.NumOut(); n > 0 && ft.Out(n-1) == errorType {
|
||||
if errv := out[n-1]; !errv.IsNil() {
|
||||
log.Errorf(c, "delay: func failed (will retry): %v", errv.Interface())
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,375 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package delay
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/protobuf/proto"
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine/internal"
|
||||
"google.golang.org/appengine/taskqueue"
|
||||
)
|
||||
|
||||
type CustomType struct {
|
||||
N int
|
||||
}
|
||||
|
||||
type CustomInterface interface {
|
||||
N() int
|
||||
}
|
||||
|
||||
type CustomImpl int
|
||||
|
||||
func (c CustomImpl) N() int { return int(c) }
|
||||
|
||||
// CustomImpl needs to be registered with gob.
|
||||
func init() {
|
||||
gob.Register(CustomImpl(0))
|
||||
}
|
||||
|
||||
var (
|
||||
invalidFunc = Func("invalid", func() {})
|
||||
|
||||
regFuncRuns = 0
|
||||
regFuncMsg = ""
|
||||
regFunc = Func("reg", func(c context.Context, arg string) {
|
||||
regFuncRuns++
|
||||
regFuncMsg = arg
|
||||
})
|
||||
|
||||
custFuncTally = 0
|
||||
custFunc = Func("cust", func(c context.Context, ct *CustomType, ci CustomInterface) {
|
||||
a, b := 2, 3
|
||||
if ct != nil {
|
||||
a = ct.N
|
||||
}
|
||||
if ci != nil {
|
||||
b = ci.N()
|
||||
}
|
||||
custFuncTally += a + b
|
||||
})
|
||||
|
||||
anotherCustFunc = Func("cust2", func(c context.Context, n int, ct *CustomType, ci CustomInterface) {
|
||||
})
|
||||
|
||||
varFuncMsg = ""
|
||||
varFunc = Func("variadic", func(c context.Context, format string, args ...int) {
|
||||
// convert []int to []interface{} for fmt.Sprintf.
|
||||
as := make([]interface{}, len(args))
|
||||
for i, a := range args {
|
||||
as[i] = a
|
||||
}
|
||||
varFuncMsg = fmt.Sprintf(format, as...)
|
||||
})
|
||||
|
||||
errFuncRuns = 0
|
||||
errFuncErr = errors.New("error!")
|
||||
errFunc = Func("err", func(c context.Context) error {
|
||||
errFuncRuns++
|
||||
if errFuncRuns == 1 {
|
||||
return nil
|
||||
}
|
||||
return errFuncErr
|
||||
})
|
||||
|
||||
dupeWhich = 0
|
||||
dupe1Func = Func("dupe", func(c context.Context) {
|
||||
if dupeWhich == 0 {
|
||||
dupeWhich = 1
|
||||
}
|
||||
})
|
||||
dupe2Func = Func("dupe", func(c context.Context) {
|
||||
if dupeWhich == 0 {
|
||||
dupeWhich = 2
|
||||
}
|
||||
})
|
||||
)
|
||||
|
||||
type fakeContext struct {
|
||||
ctx context.Context
|
||||
logging [][]interface{}
|
||||
}
|
||||
|
||||
func newFakeContext() *fakeContext {
|
||||
f := new(fakeContext)
|
||||
f.ctx = internal.WithCallOverride(context.Background(), f.call)
|
||||
f.ctx = internal.WithLogOverride(f.ctx, f.logf)
|
||||
return f
|
||||
}
|
||||
|
||||
func (f *fakeContext) call(ctx context.Context, service, method string, in, out proto.Message) error {
|
||||
panic("should never be called")
|
||||
}
|
||||
|
||||
var logLevels = map[int64]string{1: "INFO", 3: "ERROR"}
|
||||
|
||||
func (f *fakeContext) logf(level int64, format string, args ...interface{}) {
|
||||
f.logging = append(f.logging, append([]interface{}{logLevels[level], format}, args...))
|
||||
}
|
||||
|
||||
func TestInvalidFunction(t *testing.T) {
|
||||
c := newFakeContext()
|
||||
|
||||
if got, want := invalidFunc.Call(c.ctx), fmt.Errorf("delay: func is invalid: %s", errFirstArg); got.Error() != want.Error() {
|
||||
t.Errorf("Incorrect error: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVariadicFunctionArguments(t *testing.T) {
|
||||
// Check the argument type validation for variadic functions.
|
||||
|
||||
c := newFakeContext()
|
||||
|
||||
calls := 0
|
||||
taskqueueAdder = func(c context.Context, t *taskqueue.Task, _ string) (*taskqueue.Task, error) {
|
||||
calls++
|
||||
return t, nil
|
||||
}
|
||||
|
||||
varFunc.Call(c.ctx, "hi")
|
||||
varFunc.Call(c.ctx, "%d", 12)
|
||||
varFunc.Call(c.ctx, "%d %d %d", 3, 1, 4)
|
||||
if calls != 3 {
|
||||
t.Errorf("Got %d calls to taskqueueAdder, want 3", calls)
|
||||
}
|
||||
|
||||
if got, want := varFunc.Call(c.ctx, "%d %s", 12, "a string is bad"), errors.New("delay: argument 3 has wrong type: string is not assignable to int"); got.Error() != want.Error() {
|
||||
t.Errorf("Incorrect error: got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBadArguments(t *testing.T) {
|
||||
// Try running regFunc with different sets of inappropriate arguments.
|
||||
|
||||
c := newFakeContext()
|
||||
|
||||
tests := []struct {
|
||||
args []interface{} // all except context
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
args: nil,
|
||||
wantErr: "delay: too few arguments to func: 1 < 2",
|
||||
},
|
||||
{
|
||||
args: []interface{}{"lala", 53},
|
||||
wantErr: "delay: too many arguments to func: 3 > 2",
|
||||
},
|
||||
{
|
||||
args: []interface{}{53},
|
||||
wantErr: "delay: argument 1 has wrong type: int is not assignable to string",
|
||||
},
|
||||
}
|
||||
for i, tc := range tests {
|
||||
got := regFunc.Call(c.ctx, tc.args...)
|
||||
if got.Error() != tc.wantErr {
|
||||
t.Errorf("Call %v: got %q, want %q", i, got, tc.wantErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunningFunction(t *testing.T) {
|
||||
c := newFakeContext()
|
||||
|
||||
// Fake out the adding of a task.
|
||||
var task *taskqueue.Task
|
||||
taskqueueAdder = func(_ context.Context, tk *taskqueue.Task, queue string) (*taskqueue.Task, error) {
|
||||
if queue != "" {
|
||||
t.Errorf(`Got queue %q, expected ""`, queue)
|
||||
}
|
||||
task = tk
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
regFuncRuns, regFuncMsg = 0, "" // reset state
|
||||
const msg = "Why, hello!"
|
||||
regFunc.Call(c.ctx, msg)
|
||||
|
||||
// Simulate the Task Queue service.
|
||||
req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed making http.Request: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
runFunc(c.ctx, rw, req)
|
||||
|
||||
if regFuncRuns != 1 {
|
||||
t.Errorf("regFuncRuns: got %d, want 1", regFuncRuns)
|
||||
}
|
||||
if regFuncMsg != msg {
|
||||
t.Errorf("regFuncMsg: got %q, want %q", regFuncMsg, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomType(t *testing.T) {
|
||||
c := newFakeContext()
|
||||
|
||||
// Fake out the adding of a task.
|
||||
var task *taskqueue.Task
|
||||
taskqueueAdder = func(_ context.Context, tk *taskqueue.Task, queue string) (*taskqueue.Task, error) {
|
||||
if queue != "" {
|
||||
t.Errorf(`Got queue %q, expected ""`, queue)
|
||||
}
|
||||
task = tk
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
custFuncTally = 0 // reset state
|
||||
custFunc.Call(c.ctx, &CustomType{N: 11}, CustomImpl(13))
|
||||
|
||||
// Simulate the Task Queue service.
|
||||
req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed making http.Request: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
runFunc(c.ctx, rw, req)
|
||||
|
||||
if custFuncTally != 24 {
|
||||
t.Errorf("custFuncTally = %d, want 24", custFuncTally)
|
||||
}
|
||||
|
||||
// Try the same, but with nil values; one is a nil pointer (and thus a non-nil interface value),
|
||||
// and the other is a nil interface value.
|
||||
custFuncTally = 0 // reset state
|
||||
custFunc.Call(c.ctx, (*CustomType)(nil), nil)
|
||||
|
||||
// Simulate the Task Queue service.
|
||||
req, err = http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed making http.Request: %v", err)
|
||||
}
|
||||
rw = httptest.NewRecorder()
|
||||
runFunc(c.ctx, rw, req)
|
||||
|
||||
if custFuncTally != 5 {
|
||||
t.Errorf("custFuncTally = %d, want 5", custFuncTally)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunningVariadic(t *testing.T) {
|
||||
c := newFakeContext()
|
||||
|
||||
// Fake out the adding of a task.
|
||||
var task *taskqueue.Task
|
||||
taskqueueAdder = func(_ context.Context, tk *taskqueue.Task, queue string) (*taskqueue.Task, error) {
|
||||
if queue != "" {
|
||||
t.Errorf(`Got queue %q, expected ""`, queue)
|
||||
}
|
||||
task = tk
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
varFuncMsg = "" // reset state
|
||||
varFunc.Call(c.ctx, "Amiga %d has %d KB RAM", 500, 512)
|
||||
|
||||
// Simulate the Task Queue service.
|
||||
req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed making http.Request: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
runFunc(c.ctx, rw, req)
|
||||
|
||||
const expected = "Amiga 500 has 512 KB RAM"
|
||||
if varFuncMsg != expected {
|
||||
t.Errorf("varFuncMsg = %q, want %q", varFuncMsg, expected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorFunction(t *testing.T) {
|
||||
c := newFakeContext()
|
||||
|
||||
// Fake out the adding of a task.
|
||||
var task *taskqueue.Task
|
||||
taskqueueAdder = func(_ context.Context, tk *taskqueue.Task, queue string) (*taskqueue.Task, error) {
|
||||
if queue != "" {
|
||||
t.Errorf(`Got queue %q, expected ""`, queue)
|
||||
}
|
||||
task = tk
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
errFunc.Call(c.ctx)
|
||||
|
||||
// Simulate the Task Queue service.
|
||||
// The first call should succeed; the second call should fail.
|
||||
{
|
||||
req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed making http.Request: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
runFunc(c.ctx, rw, req)
|
||||
}
|
||||
{
|
||||
req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed making http.Request: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
runFunc(c.ctx, rw, req)
|
||||
if rw.Code != http.StatusInternalServerError {
|
||||
t.Errorf("Got status code %d, want %d", rw.Code, http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
wantLogging := [][]interface{}{
|
||||
{"ERROR", "delay: func failed (will retry): %v", errFuncErr},
|
||||
}
|
||||
if !reflect.DeepEqual(c.logging, wantLogging) {
|
||||
t.Errorf("Incorrect logging: got %+v, want %+v", c.logging, wantLogging)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateFunction(t *testing.T) {
|
||||
c := newFakeContext()
|
||||
|
||||
// Fake out the adding of a task.
|
||||
var task *taskqueue.Task
|
||||
taskqueueAdder = func(_ context.Context, tk *taskqueue.Task, queue string) (*taskqueue.Task, error) {
|
||||
if queue != "" {
|
||||
t.Errorf(`Got queue %q, expected ""`, queue)
|
||||
}
|
||||
task = tk
|
||||
return tk, nil
|
||||
}
|
||||
|
||||
if err := dupe1Func.Call(c.ctx); err == nil {
|
||||
t.Error("dupe1Func.Call did not return error")
|
||||
}
|
||||
if task != nil {
|
||||
t.Error("dupe1Func.Call posted a task")
|
||||
}
|
||||
if err := dupe2Func.Call(c.ctx); err != nil {
|
||||
t.Errorf("dupe2Func.Call error: %v", err)
|
||||
}
|
||||
if task == nil {
|
||||
t.Fatalf("dupe2Func.Call did not post a task")
|
||||
}
|
||||
|
||||
// Simulate the Task Queue service.
|
||||
req, err := http.NewRequest("POST", path, bytes.NewBuffer(task.Payload))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed making http.Request: %v", err)
|
||||
}
|
||||
rw := httptest.NewRecorder()
|
||||
runFunc(c.ctx, rw, req)
|
||||
|
||||
if dupeWhich == 1 {
|
||||
t.Error("dupe2Func.Call used old registered function")
|
||||
} else if dupeWhich != 2 {
|
||||
t.Errorf("dupeWhich = %d; want 2", dupeWhich)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
# Demo application for App Engine "flexible environment".
|
||||
runtime: go
|
||||
vm: true
|
||||
api_version: go1
|
||||
|
||||
handlers:
|
||||
# Favicon. Without this, the browser hits this once per page view.
|
||||
- url: /favicon.ico
|
||||
static_files: favicon.ico
|
||||
upload: favicon.ico
|
||||
|
||||
# Main app. All the real work is here.
|
||||
- url: /.*
|
||||
script: _go_app
|
Binary file not shown.
After Width: | Height: | Size: 1.1 KiB |
|
@ -0,0 +1,109 @@
|
|||
// Copyright 2011 Google Inc. All rights reserved.
|
||||
// Use of this source code is governed by the Apache 2.0
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// This example only works on App Engine "flexible environment".
|
||||
// +build !appengine
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"html/template"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/context"
|
||||
|
||||
"google.golang.org/appengine"
|
||||
"google.golang.org/appengine/datastore"
|
||||
"google.golang.org/appengine/log"
|
||||
"google.golang.org/appengine/user"
|
||||
)
|
||||
|
||||
var initTime time.Time
|
||||
|
||||
type Greeting struct {
|
||||
Author string
|
||||
Content string
|
||||
Date time.Time
|
||||
}
|
||||
|
||||
func main() {
|
||||
http.HandleFunc("/", handleMainPage)
|
||||
http.HandleFunc("/sign", handleSign)
|
||||
appengine.Main()
|
||||
}
|
||||
|
||||
// guestbookKey returns the key used for all guestbook entries.
|
||||
func guestbookKey(ctx context.Context) *datastore.Key {
|
||||
// The string "default_guestbook" here could be varied to have multiple guestbooks.
|
||||
return datastore.NewKey(ctx, "Guestbook", "default_guestbook", 0, nil)
|
||||
}
|
||||
|
||||
var tpl = template.Must(template.ParseGlob("templates/*.html"))
|
||||
|
||||
func handleMainPage(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
http.Error(w, "GET requests only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := appengine.NewContext(r)
|
||||
tic := time.Now()
|
||||
q := datastore.NewQuery("Greeting").Ancestor(guestbookKey(ctx)).Order("-Date").Limit(10)
|
||||
var gg []*Greeting
|
||||
if _, err := q.GetAll(ctx, &gg); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
log.Errorf(ctx, "GetAll: %v", err)
|
||||
return
|
||||
}
|
||||
log.Infof(ctx, "Datastore lookup took %s", time.Since(tic).String())
|
||||
log.Infof(ctx, "Rendering %d greetings", len(gg))
|
||||
|
||||
var email, logout, login string
|
||||
if u := user.Current(ctx); u != nil {
|
||||
logout, _ = user.LogoutURL(ctx, "/")
|
||||
email = u.Email
|
||||
} else {
|
||||
login, _ = user.LoginURL(ctx, "/")
|
||||
}
|
||||
data := struct {
|
||||
Greetings []*Greeting
|
||||
Login, Logout, Email string
|
||||
}{
|
||||
Greetings: gg,
|
||||
Login: login,
|
||||
Logout: logout,
|
||||
Email: email,
|
||||
}
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := tpl.ExecuteTemplate(w, "guestbook.html", data); err != nil {
|
||||
log.Errorf(ctx, "%v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func handleSign(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "POST requests only", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
ctx := appengine.NewContext(r)
|
||||
g := &Greeting{
|
||||
Content: r.FormValue("content"),
|
||||
Date: time.Now(),
|
||||
}
|
||||
if u := user.Current(ctx); u != nil {
|
||||
g.Author = u.String()
|
||||
}
|
||||
key := datastore.NewIncompleteKey(ctx, "Greeting", guestbookKey(ctx))
|
||||
if _, err := datastore.Put(ctx, key, g); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// Redirect with 303 which causes the subsequent request to use GET.
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
indexes:
|
||||
|
||||
- kind: Greeting
|
||||
ancestor: yes
|
||||
properties:
|
||||
- name: Date
|
||||
direction: desc
|
26
vendor/google.golang.org/appengine/demos/guestbook/templates/guestbook.html
generated
vendored
Normal file
26
vendor/google.golang.org/appengine/demos/guestbook/templates/guestbook.html
generated
vendored
Normal file
|
@ -0,0 +1,26 @@
|
|||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Guestbook Demo</title>
|
||||
</head>
|
||||
<body>
|
||||
<p>
|
||||
{{with .Email}}You are currently logged in as {{.}}.{{end}}
|
||||
{{with .Login}}<a href="{{.}}">Sign in</a>{{end}}
|
||||
{{with .Logout}}<a href="{{.}}">Sign out</a>{{end}}
|
||||
</p>
|
||||
|
||||
{{range .Greetings }}
|
||||
<p>
|
||||
{{with .Author}}<b>{{.}}</b>{{else}}An anonymous person{{end}}
|
||||
on <em>{{.Date.Format "3:04pm, Mon 2 Jan"}}</em>
|
||||
wrote <blockquote>{{.Content}}</blockquote>
|
||||
</p>
|
||||
{{end}}
|
||||
|
||||
<form action="/sign" method="post">
|
||||
<div><textarea name="content" rows="3" cols="60"></textarea></div>
|
||||
<div><input type="submit" value="Sign Guestbook"></div>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
|
@ -0,0 +1,10 @@
|
|||
runtime: go
|
||||
api_version: go1
|
||||
vm: true
|
||||
|
||||
handlers:
|
||||
- url: /favicon.ico
|
||||
static_files: favicon.ico
|
||||
upload: favicon.ico
|
||||
- url: /.*
|
||||
script: _go_app
|
Binary file not shown.
After Width: | Height: | Size: 1.1 KiB |
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue