Use sqlx to simplify things
This commit is contained in:
parent
3c9d37f7fb
commit
54c43696ac
|
@ -19,6 +19,12 @@
|
|||
revision = "96dc06278ce32a0e9d957d590bb987c81ee66407"
|
||||
version = "v1.3.0"
|
||||
|
||||
[[projects]]
|
||||
branch = "master"
|
||||
name = "github.com/jmoiron/sqlx"
|
||||
packages = [".","reflectx"]
|
||||
revision = "99f3ad6d85ae53d0fecf788ab62d0e9734b3c117"
|
||||
|
||||
[[projects]]
|
||||
name = "github.com/mattn/go-sqlite3"
|
||||
packages = ["."]
|
||||
|
@ -41,7 +47,7 @@
|
|||
branch = "master"
|
||||
name = "golang.org/x/net"
|
||||
packages = ["context"]
|
||||
revision = "a8b9294777976932365dabb6640cf1468d95c70f"
|
||||
revision = "dc871a5d77e227f5bbf6545176ef3eeebf87e76e"
|
||||
|
||||
[[projects]]
|
||||
branch = "v2"
|
||||
|
@ -52,6 +58,6 @@
|
|||
[solve-meta]
|
||||
analyzer-name = "dep"
|
||||
analyzer-version = 1
|
||||
inputs-digest = "da08912cf0f9aa88d93fa8fb8cf01a421f16b7977841c275f883c0c6821adcca"
|
||||
inputs-digest = "2597d02f0d1ff0af313642458ae19f0dabc6e5464adc94013e82fa3285a75c4e"
|
||||
solver-name = "gps-cdcl"
|
||||
solver-version = 1
|
||||
|
|
14
anonircd.go
14
anonircd.go
|
@ -28,7 +28,8 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/jessevdk/go-flags"
|
||||
irc "gopkg.in/sorcix/irc.v2"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/sorcix/irc.v2"
|
||||
)
|
||||
|
||||
var prefixAnonymous = irc.Prefix{"Anonymous", "Anon", "IRC"}
|
||||
|
@ -45,11 +46,8 @@ const letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|||
const writebuffersize = 10
|
||||
|
||||
const (
|
||||
PERMISSION_USER = 0
|
||||
PERMISSION_SUPERADMIN = 1
|
||||
PERMISSION_ADMIN = 2
|
||||
PERMISSION_MODERATOR = 3
|
||||
PERMISSION_VIP = 4
|
||||
CHANNEL_LOBBY = "#"
|
||||
CHANNEL_SERVER = "&"
|
||||
)
|
||||
|
||||
var debugMode = false
|
||||
|
@ -67,7 +65,7 @@ func main() {
|
|||
|
||||
_, err := flags.Parse(&opts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", errors.Wrap(err, "failed to parse flags"))
|
||||
}
|
||||
|
||||
if opts.Debug > 0 {
|
||||
|
@ -85,7 +83,7 @@ func main() {
|
|||
s := NewServer(opts.ConfigFile)
|
||||
err = s.loadConfig()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", errors.Wrap(err, "failed to load configuration file"))
|
||||
}
|
||||
s.connectDatabase()
|
||||
defer s.closeDatabase()
|
||||
|
|
68
channel.go
68
channel.go
|
@ -2,16 +2,19 @@ package main
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gopkg.in/sorcix/irc.v2"
|
||||
)
|
||||
|
||||
type Channel struct {
|
||||
Entity
|
||||
|
||||
clients *sync.Map
|
||||
logs []*ChannelLog
|
||||
logs map[int64]*ChannelLog
|
||||
|
||||
topic string
|
||||
topictime int64
|
||||
|
@ -23,6 +26,7 @@ type ChannelLog struct {
|
|||
Timestamp int64
|
||||
Client string
|
||||
IP string
|
||||
Account int
|
||||
Action string
|
||||
Message string
|
||||
}
|
||||
|
@ -30,11 +34,11 @@ type ChannelLog struct {
|
|||
const CHANNEL_LOGS_PER_PAGE = 25
|
||||
|
||||
func (cl *ChannelLog) Identifier(index int) string {
|
||||
return fmt.Sprintf("%03d%02d", index+1, cl.Timestamp%100)
|
||||
return fmt.Sprintf("%03d%02d", index, cl.Timestamp%100)
|
||||
}
|
||||
|
||||
func (cl *ChannelLog) Print(index int, channel string) string {
|
||||
return strings.TrimSpace(fmt.Sprintf("%s %s %5s %4s %s", time.Unix(0, cl.Timestamp).Format(time.Stamp), channel, cl.Identifier(index), cl.Action, cl.Message))
|
||||
return strings.TrimSpace(fmt.Sprintf("%s %s %s %4s %s", time.Unix(0, cl.Timestamp).Format(time.Stamp), channel, cl.Identifier(index), cl.Action, cl.Message))
|
||||
}
|
||||
|
||||
func NewChannel(identifier string) *Channel {
|
||||
|
@ -42,6 +46,7 @@ func NewChannel(identifier string) *Channel {
|
|||
c.Initialize(ENTITY_CHANNEL, identifier)
|
||||
|
||||
c.clients = new(sync.Map)
|
||||
c.logs = make(map[int64]*ChannelLog)
|
||||
|
||||
return c
|
||||
}
|
||||
|
@ -50,13 +55,14 @@ func (c *Channel) Log(client *Client, action string, message string) {
|
|||
c.Lock()
|
||||
defer c.Unlock()
|
||||
|
||||
// TODO: Log size limiting, max capacity will be 998 entries
|
||||
// TODO: Log size limiting, max capacity will be 999 entries
|
||||
// Log hash of IP address which is used later when connecting/joining
|
||||
|
||||
c.logs = append(c.logs, &ChannelLog{Timestamp: time.Now().UTC().UnixNano(), Client: client.identifier, IP: client.ip, Action: action, Message: message})
|
||||
nano := time.Now().UTC().UnixNano()
|
||||
c.logs[nano] = &ChannelLog{Timestamp: nano, Client: client.identifier, IP: client.iphash, Account: client.account, Action: action, Message: message}
|
||||
}
|
||||
|
||||
func (c *Channel) RevealLog(page int) []string {
|
||||
func (c *Channel) RevealLog(page int, full bool) []string {
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
|
@ -66,14 +72,30 @@ func (c *Channel) RevealLog(page int) []string {
|
|||
var ls []string
|
||||
logsRemain := false
|
||||
j := 0
|
||||
for i, l := range c.logs {
|
||||
|
||||
var nanos int64arr
|
||||
for n := range c.logs {
|
||||
nanos = append(nanos, n)
|
||||
}
|
||||
sort.Sort(nanos)
|
||||
|
||||
// To perform the opertion you want
|
||||
var l *ChannelLog
|
||||
var ok bool
|
||||
for i, nano := range nanos {
|
||||
if l, ok = c.logs[nano]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if page == -1 || i >= (CHANNEL_LOGS_PER_PAGE*(page-1)) {
|
||||
if page > -1 && j == CHANNEL_LOGS_PER_PAGE {
|
||||
logsRemain = true
|
||||
break
|
||||
if full || (l.Action != irc.JOIN && l.Action != irc.PART) {
|
||||
if page > -1 && j == CHANNEL_LOGS_PER_PAGE {
|
||||
logsRemain = true
|
||||
break
|
||||
}
|
||||
ls = append(ls, l.Print(i, c.identifier))
|
||||
j++
|
||||
}
|
||||
ls = append(ls, l.Print(i, c.identifier))
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,21 +118,31 @@ func (c *Channel) RevealLog(page int) []string {
|
|||
return ls
|
||||
}
|
||||
|
||||
func (c *Channel) RevealHash(identifier string) string {
|
||||
func (c *Channel) RevealInfo(identifier string) (string, int) {
|
||||
if len(identifier) != 5 {
|
||||
return ""
|
||||
return "", 0
|
||||
}
|
||||
|
||||
c.RLock()
|
||||
defer c.RUnlock()
|
||||
|
||||
for i, l := range c.logs {
|
||||
if l.Identifier(i) == identifier {
|
||||
return l.IP
|
||||
var nanos int64arr
|
||||
for n := range c.logs {
|
||||
nanos = append(nanos, n)
|
||||
}
|
||||
sort.Sort(nanos)
|
||||
|
||||
var l *ChannelLog
|
||||
var ok bool
|
||||
for i, nano := range nanos {
|
||||
if l, ok = c.logs[nano]; !ok {
|
||||
continue
|
||||
} else if l.Identifier(i) == identifier {
|
||||
return l.IP, l.Account
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
return "", 0
|
||||
}
|
||||
|
||||
func (c *Channel) HasClient(client string) bool {
|
||||
|
|
119
client.go
119
client.go
|
@ -1,16 +1,21 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"sync"
|
||||
|
||||
"strings"
|
||||
|
||||
"fmt"
|
||||
|
||||
irc "gopkg.in/sorcix/irc.v2"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
Entity
|
||||
ip string
|
||||
iphash string
|
||||
|
||||
ssl bool
|
||||
nick string
|
||||
|
@ -39,9 +44,7 @@ func NewClient(identifier string, conn net.Conn, ssl bool) *Client {
|
|||
return nil
|
||||
}
|
||||
|
||||
c.ip = generateHash(ip)
|
||||
// TODO: Check bans, return nil
|
||||
|
||||
c.iphash = generateHash(ip)
|
||||
c.ssl = ssl
|
||||
c.nick = "*"
|
||||
c.conn = conn
|
||||
|
@ -53,8 +56,21 @@ func NewClient(identifier string, conn net.Conn, ssl bool) *Client {
|
|||
return c
|
||||
}
|
||||
|
||||
func (c *Client) getAccount() (*DBAccount, error) {
|
||||
if c.account == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
acc, err := db.Account(c.account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &acc, nil
|
||||
}
|
||||
|
||||
func (c *Client) registered() bool {
|
||||
// TODO
|
||||
// TODO get account and check if it is valid
|
||||
return c.account > 0
|
||||
}
|
||||
|
||||
|
@ -90,6 +106,95 @@ func (c *Client) sendNotice(message string) {
|
|||
c.sendMessage("*** " + message)
|
||||
}
|
||||
|
||||
func (c *Client) accessDenied() {
|
||||
c.sendNotice("Access denied")
|
||||
func (c *Client) accessDenied(permissionRequired int) {
|
||||
ex := ""
|
||||
if permissionRequired > PERMISSION_CLIENT {
|
||||
ex = fmt.Sprintf(", that command is available to %ss only", strings.ToLower(permissionLabels[permissionRequired]))
|
||||
if permissionRequired == PERMISSION_REGISTERED {
|
||||
ex += " - Reply HELP for more info (see REGISTER and IDENTIFY)"
|
||||
}
|
||||
}
|
||||
|
||||
c.sendNotice("Access denied" + ex)
|
||||
}
|
||||
|
||||
func (c *Client) identify(username string, password string) bool {
|
||||
accountid, err := db.Auth(username, password)
|
||||
if err != nil {
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
|
||||
account, err := db.Account(accountid)
|
||||
if err != nil {
|
||||
log.Panicf("%+v", err)
|
||||
} else if account.ID == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
c.account = accountid
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) getPermission(channel string) int {
|
||||
if c.account == 0 {
|
||||
return PERMISSION_CLIENT
|
||||
}
|
||||
|
||||
p, err := db.GetPermission(c.account, channel)
|
||||
if err != nil {
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
|
||||
return p.Permission
|
||||
}
|
||||
|
||||
func (c *Client) globalPermission() int {
|
||||
return c.getPermission("&")
|
||||
}
|
||||
|
||||
func (c *Client) canUse(command string, channel string) bool {
|
||||
command = strings.ToUpper(command)
|
||||
req := c.permissionRequired(command)
|
||||
|
||||
globalPermission := c.globalPermission()
|
||||
if globalPermission >= req {
|
||||
return true
|
||||
} else if containsString(serverCommands, command) {
|
||||
return false
|
||||
}
|
||||
|
||||
return c.getPermission(channel) >= req
|
||||
}
|
||||
|
||||
func (c *Client) permissionRequired(command string) int {
|
||||
command = strings.ToUpper(command)
|
||||
for permissionRequired, commands := range commandRestrictions {
|
||||
for _, cmd := range commands {
|
||||
if cmd == command {
|
||||
return permissionRequired
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (c *Client) isBanned(channel string) (bool, string) {
|
||||
b, err := db.BanAddr(c.iphash, channel)
|
||||
if err != nil {
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
|
||||
if b.Channel == "" && c.account > 0 {
|
||||
b, err = db.BanAccount(c.account, channel)
|
||||
if err != nil {
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if b.Channel != "" {
|
||||
return true, b.Reason
|
||||
}
|
||||
|
||||
return false, ""
|
||||
}
|
||||
|
|
252
database.go
252
database.go
|
@ -1,21 +1,23 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/securecookie"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const DATABASE_VERSION = 1
|
||||
|
||||
var ErrAccountExists = errors.New("account exists")
|
||||
var ErrChannelExists = errors.New("channel exists")
|
||||
var ErrAccountExists = errors.New("account already exists")
|
||||
var ErrChannelExists = errors.New("channel already exists")
|
||||
var ErrChannelDoesNotExist = errors.New("channel does not exist")
|
||||
|
||||
var tables = map[string][]string{
|
||||
"meta": {
|
||||
|
@ -41,10 +43,15 @@ var tables = map[string][]string{
|
|||
"`expires` INTEGER NULL",
|
||||
"`reason` TEXT NULL"}}
|
||||
|
||||
const (
|
||||
BAN_TYPE_ADDRESS = 1
|
||||
BAN_TYPE_ACCOUNT = 2
|
||||
)
|
||||
|
||||
type DBAccount struct {
|
||||
ID int
|
||||
Username string
|
||||
Permission int
|
||||
ID int
|
||||
Username string
|
||||
Password string
|
||||
}
|
||||
|
||||
type DBChannel struct {
|
||||
|
@ -60,6 +67,12 @@ type DBPermission struct {
|
|||
Permission int
|
||||
}
|
||||
|
||||
type DBMode struct {
|
||||
Channel string
|
||||
Mode string
|
||||
Value string
|
||||
}
|
||||
|
||||
type DBBan struct {
|
||||
Channel string
|
||||
Type int
|
||||
|
@ -69,12 +82,12 @@ type DBBan struct {
|
|||
}
|
||||
|
||||
type Database struct {
|
||||
db *sql.DB
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
func (d *Database) Connect(driver string, dataSource string) error {
|
||||
var err error
|
||||
d.db, err = sql.Open(driver, dataSource)
|
||||
d.db, err = sqlx.Connect(driver, dataSource)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to connect to %s database", driver)
|
||||
}
|
||||
|
@ -110,7 +123,7 @@ func (d *Database) CreateTables() error {
|
|||
|
||||
func (d *Database) Migrate() error {
|
||||
rows, err := d.db.Query("SELECT `value` FROM meta WHERE `key`=? LIMIT 1", "version")
|
||||
if err != nil {
|
||||
if p(err) {
|
||||
return errors.Wrap(err, "failed to fetch database version")
|
||||
}
|
||||
|
||||
|
@ -129,7 +142,7 @@ func (d *Database) Migrate() error {
|
|||
}
|
||||
|
||||
if version == -1 {
|
||||
panic("Unable to migrate database: database version unknown")
|
||||
log.Panic("Unable to migrate database: database version unknown")
|
||||
} else if version == 0 {
|
||||
_, err := d.db.Exec("UPDATE meta SET `value`=? WHERE `key`=?", strconv.Itoa(DATABASE_VERSION), "version")
|
||||
if err != nil {
|
||||
|
@ -143,23 +156,26 @@ func (d *Database) Migrate() error {
|
|||
}
|
||||
|
||||
func (d *Database) Initialize() error {
|
||||
username := ""
|
||||
err := d.db.QueryRow("SELECT username FROM accounts").Scan(&username)
|
||||
if err == sql.ErrNoRows {
|
||||
err := d.AddAccount("admin", "password")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create first account")
|
||||
}
|
||||
|
||||
ac := &DBChannel{Channel: "&", Topic: "Secret Area of VIP Quality"}
|
||||
d.AddChannel(1, ac)
|
||||
|
||||
uc := &DBChannel{Channel: "#", Topic: "Welcome to AnonIRC"}
|
||||
d.AddChannel(1, uc)
|
||||
} else if err != nil {
|
||||
return errors.Wrap(err, "failed to check for first account")
|
||||
a, err := db.Account(1)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to initialize")
|
||||
}
|
||||
|
||||
if a.ID > 0 {
|
||||
return nil // Admin account exists
|
||||
}
|
||||
|
||||
err = d.AddAccount("admin", "password")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to create initial administrator account")
|
||||
}
|
||||
|
||||
ac := &DBChannel{Channel: CHANNEL_SERVER, Topic: "Secret Area of VIP Quality"}
|
||||
d.AddChannel(1, ac)
|
||||
|
||||
uc := &DBChannel{Channel: CHANNEL_LOBBY, Topic: "Welcome to AnonIRC"}
|
||||
d.AddChannel(1, uc)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -173,37 +189,21 @@ func (d *Database) Close() error {
|
|||
|
||||
// Accounts
|
||||
|
||||
func (d *Database) Account(id int) (*DBAccount, error) {
|
||||
rows, err := d.db.Query("SELECT id, username FROM accounts WHERE id=?", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch account")
|
||||
}
|
||||
|
||||
var a *DBAccount
|
||||
for rows.Next() {
|
||||
a = new(DBAccount)
|
||||
err = rows.Scan(&a.ID, &a.Username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan account")
|
||||
}
|
||||
func (d *Database) Account(id int) (DBAccount, error) {
|
||||
a := DBAccount{}
|
||||
err := d.db.Get(&a, "SELECT * FROM accounts WHERE id=? LIMIT 1", id)
|
||||
if p(err) {
|
||||
return a, errors.Wrap(err, "failed to fetch account")
|
||||
}
|
||||
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (d *Database) AccountU(username string) (*DBAccount, error) {
|
||||
rows, err := d.db.Query("SELECT id, username FROM accounts WHERE username=?", generateHash(username))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch account by username")
|
||||
}
|
||||
|
||||
var a *DBAccount
|
||||
for rows.Next() {
|
||||
a = new(DBAccount)
|
||||
err = rows.Scan(&a.ID, &a.Username)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan account")
|
||||
}
|
||||
func (d *Database) AccountU(username string) (DBAccount, error) {
|
||||
a := DBAccount{}
|
||||
err := d.db.Get(&a, "SELECT * FROM accounts WHERE username=? LIMIT 1", generateHash(username))
|
||||
if p(err) {
|
||||
return a, errors.Wrap(err, "failed to fetch account by username")
|
||||
}
|
||||
|
||||
return a, nil
|
||||
|
@ -212,20 +212,13 @@ func (d *Database) AccountU(username string) (*DBAccount, error) {
|
|||
// TODO: Lockout on too many failed attempts
|
||||
func (d *Database) Auth(username string, password string) (int, error) {
|
||||
// TODO: Salt in config
|
||||
rows, err := d.db.Query("SELECT id FROM accounts WHERE username=? AND password=?", generateHash(username), generateHash(username+"-"+password))
|
||||
if err != nil {
|
||||
a := DBAccount{}
|
||||
err := d.db.Get(&a, "SELECT * FROM accounts WHERE username=? AND password=? LIMIT 1", generateHash(username), generateHash(username+"-"+password))
|
||||
if p(err) {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
}
|
||||
|
||||
accountid := 0
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&accountid)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
}
|
||||
}
|
||||
|
||||
return accountid, nil
|
||||
return a.ID, nil
|
||||
}
|
||||
|
||||
func (d *Database) GenerateToken() string {
|
||||
|
@ -236,7 +229,7 @@ func (d *Database) AddAccount(username string, password string) error {
|
|||
ex, err := d.AccountU(username)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to search for existing account while adding account")
|
||||
} else if ex != nil {
|
||||
} else if ex.ID > 0 {
|
||||
return ErrAccountExists
|
||||
}
|
||||
|
||||
|
@ -252,7 +245,7 @@ func (d *Database) SetUsername(accountid int, username string, password string)
|
|||
ex, err := d.AccountU(username)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to search for existing account while setting username")
|
||||
} else if ex != nil {
|
||||
} else if ex.ID > 0 {
|
||||
return ErrAccountExists
|
||||
}
|
||||
|
||||
|
@ -275,37 +268,21 @@ func (d *Database) SetPassword(accountid int, username string, password string)
|
|||
|
||||
// Channels
|
||||
|
||||
func (d *Database) ChannelID(id int) (*DBChannel, error) {
|
||||
rows, err := d.db.Query("SELECT channel, topic FROM channels WHERE id=?", id)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch channel")
|
||||
}
|
||||
|
||||
var c *DBChannel
|
||||
for rows.Next() {
|
||||
c = new(DBChannel)
|
||||
err = rows.Scan(&c.Channel, &c.Topic)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan channel")
|
||||
}
|
||||
func (d *Database) ChannelID(id int) (DBChannel, error) {
|
||||
c := DBChannel{}
|
||||
err := d.db.Get(&c, "SELECT * FROM channels WHERE id=? LIMIT 1", id)
|
||||
if p(err) {
|
||||
return c, errors.Wrap(err, "failed to fetch channel")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (d *Database) Channel(channel string) (*DBChannel, error) {
|
||||
rows, err := d.db.Query("SELECT channel, topic FROM channels WHERE channel=?", generateHash(channel))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to fetch channel by key")
|
||||
}
|
||||
|
||||
var c *DBChannel
|
||||
for rows.Next() {
|
||||
c = new(DBChannel)
|
||||
err = rows.Scan(&c.Channel, &c.Topic)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to scan channel")
|
||||
}
|
||||
func (d *Database) Channel(channel string) (DBChannel, error) {
|
||||
c := DBChannel{}
|
||||
err := d.db.Get(&c, "SELECT * FROM channels WHERE channel=? LIMIT 1", generateHash(channel))
|
||||
if p(err) {
|
||||
return c, errors.Wrap(err, "failed to fetch channel by key")
|
||||
}
|
||||
|
||||
return c, nil
|
||||
|
@ -315,7 +292,7 @@ func (d *Database) AddChannel(accountid int, channel *DBChannel) error {
|
|||
ex, err := d.Channel(channel.Channel)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to search for existing channel while adding channel")
|
||||
} else if ex != nil {
|
||||
} else if ex.Channel != "" {
|
||||
return ErrChannelExists
|
||||
}
|
||||
|
||||
|
@ -333,51 +310,51 @@ func (d *Database) AddChannel(accountid int, channel *DBChannel) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
func (d *Database) GetPermission(accountid int, channel string) (int, error) {
|
||||
rows, err := d.db.Query("SELECT permission FROM permissions WHERE account=? AND channel=?", accountid, generateHash(channel))
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
|
||||
// Permissions
|
||||
|
||||
func (d *Database) GetPermission(accountid int, channel string) (DBPermission, error) {
|
||||
dbp := DBPermission{}
|
||||
|
||||
// Return REGISTERED by default
|
||||
dbp.Permission = PERMISSION_REGISTERED
|
||||
|
||||
err := d.db.Get(&dbp, "SELECT * FROM permissions WHERE account=? AND channel=? LIMIT 1", accountid, generateHash(channel))
|
||||
if p(err) {
|
||||
return dbp, errors.Wrap(err, "failed to fetch permission")
|
||||
}
|
||||
|
||||
permission := PERMISSION_USER
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&permission)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "failed to authenticate account")
|
||||
}
|
||||
}
|
||||
|
||||
return permission, nil
|
||||
return dbp, nil
|
||||
}
|
||||
|
||||
func (d *Database) SetPermission(accountid int, channel string, permission int) error {
|
||||
acc, err := d.Account(accountid)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if acc == nil {
|
||||
log.Panicf("%+v", err)
|
||||
} else if acc.ID == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
ch, err := d.Channel(channel)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to fetch channel while setting permission")
|
||||
} else if ch == nil {
|
||||
} else if ch.Channel == "" {
|
||||
return nil
|
||||
}
|
||||
chh := generateHash(channel)
|
||||
|
||||
rows, err := d.db.Query("SELECT permission FROM permissions WHERE account=? AND channel=?", accountid, chh)
|
||||
dbp, err := d.GetPermission(accountid, chh)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set permission")
|
||||
}
|
||||
|
||||
if !rows.Next() {
|
||||
_, err = d.db.Exec("INSERT INTO permissions (channel, account, permission) VALUES (?, ?, ?)", chh, accountid, permission)
|
||||
if dbp.Channel != "" {
|
||||
_, err = d.db.Exec("UPDATE permissions SET permission=? WHERE account=? AND channel=?", permission, accountid, chh)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set permission")
|
||||
}
|
||||
} else {
|
||||
_, err = d.db.Exec("UPDATE permissions SET permission=? WHERE account=? AND channel=?", permission, accountid, chh)
|
||||
_, err = d.db.Exec("INSERT INTO permissions (channel, account, permission) VALUES (?, ?, ?)", chh, accountid, permission)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to set permission")
|
||||
}
|
||||
|
@ -385,3 +362,56 @@ func (d *Database) SetPermission(accountid int, channel string, permission int)
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bans
|
||||
|
||||
func (d *Database) Ban(banid int) (DBBan, error) {
|
||||
b := DBBan{}
|
||||
err := d.db.Get(&b, "SELECT * FROM bans WHERE id=? LIMIT 1", banid)
|
||||
if p(err) {
|
||||
return b, errors.Wrap(err, "failed to fetch ban")
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (d *Database) BanAddr(addrhash string, channel string) (DBBan, error) {
|
||||
b := DBBan{}
|
||||
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=?", generateHash(channel), BAN_TYPE_ADDRESS, addrhash)
|
||||
if p(err) {
|
||||
return b, errors.Wrap(err, "failed to fetch ban")
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (d *Database) BanAccount(accountid int, channel string) (DBBan, error) {
|
||||
b := DBBan{}
|
||||
err := d.db.Get(&b, "SELECT * FROM bans WHERE channel=? AND `type`=? AND target=?", generateHash(channel), BAN_TYPE_ACCOUNT, accountid)
|
||||
if p(err) {
|
||||
return b, errors.Wrap(err, "failed to fetch ban")
|
||||
}
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (d *Database) AddBan(b DBBan) error {
|
||||
var err error
|
||||
|
||||
// Channel-specific (not server-wide)
|
||||
if b.Channel != "&" {
|
||||
ex, err := d.Channel(b.Channel)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to search for existing ban while adding ban")
|
||||
} else if ex.Channel == "" {
|
||||
return ErrChannelDoesNotExist
|
||||
}
|
||||
}
|
||||
|
||||
_, err = d.db.Exec("INSERT INTO bans (`channel`, `type`, `target`, `expires`, `reason`) VALUES (?, ?, ?, ?, ?, ?)", b.Channel, b.Type, b.Target, b.Expires, b.Reason)
|
||||
if p(err) {
|
||||
return errors.Wrap(err, "failed to add ban")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
417
server.go
417
server.go
|
@ -10,17 +10,16 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"sort"
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/pkg/errors"
|
||||
"golang.org/x/crypto/sha3"
|
||||
irc "gopkg.in/sorcix/irc.v2"
|
||||
"gopkg.in/sorcix/irc.v2"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -30,11 +29,13 @@ const (
|
|||
// User commands
|
||||
COMMAND_REGISTER = "REGISTER"
|
||||
COMMAND_IDENTIFY = "IDENTIFY"
|
||||
COMMAND_TOKEN = "TOKEN"
|
||||
COMMAND_USERNAME = "USERNAME"
|
||||
COMMAND_PASSWORD = "PASSWORD"
|
||||
|
||||
// Channel/server commands
|
||||
COMMAND_FOUND = "FOUND"
|
||||
COMMAND_DROP = "DROP"
|
||||
COMMAND_GRANT = "GRANT"
|
||||
COMMAND_REVEAL = "REVEAL"
|
||||
COMMAND_KICK = "KICK"
|
||||
|
@ -47,6 +48,81 @@ const (
|
|||
COMMAND_UPGRADE = "UPGRADE"
|
||||
)
|
||||
|
||||
var serverCommands = []string{COMMAND_KILL, COMMAND_STATS, COMMAND_REHASH, COMMAND_UPGRADE}
|
||||
|
||||
const (
|
||||
PERMISSION_CLIENT = 0
|
||||
PERMISSION_REGISTERED = 1
|
||||
PERMISSION_VIP = 2
|
||||
PERMISSION_MODERATOR = 3
|
||||
PERMISSION_ADMIN = 4
|
||||
PERMISSION_SUPERADMIN = 5
|
||||
)
|
||||
|
||||
var permissionLabels = map[int]string{
|
||||
PERMISSION_CLIENT: "Client",
|
||||
PERMISSION_REGISTERED: "Registered",
|
||||
PERMISSION_VIP: "VIP",
|
||||
PERMISSION_MODERATOR: "Moderator",
|
||||
PERMISSION_ADMIN: "Administrator",
|
||||
PERMISSION_SUPERADMIN: "Super Administrator",
|
||||
}
|
||||
|
||||
var commandRestrictions = map[int][]string{
|
||||
PERMISSION_REGISTERED: {COMMAND_TOKEN, COMMAND_USERNAME, COMMAND_PASSWORD, COMMAND_FOUND},
|
||||
PERMISSION_MODERATOR: {COMMAND_REVEAL, COMMAND_KICK, COMMAND_BAN},
|
||||
PERMISSION_ADMIN: {COMMAND_GRANT},
|
||||
PERMISSION_SUPERADMIN: {COMMAND_DROP, COMMAND_KILL, COMMAND_STATS, COMMAND_REHASH, COMMAND_UPGRADE}}
|
||||
|
||||
var helpDuration = "Duration can be 0 to never expire, or e.g. 30m, 1h, 2d, 3w"
|
||||
var commandUsage = map[string][]string{
|
||||
COMMAND_HELP: {"[command]",
|
||||
"Print info regarding all commands or a specific command"},
|
||||
COMMAND_INFO: {"[channel]",
|
||||
"When a channel is specified, prints info including whether it is registered",
|
||||
"Without a channel, server info is printed"},
|
||||
COMMAND_REGISTER: {"<username> <password>",
|
||||
"Create an account, allowing you to found channels and moderate existing channels",
|
||||
"See IDENTIFY, FOUND, GRANT"},
|
||||
COMMAND_IDENTIFY: {"[username] <password>",
|
||||
"Identify to a previously registered account",
|
||||
"If username is omitted, it will be replaced with your current nick",
|
||||
"Note that you may automatically identify when connecting by specifying a server password of your username and password separated by a colon - Example: admin:hunter2"},
|
||||
COMMAND_TOKEN: {"<channel>",
|
||||
"Returns a token which can be used by channel administrators to grant special access to your account"},
|
||||
COMMAND_USERNAME: {"<username> <password> <new username> <confirm new username>",
|
||||
"Change your username"},
|
||||
COMMAND_PASSWORD: {"<username> <password> <new password> <confirm new password>",
|
||||
"Change your password"},
|
||||
COMMAND_FOUND: {"<channel>",
|
||||
"Register a channel"},
|
||||
COMMAND_GRANT: {"<channel> [account] [updated access]",
|
||||
"When an account token isn't specified, all permissions are listed",
|
||||
"View or update a user's access level by specifying their account token",
|
||||
"To remove an account, set their access level to User"},
|
||||
COMMAND_REVEAL: {"<channel> [page] [full]",
|
||||
"Print channel log, allowing KICK/BAN to be used",
|
||||
fmt.Sprintf("Results start at page 1, %d per page", CHANNEL_LOGS_PER_PAGE),
|
||||
"All log entries are returned when viewing page -1",
|
||||
"By default joins and parts are hidden, use 'full' to show them"},
|
||||
COMMAND_KICK: {"<channel> <5 digit log number> [reason]",
|
||||
"Kick a user from a channel"},
|
||||
COMMAND_BAN: {"<channel> <5 digit log number> <duration> [reason]",
|
||||
"Kick and ban a user from a channel",
|
||||
helpDuration},
|
||||
COMMAND_DROP: {"<channel> <confirm channel>",
|
||||
"Delete all channel data, allowing it to be FOUNDed again"},
|
||||
COMMAND_KILL: {"<channel> <5 digit log number> <duration> [reason]",
|
||||
"Disconnect and ban a user from the server",
|
||||
helpDuration},
|
||||
COMMAND_STATS: {"",
|
||||
"Print the current number of clients and channels"},
|
||||
COMMAND_REHASH: {"",
|
||||
"Reload the server configuration"},
|
||||
COMMAND_UPGRADE: {"",
|
||||
"Upgrade the server without disconnecting clients"},
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Salt string
|
||||
DBDriver string
|
||||
|
@ -59,7 +135,6 @@ type Server struct {
|
|||
config *Config
|
||||
configfile string
|
||||
created int64
|
||||
db *Database
|
||||
clients *sync.Map
|
||||
channels *sync.Map
|
||||
odyssey *os.File
|
||||
|
@ -71,12 +146,13 @@ type Server struct {
|
|||
*sync.RWMutex
|
||||
}
|
||||
|
||||
var db = &Database{}
|
||||
|
||||
func NewServer(configfile string) *Server {
|
||||
s := &Server{}
|
||||
s.config = &Config{}
|
||||
s.configfile = configfile
|
||||
s.created = time.Now().Unix()
|
||||
s.db = new(Database)
|
||||
s.clients = new(sync.Map)
|
||||
s.channels = new(sync.Map)
|
||||
s.odysseymutex = new(sync.RWMutex)
|
||||
|
@ -150,6 +226,17 @@ func (s *Server) getClient(client string) *Client {
|
|||
func (s *Server) getClients(channel string) map[string]*Client {
|
||||
clients := make(map[string]*Client)
|
||||
|
||||
if channel == "" {
|
||||
s.clients.Range(func(k, v interface{}) bool {
|
||||
cl := s.getClient(k.(string))
|
||||
if cl != nil {
|
||||
clients[cl.identifier] = cl
|
||||
}
|
||||
return true
|
||||
})
|
||||
return clients
|
||||
}
|
||||
|
||||
ch := s.getChannel(channel)
|
||||
if ch == nil {
|
||||
return clients
|
||||
|
@ -157,7 +244,7 @@ func (s *Server) getClients(channel string) map[string]*Client {
|
|||
|
||||
ch.clients.Range(func(k, v interface{}) bool {
|
||||
cl := s.getClient(k.(string))
|
||||
if channel == "" || cl != nil {
|
||||
if cl != nil {
|
||||
clients[cl.identifier] = cl
|
||||
}
|
||||
return true
|
||||
|
@ -177,23 +264,15 @@ func (s *Server) clientCount() int {
|
|||
}
|
||||
|
||||
func (s *Server) revealClient(channel string, identifier string) *Client {
|
||||
if len(identifier) != 5 {
|
||||
riphash, raccount := s.revealClientInfo(channel, identifier)
|
||||
if riphash == "" && raccount == 0 {
|
||||
log.Println("hash not found")
|
||||
return nil
|
||||
}
|
||||
|
||||
ch := s.getChannel(channel)
|
||||
if ch == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rip := ch.RevealHash(identifier)
|
||||
if rip == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
cls := s.getClients(ch.identifier)
|
||||
log.Println("have hash")
|
||||
cls := s.getClients("")
|
||||
for _, rcl := range cls {
|
||||
if rcl.ip == rip {
|
||||
if rcl.iphash == riphash || (rcl.account > 0 && rcl.account == raccount) {
|
||||
return rcl
|
||||
}
|
||||
}
|
||||
|
@ -201,6 +280,19 @@ func (s *Server) revealClient(channel string, identifier string) *Client {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) revealClientInfo(channel string, identifier string) (string, int) {
|
||||
if len(identifier) != 5 {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
ch := s.getChannel(channel)
|
||||
if ch == nil {
|
||||
return "", 0
|
||||
}
|
||||
|
||||
return ch.RevealInfo(identifier)
|
||||
}
|
||||
|
||||
func (s *Server) inChannel(channel string, client string) bool {
|
||||
ch := s.getChannel(channel)
|
||||
if ch != nil {
|
||||
|
@ -224,8 +316,8 @@ func (s *Server) joinChannel(channel string, client string) {
|
|||
if len(channel) == 0 {
|
||||
return
|
||||
} else if channel[0] == '&' {
|
||||
if s.globalPermission(cl) == PERMISSION_USER {
|
||||
cl.accessDenied()
|
||||
if cl.globalPermission() < PERMISSION_VIP {
|
||||
cl.accessDenied(0)
|
||||
return
|
||||
}
|
||||
} else if channel[0] != '#' {
|
||||
|
@ -241,6 +333,16 @@ func (s *Server) joinChannel(channel string, client string) {
|
|||
return
|
||||
}
|
||||
|
||||
banned, reason := cl.isBanned(channel)
|
||||
if banned {
|
||||
ex := ""
|
||||
if reason != "" {
|
||||
ex = ". Reason: " + reason
|
||||
}
|
||||
cl.sendNotice("Unable to join " + channel + ": You are banned" + ex)
|
||||
return
|
||||
}
|
||||
|
||||
ch.clients.Store(client, s.clientsInChannel(channel, client)+1)
|
||||
cl.write(&irc.Message{cl.getPrefix(), irc.JOIN, []string{channel}})
|
||||
ch.Log(cl, irc.JOIN, "")
|
||||
|
@ -272,7 +374,7 @@ func (s *Server) partAllChannels(client string, reason string) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *Server) revealChannelLog(channel string, client string, page int) {
|
||||
func (s *Server) revealChannelLog(channel string, client string, page int, full bool) {
|
||||
cl := s.getClient(client)
|
||||
if cl == nil {
|
||||
return
|
||||
|
@ -288,7 +390,7 @@ func (s *Server) revealChannelLog(channel string, client string, page int) {
|
|||
return
|
||||
}
|
||||
|
||||
r := ch.RevealLog(page)
|
||||
r := ch.RevealLog(page, full)
|
||||
for _, rev := range r {
|
||||
cl.sendMessage(rev)
|
||||
}
|
||||
|
@ -443,11 +545,11 @@ func (s *Server) handleTopic(channel string, client string, topic string) {
|
|||
return
|
||||
}
|
||||
|
||||
chp, err := s.db.GetPermission(cl.account, channel)
|
||||
chp, err := db.GetPermission(cl.account, channel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if ch.hasMode("t") && (cl.account == 0 || chp == PERMISSION_USER) {
|
||||
cl.accessDenied()
|
||||
log.Panicf("%+v", err)
|
||||
} else if ch.hasMode("t") && (chp.Permission < PERMISSION_VIP) {
|
||||
cl.accessDenied(PERMISSION_VIP)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -485,9 +587,9 @@ func (s *Server) handleMode(c *Client, params []string) {
|
|||
// Send channel creation time
|
||||
c.writeMessage(strings.Join([]string{"329", c.nick, params[0], fmt.Sprintf("%d", int32(ch.created))}, " "), []string{})
|
||||
} else if len(params) > 1 && len(params[1]) > 0 && (params[1][0] == '+' || params[1][0] == '-') {
|
||||
if !s.hasPermission(c, irc.MODE, params[0]) {
|
||||
if !c.canUse(irc.MODE, params[0]) {
|
||||
// TODO: Send proper mode denied message
|
||||
c.accessDenied()
|
||||
c.accessDenied(c.permissionRequired(irc.MODE))
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -582,55 +684,11 @@ func (s *Server) handleMode(c *Client, params []string) {
|
|||
}
|
||||
|
||||
func (s *Server) buildUsage(cl *Client, command string) map[string][]string {
|
||||
helpDuration := "Duration can be 0 to never expire, or e.g. 30m, 1h, 2d, 3w"
|
||||
|
||||
var commandUsage = map[string][]string{
|
||||
COMMAND_HELP: {"[command]",
|
||||
"Print info regarding all commands or a specific command"},
|
||||
COMMAND_INFO: {"[channel]",
|
||||
"Print info such as whether a channel is registered",
|
||||
"If channel is omitted, server info is printed instead"},
|
||||
COMMAND_REGISTER: {"<username> <password>",
|
||||
"Create an account, allowing you to found channels and moderate existing channels (with permission)",
|
||||
"See IDENTIFY"},
|
||||
COMMAND_IDENTIFY: {"[username] <password>",
|
||||
"Identify to a previously registered account",
|
||||
"If username is omitted, it will be replaced with your current nick",
|
||||
"Note that you may identify when connecting by sending the following server password:",
|
||||
"Your username and password: <username>:<password>",
|
||||
"Or your username and token: <username>:<token>",
|
||||
"E.g. admin:hunter2"},
|
||||
COMMAND_USERNAME: {"<username> <password> <new username> <confirm new username>",
|
||||
"Change your username"},
|
||||
COMMAND_PASSWORD: {"<username> <password> <new password> <confirm new password>",
|
||||
"Change your password"},
|
||||
COMMAND_FOUND: {"<channel>",
|
||||
"Register a channel"},
|
||||
COMMAND_REVEAL: {"<channel> [page]",
|
||||
"Print channel log, allowing KICK/BAN to be used",
|
||||
fmt.Sprintf("Results start at page 1, %d per page", CHANNEL_LOGS_PER_PAGE),
|
||||
"All log entries are returned when viewing page -1"},
|
||||
COMMAND_KICK: {"<channel> <5 digit log number> [reason]",
|
||||
"Kick a user from a channel"},
|
||||
COMMAND_BAN: {"<channel> <5 digit log number> <duration> [reason]",
|
||||
"Kick and ban a user from a channel",
|
||||
helpDuration},
|
||||
COMMAND_KILL: {"<channel> <5 digit log number> <duration> [reason]",
|
||||
"Disconnect and ban a user from the server",
|
||||
helpDuration},
|
||||
COMMAND_STATS: {"",
|
||||
"Print the current number of clients and channels"},
|
||||
COMMAND_REHASH: {"",
|
||||
"Reload the server configuration"},
|
||||
COMMAND_UPGRADE: {"",
|
||||
"Upgrade the server without disconnecting clients"},
|
||||
}
|
||||
|
||||
u := map[string][]string{}
|
||||
command = strings.ToUpper(command)
|
||||
for cmd, usage := range commandUsage {
|
||||
if command == COMMAND_HELP || cmd == command {
|
||||
if s.hasPermission(cl, cmd, "") {
|
||||
if cl.canUse(cmd, "") {
|
||||
u[cmd] = usage
|
||||
}
|
||||
}
|
||||
|
@ -668,6 +726,15 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
|
||||
var err error
|
||||
command = strings.ToUpper(command)
|
||||
ch := ""
|
||||
if len(params) > 0 {
|
||||
ch = params[0]
|
||||
}
|
||||
if !cl.canUse(command, ch) {
|
||||
cl.accessDenied(cl.permissionRequired(command))
|
||||
return
|
||||
}
|
||||
|
||||
switch command {
|
||||
case COMMAND_HELP:
|
||||
cmd := command
|
||||
|
@ -677,7 +744,8 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
s.sendUsage(cl, cmd)
|
||||
return
|
||||
case COMMAND_INFO:
|
||||
cl.sendMessage("AnonIRCd https://github.com/sageru-6ch/anonircd")
|
||||
// TODO: when channel is supplied, send whether it is registered and show a notice that it is dropping soon if no super admins have logged in in X days
|
||||
cl.sendMessage("Server info: AnonIRCd https://github.com/sageru-6ch/anonircd")
|
||||
return
|
||||
case COMMAND_REGISTER:
|
||||
if len(params) == 0 {
|
||||
|
@ -699,12 +767,24 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
password = params[1]
|
||||
}
|
||||
|
||||
authSuccess := s.identify(cl, username, password)
|
||||
authSuccess := cl.identify(username, password)
|
||||
if authSuccess {
|
||||
cl.sendNotice("Identified successfully")
|
||||
|
||||
if s.globalPermission(cl) != PERMISSION_USER {
|
||||
s.joinChannel("&", cl.identifier)
|
||||
if cl.globalPermission() >= PERMISSION_VIP {
|
||||
s.joinChannel(CHANNEL_SERVER, cl.identifier)
|
||||
}
|
||||
|
||||
for clch := range s.getChannels(cl.identifier) {
|
||||
banned, br := cl.isBanned(clch)
|
||||
if banned {
|
||||
reason := "Banned"
|
||||
if br != "" {
|
||||
reason += ": " + br
|
||||
}
|
||||
s.partChannel(clch, cl.identifier, reason)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cl.sendNotice("Failed to identify, incorrect username/password")
|
||||
|
@ -714,7 +794,7 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
cl.sendError("You must identify before using that command")
|
||||
}
|
||||
|
||||
if len(params) == 0 || len(params) > 4 {
|
||||
if len(params) == 0 || len(params) < 4 {
|
||||
s.sendUsage(cl, command)
|
||||
return
|
||||
}
|
||||
|
@ -725,9 +805,9 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
}
|
||||
// TODO: Alphanumeric username
|
||||
|
||||
accid, err := s.db.Auth(params[0], params[1])
|
||||
accid, err := db.Auth(params[0], params[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
|
||||
if accid == 0 {
|
||||
|
@ -735,17 +815,13 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
return
|
||||
}
|
||||
|
||||
err = s.db.SetUsername(accid, params[2], params[1])
|
||||
err = db.SetUsername(accid, params[2], params[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
cl.sendMessage("Username changed successfully")
|
||||
case COMMAND_PASSWORD:
|
||||
if cl.account == 0 {
|
||||
cl.sendError("You must identify before using that command")
|
||||
}
|
||||
|
||||
if len(params) == 0 || len(params) > 4 {
|
||||
if len(params) == 0 || len(params) < 4 {
|
||||
s.sendUsage(cl, command)
|
||||
return
|
||||
}
|
||||
|
@ -755,9 +831,9 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
return
|
||||
}
|
||||
|
||||
accid, err := s.db.Auth(params[0], params[1])
|
||||
accid, err := db.Auth(params[0], params[1])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
|
||||
if accid == 0 {
|
||||
|
@ -765,9 +841,9 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
return
|
||||
}
|
||||
|
||||
err = s.db.SetPassword(accid, params[0], params[2])
|
||||
err = db.SetPassword(accid, params[0], params[2])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
cl.sendMessage("Password changed successfully")
|
||||
case COMMAND_REVEAL:
|
||||
|
@ -783,11 +859,6 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
return
|
||||
}
|
||||
|
||||
if !s.hasPermission(cl, COMMAND_REVEAL, params[0]) {
|
||||
cl.accessDenied()
|
||||
return
|
||||
}
|
||||
|
||||
page := 1
|
||||
if len(params) > 1 {
|
||||
page, err = strconv.Atoi(params[1])
|
||||
|
@ -797,7 +868,14 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
}
|
||||
}
|
||||
|
||||
s.revealChannelLog(params[0], cl.identifier, page)
|
||||
full := false
|
||||
if len(params) > 2 {
|
||||
if strings.ToLower(params[2]) == "full" {
|
||||
full = true
|
||||
}
|
||||
}
|
||||
|
||||
s.revealChannelLog(params[0], cl.identifier, page, full)
|
||||
case COMMAND_KICK:
|
||||
if len(params) < 2 {
|
||||
s.sendUsage(cl, command)
|
||||
|
@ -810,11 +888,6 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
return
|
||||
}
|
||||
|
||||
if !s.hasPermission(cl, COMMAND_KICK, params[0]) {
|
||||
cl.accessDenied()
|
||||
return
|
||||
}
|
||||
|
||||
rcl := s.revealClient(params[0], params[1])
|
||||
if rcl == nil {
|
||||
cl.sendError("Unable to kick, client not found or no longer connected")
|
||||
|
@ -839,30 +912,22 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
return
|
||||
}
|
||||
|
||||
if !s.hasPermission(cl, COMMAND_BAN, params[0]) {
|
||||
cl.accessDenied()
|
||||
return
|
||||
}
|
||||
|
||||
rcl := s.revealClient(params[0], params[1])
|
||||
if rcl == nil {
|
||||
cl.sendError("Unable to ban, client not found or no longer connected")
|
||||
return
|
||||
}
|
||||
|
||||
reason := "Banned"
|
||||
reason := strings.Join(params[3:], " ")
|
||||
|
||||
partmsg := "Banned"
|
||||
if len(params) > 3 {
|
||||
reason = fmt.Sprintf("%s: %s", reason, strings.Join(params[3:], " "))
|
||||
partmsg = fmt.Sprintf("%s: %s", partmsg, reason)
|
||||
}
|
||||
// TODO: Apply ban in DB
|
||||
s.partChannel(ch.identifier, rcl.identifier, reason)
|
||||
s.partChannel(ch.identifier, rcl.identifier, partmsg)
|
||||
cl.sendMessage(fmt.Sprintf("Banned %s %s", params[0], params[1]))
|
||||
case COMMAND_KILL:
|
||||
if s.globalPermission(cl) != PERMISSION_SUPERADMIN {
|
||||
cl.accessDenied()
|
||||
return
|
||||
}
|
||||
|
||||
if len(params) < 3 {
|
||||
s.sendUsage(cl, command)
|
||||
return
|
||||
|
@ -882,17 +947,9 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
s.killClient(rcl)
|
||||
cl.sendMessage(fmt.Sprintf("Killed %s %s", params[0], params[1]))
|
||||
case COMMAND_STATS:
|
||||
if s.globalPermission(cl) != PERMISSION_SUPERADMIN {
|
||||
cl.accessDenied()
|
||||
return
|
||||
}
|
||||
|
||||
cl.sendMessage(fmt.Sprintf("%d clients in %d channels", s.clientCount(), s.channelCount()))
|
||||
case COMMAND_REHASH:
|
||||
if s.globalPermission(cl) != PERMISSION_SUPERADMIN {
|
||||
cl.accessDenied()
|
||||
return
|
||||
}
|
||||
|
||||
err := s.reload()
|
||||
if err != nil {
|
||||
|
@ -901,11 +958,6 @@ func (s *Server) handleUserCommand(client string, command string, params []strin
|
|||
cl.sendMessage("Reloaded configuration")
|
||||
}
|
||||
case COMMAND_UPGRADE:
|
||||
if s.globalPermission(cl) != PERMISSION_SUPERADMIN {
|
||||
cl.accessDenied()
|
||||
return
|
||||
}
|
||||
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
|
@ -952,76 +1004,6 @@ func (s *Server) handlePrivmsg(channel string, client string, message string) {
|
|||
ch.Log(cl, "CHAT", message)
|
||||
}
|
||||
|
||||
func (s *Server) identify(c *Client, username string, password string) bool {
|
||||
accountid, err := s.db.Auth(username, password)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var account *DBAccount
|
||||
if accountid > 0 {
|
||||
account, err = s.db.Account(accountid)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
c.account = accountid
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) globalPermission(c *Client) int {
|
||||
globalPermission := PERMISSION_USER
|
||||
if c.account > 0 {
|
||||
gp, err := s.db.GetPermission(c.account, "&")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
globalPermission = gp
|
||||
}
|
||||
|
||||
return globalPermission
|
||||
}
|
||||
|
||||
func (s *Server) hasPermission(c *Client, command string, channel string) bool {
|
||||
globalPermission := s.globalPermission(c)
|
||||
if globalPermission == PERMISSION_SUPERADMIN {
|
||||
return true
|
||||
}
|
||||
|
||||
chp := PERMISSION_USER
|
||||
var err error
|
||||
if channel != "" {
|
||||
chp, err = s.db.GetPermission(c.account, channel)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
command = strings.ToUpper(command)
|
||||
if command == COMMAND_HELP || command == COMMAND_INFO || command == COMMAND_REGISTER || command == COMMAND_IDENTIFY {
|
||||
return true
|
||||
} else if command == COMMAND_USERNAME || command == COMMAND_PASSWORD || command == COMMAND_FOUND || command == COMMAND_GRANT || command == COMMAND_KICK || command == COMMAND_BAN || command == irc.MODE {
|
||||
if !c.registered() {
|
||||
return false
|
||||
} else if channel == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
if command == COMMAND_GRANT {
|
||||
return chp == PERMISSION_SUPERADMIN
|
||||
} else if command == COMMAND_KICK || command == COMMAND_BAN || command == irc.MODE {
|
||||
return chp == PERMISSION_SUPERADMIN || chp == PERMISSION_ADMIN || chp == PERMISSION_MODERATOR
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Server) handleRead(c *Client) {
|
||||
for {
|
||||
if c.state == ENTITY_STATE_TERMINATING {
|
||||
|
@ -1069,16 +1051,16 @@ func (s *Server) handleRead(c *Client) {
|
|||
c.writeMessage(motdcode, []string{" " + motdmsg})
|
||||
}
|
||||
|
||||
s.joinChannel("#", c.identifier)
|
||||
if s.globalPermission(c) != PERMISSION_USER {
|
||||
s.joinChannel("&", c.identifier)
|
||||
s.joinChannel(CHANNEL_LOBBY, c.identifier)
|
||||
if c.globalPermission() >= PERMISSION_VIP {
|
||||
s.joinChannel(CHANNEL_SERVER, c.identifier)
|
||||
}
|
||||
} else if msg.Command == irc.PASS && c.user == "" && len(msg.Params) > 0 && len(msg.Params[0]) > 0 {
|
||||
// TODO: Add auth and multiple failed attempts ban
|
||||
authSuccess := false
|
||||
psplit := strings.SplitN(msg.Params[0], ":", 2)
|
||||
if len(psplit) == 2 {
|
||||
authSuccess = s.identify(c, psplit[0], psplit[1])
|
||||
authSuccess = c.identify(psplit[0], psplit[1])
|
||||
}
|
||||
|
||||
if !authSuccess {
|
||||
|
@ -1139,10 +1121,15 @@ func (s *Server) handleRead(c *Client) {
|
|||
key := k.(string)
|
||||
ch := v.(*Channel)
|
||||
|
||||
if ch != nil && !ch.hasMode("p") && !ch.hasMode("s") {
|
||||
chans[key] = s.clientsInChannel(key, c.identifier)
|
||||
if key[0] == '&' && c.globalPermission() < PERMISSION_VIP {
|
||||
return true
|
||||
}
|
||||
|
||||
if ch == nil || ch.hasMode("p") || ch.hasMode("s") {
|
||||
return true
|
||||
}
|
||||
|
||||
chans[key] = s.clientsInChannel(key, c.identifier)
|
||||
return true
|
||||
})
|
||||
|
||||
|
@ -1249,7 +1236,14 @@ func (s *Server) handleConnection(conn net.Conn, ssl bool) {
|
|||
}
|
||||
|
||||
client := NewClient(identifier, conn, ssl)
|
||||
if client == nil {
|
||||
banned := true
|
||||
reason := ""
|
||||
if client != nil {
|
||||
banned, reason = client.isBanned("")
|
||||
}
|
||||
if banned {
|
||||
// TODO: Send banned message
|
||||
_ = reason
|
||||
return // Banned
|
||||
}
|
||||
s.clients.Store(client.identifier, client)
|
||||
|
@ -1258,7 +1252,6 @@ func (s *Server) handleConnection(conn net.Conn, ssl bool) {
|
|||
s.handleRead(client) // Block until the connection is closed
|
||||
|
||||
s.killClient(client)
|
||||
client.conn.Close()
|
||||
}
|
||||
|
||||
func (s *Server) killClient(c *Client) {
|
||||
|
@ -1360,16 +1353,16 @@ func (s *Server) pingClients() {
|
|||
}
|
||||
|
||||
func (s *Server) connectDatabase() {
|
||||
err := s.db.Connect(s.config.DBDriver, s.config.DBSource)
|
||||
err := db.Connect(s.config.DBDriver, s.config.DBSource)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) closeDatabase() {
|
||||
err := s.db.Close()
|
||||
err := db.Close()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicf("%+v", err)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
27
utilities.go
27
utilities.go
|
@ -1,15 +1,14 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"crypto/md5"
|
||||
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/sha3"
|
||||
)
|
||||
|
||||
|
@ -20,6 +19,12 @@ type Pair struct {
|
|||
|
||||
type PairList []Pair
|
||||
|
||||
type int64arr []int64
|
||||
|
||||
func (a int64arr) Len() int { return len(a) }
|
||||
func (a int64arr) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a int64arr) Less(i, j int) bool { return a[i] < a[j] }
|
||||
|
||||
func (p PairList) Len() int {
|
||||
return len(p)
|
||||
}
|
||||
|
@ -62,3 +67,17 @@ func generateHash(s string) string {
|
|||
|
||||
return base64.URLEncoding.EncodeToString(sha512.Sum(nil))
|
||||
}
|
||||
|
||||
func containsString(s []string, e string) bool {
|
||||
for _, a := range s {
|
||||
if a == e {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Problem
|
||||
func p(err error) bool {
|
||||
return err != nil && err != sql.ErrNoRows
|
||||
}
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
# Compiled Object files, Static and Dynamic libs (Shared Objects)
|
||||
*.o
|
||||
*.a
|
||||
*.so
|
||||
|
||||
# Folders
|
||||
_obj
|
||||
_test
|
||||
|
||||
# Architecture specific extensions/prefixes
|
||||
*.[568vq]
|
||||
[568vq].out
|
||||
|
||||
*.cgo1.go
|
||||
*.cgo2.c
|
||||
_cgo_defun.c
|
||||
_cgo_gotypes.go
|
||||
_cgo_export.*
|
||||
|
||||
_testmain.go
|
||||
|
||||
*.exe
|
||||
tags
|
||||
environ
|
|
@ -0,0 +1,23 @@
|
|||
Copyright (c) 2013, Jason Moiron
|
||||
|
||||
Permission is hereby granted, free of charge, to any person
|
||||
obtaining a copy of this software and associated documentation
|
||||
files (the "Software"), to deal in the Software without
|
||||
restriction, including without limitation the rights to use,
|
||||
copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following
|
||||
conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
||||
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
||||
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
|
@ -0,0 +1,185 @@
|
|||
# sqlx
|
||||
|
||||
[![Build Status](https://drone.io/github.com/jmoiron/sqlx/status.png)](https://drone.io/github.com/jmoiron/sqlx/latest) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/jmoiron/sqlx) [![license](http://img.shields.io/badge/license-MIT-red.svg?style=flat)](https://raw.githubusercontent.com/jmoiron/sqlx/master/LICENSE)
|
||||
|
||||
sqlx is a library which provides a set of extensions on go's standard
|
||||
`database/sql` library. The sqlx versions of `sql.DB`, `sql.TX`, `sql.Stmt`,
|
||||
et al. all leave the underlying interfaces untouched, so that their interfaces
|
||||
are a superset on the standard ones. This makes it relatively painless to
|
||||
integrate existing codebases using database/sql with sqlx.
|
||||
|
||||
Major additional concepts are:
|
||||
|
||||
* Marshal rows into structs (with embedded struct support), maps, and slices
|
||||
* Named parameter support including prepared statements
|
||||
* `Get` and `Select` to go quickly from query to struct/slice
|
||||
|
||||
In addition to the [godoc API documentation](http://godoc.org/github.com/jmoiron/sqlx),
|
||||
there is also some [standard documentation](http://jmoiron.github.io/sqlx/) that
|
||||
explains how to use `database/sql` along with sqlx.
|
||||
|
||||
## Recent Changes
|
||||
|
||||
* sqlx/types.JsonText has been renamed to JSONText to follow Go naming conventions.
|
||||
|
||||
This breaks backwards compatibility, but it's in a way that is trivially fixable
|
||||
(`s/JsonText/JSONText/g`). The `types` package is both experimental and not in
|
||||
active development currently.
|
||||
|
||||
* Using Go 1.6 and below with `types.JSONText` and `types.GzippedText` can be _potentially unsafe_, **especially** when used with common auto-scan sqlx idioms like `Select` and `Get`. See [golang bug #13905](https://github.com/golang/go/issues/13905).
|
||||
|
||||
### Backwards Compatibility
|
||||
|
||||
There is no Go1-like promise of absolute stability, but I take the issue seriously
|
||||
and will maintain the library in a compatible state unless vital bugs prevent me
|
||||
from doing so. Since [#59](https://github.com/jmoiron/sqlx/issues/59) and
|
||||
[#60](https://github.com/jmoiron/sqlx/issues/60) necessitated breaking behavior,
|
||||
a wider API cleanup was done at the time of fixing. It's possible this will happen
|
||||
in future; if it does, a git tag will be provided for users requiring the old
|
||||
behavior to continue to use it until such a time as they can migrate.
|
||||
|
||||
## install
|
||||
|
||||
go get github.com/jmoiron/sqlx
|
||||
|
||||
## issues
|
||||
|
||||
Row headers can be ambiguous (`SELECT 1 AS a, 2 AS a`), and the result of
|
||||
`Columns()` does not fully qualify column names in queries like:
|
||||
|
||||
```sql
|
||||
SELECT a.id, a.name, b.id, b.name FROM foos AS a JOIN foos AS b ON a.parent = b.id;
|
||||
```
|
||||
|
||||
making a struct or map destination ambiguous. Use `AS` in your queries
|
||||
to give columns distinct names, `rows.Scan` to scan them manually, or
|
||||
`SliceScan` to get a slice of results.
|
||||
|
||||
## usage
|
||||
|
||||
Below is an example which shows some common use cases for sqlx. Check
|
||||
[sqlx_test.go](https://github.com/jmoiron/sqlx/blob/master/sqlx_test.go) for more
|
||||
usage.
|
||||
|
||||
|
||||
```go
|
||||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
var schema = `
|
||||
CREATE TABLE person (
|
||||
first_name text,
|
||||
last_name text,
|
||||
email text
|
||||
);
|
||||
|
||||
CREATE TABLE place (
|
||||
country text,
|
||||
city text NULL,
|
||||
telcode integer
|
||||
)`
|
||||
|
||||
type Person struct {
|
||||
FirstName string `db:"first_name"`
|
||||
LastName string `db:"last_name"`
|
||||
Email string
|
||||
}
|
||||
|
||||
type Place struct {
|
||||
Country string
|
||||
City sql.NullString
|
||||
TelCode int
|
||||
}
|
||||
|
||||
func main() {
|
||||
// this Pings the database trying to connect, panics on error
|
||||
// use sqlx.Open() for sql.Open() semantics
|
||||
db, err := sqlx.Connect("postgres", "user=foo dbname=bar sslmode=disable")
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
|
||||
// exec the schema or fail; multi-statement Exec behavior varies between
|
||||
// database drivers; pq will exec them all, sqlite3 won't, ymmv
|
||||
db.MustExec(schema)
|
||||
|
||||
tx := db.MustBegin()
|
||||
tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "Jason", "Moiron", "jmoiron@jmoiron.net")
|
||||
tx.MustExec("INSERT INTO person (first_name, last_name, email) VALUES ($1, $2, $3)", "John", "Doe", "johndoeDNE@gmail.net")
|
||||
tx.MustExec("INSERT INTO place (country, city, telcode) VALUES ($1, $2, $3)", "United States", "New York", "1")
|
||||
tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Hong Kong", "852")
|
||||
tx.MustExec("INSERT INTO place (country, telcode) VALUES ($1, $2)", "Singapore", "65")
|
||||
// Named queries can use structs, so if you have an existing struct (i.e. person := &Person{}) that you have populated, you can pass it in as &person
|
||||
tx.NamedExec("INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)", &Person{"Jane", "Citizen", "jane.citzen@example.com"})
|
||||
tx.Commit()
|
||||
|
||||
// Query the database, storing results in a []Person (wrapped in []interface{})
|
||||
people := []Person{}
|
||||
db.Select(&people, "SELECT * FROM person ORDER BY first_name ASC")
|
||||
jason, john := people[0], people[1]
|
||||
|
||||
fmt.Printf("%#v\n%#v", jason, john)
|
||||
// Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"}
|
||||
// Person{FirstName:"John", LastName:"Doe", Email:"johndoeDNE@gmail.net"}
|
||||
|
||||
// You can also get a single result, a la QueryRow
|
||||
jason = Person{}
|
||||
err = db.Get(&jason, "SELECT * FROM person WHERE first_name=$1", "Jason")
|
||||
fmt.Printf("%#v\n", jason)
|
||||
// Person{FirstName:"Jason", LastName:"Moiron", Email:"jmoiron@jmoiron.net"}
|
||||
|
||||
// if you have null fields and use SELECT *, you must use sql.Null* in your struct
|
||||
places := []Place{}
|
||||
err = db.Select(&places, "SELECT * FROM place ORDER BY telcode ASC")
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
usa, singsing, honkers := places[0], places[1], places[2]
|
||||
|
||||
fmt.Printf("%#v\n%#v\n%#v\n", usa, singsing, honkers)
|
||||
// Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1}
|
||||
// Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65}
|
||||
// Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852}
|
||||
|
||||
// Loop through rows using only one struct
|
||||
place := Place{}
|
||||
rows, err := db.Queryx("SELECT * FROM place")
|
||||
for rows.Next() {
|
||||
err := rows.StructScan(&place)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
fmt.Printf("%#v\n", place)
|
||||
}
|
||||
// Place{Country:"United States", City:sql.NullString{String:"New York", Valid:true}, TelCode:1}
|
||||
// Place{Country:"Hong Kong", City:sql.NullString{String:"", Valid:false}, TelCode:852}
|
||||
// Place{Country:"Singapore", City:sql.NullString{String:"", Valid:false}, TelCode:65}
|
||||
|
||||
// Named queries, using `:name` as the bindvar. Automatic bindvar support
|
||||
// which takes into account the dbtype based on the driverName on sqlx.Open/Connect
|
||||
_, err = db.NamedExec(`INSERT INTO person (first_name,last_name,email) VALUES (:first,:last,:email)`,
|
||||
map[string]interface{}{
|
||||
"first": "Bin",
|
||||
"last": "Smuth",
|
||||
"email": "bensmith@allblacks.nz",
|
||||
})
|
||||
|
||||
// Selects Mr. Smith from the database
|
||||
rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:fn`, map[string]interface{}{"fn": "Bin"})
|
||||
|
||||
// Named queries can also use structs. Their bind names follow the same rules
|
||||
// as the name -> db mapping, so struct fields are lowercased and the `db` tag
|
||||
// is taken into consideration.
|
||||
rows, err = db.NamedQuery(`SELECT * FROM person WHERE first_name=:first_name`, jason)
|
||||
}
|
||||
```
|
||||
|
|
@ -0,0 +1,207 @@
|
|||
package sqlx
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx/reflectx"
|
||||
)
|
||||
|
||||
// Bindvar types supported by Rebind, BindMap and BindStruct.
|
||||
const (
|
||||
UNKNOWN = iota
|
||||
QUESTION
|
||||
DOLLAR
|
||||
NAMED
|
||||
)
|
||||
|
||||
// BindType returns the bindtype for a given database given a drivername.
|
||||
func BindType(driverName string) int {
|
||||
switch driverName {
|
||||
case "postgres", "pgx":
|
||||
return DOLLAR
|
||||
case "mysql":
|
||||
return QUESTION
|
||||
case "sqlite3":
|
||||
return QUESTION
|
||||
case "oci8", "ora", "goracle":
|
||||
return NAMED
|
||||
}
|
||||
return UNKNOWN
|
||||
}
|
||||
|
||||
// FIXME: this should be able to be tolerant of escaped ?'s in queries without
|
||||
// losing much speed, and should be to avoid confusion.
|
||||
|
||||
// Rebind a query from the default bindtype (QUESTION) to the target bindtype.
|
||||
func Rebind(bindType int, query string) string {
|
||||
switch bindType {
|
||||
case QUESTION, UNKNOWN:
|
||||
return query
|
||||
}
|
||||
|
||||
// Add space enough for 10 params before we have to allocate
|
||||
rqb := make([]byte, 0, len(query)+10)
|
||||
|
||||
var i, j int
|
||||
|
||||
for i = strings.Index(query, "?"); i != -1; i = strings.Index(query, "?") {
|
||||
rqb = append(rqb, query[:i]...)
|
||||
|
||||
switch bindType {
|
||||
case DOLLAR:
|
||||
rqb = append(rqb, '$')
|
||||
case NAMED:
|
||||
rqb = append(rqb, ':', 'a', 'r', 'g')
|
||||
}
|
||||
|
||||
j++
|
||||
rqb = strconv.AppendInt(rqb, int64(j), 10)
|
||||
|
||||
query = query[i+1:]
|
||||
}
|
||||
|
||||
return string(append(rqb, query...))
|
||||
}
|
||||
|
||||
// Experimental implementation of Rebind which uses a bytes.Buffer. The code is
|
||||
// much simpler and should be more resistant to odd unicode, but it is twice as
|
||||
// slow. Kept here for benchmarking purposes and to possibly replace Rebind if
|
||||
// problems arise with its somewhat naive handling of unicode.
|
||||
func rebindBuff(bindType int, query string) string {
|
||||
if bindType != DOLLAR {
|
||||
return query
|
||||
}
|
||||
|
||||
b := make([]byte, 0, len(query))
|
||||
rqb := bytes.NewBuffer(b)
|
||||
j := 1
|
||||
for _, r := range query {
|
||||
if r == '?' {
|
||||
rqb.WriteRune('$')
|
||||
rqb.WriteString(strconv.Itoa(j))
|
||||
j++
|
||||
} else {
|
||||
rqb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
return rqb.String()
|
||||
}
|
||||
|
||||
// In expands slice values in args, returning the modified query string
|
||||
// and a new arg list that can be executed by a database. The `query` should
|
||||
// use the `?` bindVar. The return value uses the `?` bindVar.
|
||||
func In(query string, args ...interface{}) (string, []interface{}, error) {
|
||||
// argMeta stores reflect.Value and length for slices and
|
||||
// the value itself for non-slice arguments
|
||||
type argMeta struct {
|
||||
v reflect.Value
|
||||
i interface{}
|
||||
length int
|
||||
}
|
||||
|
||||
var flatArgsCount int
|
||||
var anySlices bool
|
||||
|
||||
meta := make([]argMeta, len(args))
|
||||
|
||||
for i, arg := range args {
|
||||
v := reflect.ValueOf(arg)
|
||||
t := reflectx.Deref(v.Type())
|
||||
|
||||
if t.Kind() == reflect.Slice {
|
||||
meta[i].length = v.Len()
|
||||
meta[i].v = v
|
||||
|
||||
anySlices = true
|
||||
flatArgsCount += meta[i].length
|
||||
|
||||
if meta[i].length == 0 {
|
||||
return "", nil, errors.New("empty slice passed to 'in' query")
|
||||
}
|
||||
} else {
|
||||
meta[i].i = arg
|
||||
flatArgsCount++
|
||||
}
|
||||
}
|
||||
|
||||
// don't do any parsing if there aren't any slices; note that this means
|
||||
// some errors that we might have caught below will not be returned.
|
||||
if !anySlices {
|
||||
return query, args, nil
|
||||
}
|
||||
|
||||
newArgs := make([]interface{}, 0, flatArgsCount)
|
||||
buf := bytes.NewBuffer(make([]byte, 0, len(query)+len(", ?")*flatArgsCount))
|
||||
|
||||
var arg, offset int
|
||||
|
||||
for i := strings.IndexByte(query[offset:], '?'); i != -1; i = strings.IndexByte(query[offset:], '?') {
|
||||
if arg >= len(meta) {
|
||||
// if an argument wasn't passed, lets return an error; this is
|
||||
// not actually how database/sql Exec/Query works, but since we are
|
||||
// creating an argument list programmatically, we want to be able
|
||||
// to catch these programmer errors earlier.
|
||||
return "", nil, errors.New("number of bindVars exceeds arguments")
|
||||
}
|
||||
|
||||
argMeta := meta[arg]
|
||||
arg++
|
||||
|
||||
// not a slice, continue.
|
||||
// our questionmark will either be written before the next expansion
|
||||
// of a slice or after the loop when writing the rest of the query
|
||||
if argMeta.length == 0 {
|
||||
offset = offset + i + 1
|
||||
newArgs = append(newArgs, argMeta.i)
|
||||
continue
|
||||
}
|
||||
|
||||
// write everything up to and including our ? character
|
||||
buf.WriteString(query[:offset+i+1])
|
||||
|
||||
for si := 1; si < argMeta.length; si++ {
|
||||
buf.WriteString(", ?")
|
||||
}
|
||||
|
||||
newArgs = appendReflectSlice(newArgs, argMeta.v, argMeta.length)
|
||||
|
||||
// slice the query and reset the offset. this avoids some bookkeeping for
|
||||
// the write after the loop
|
||||
query = query[offset+i+1:]
|
||||
offset = 0
|
||||
}
|
||||
|
||||
buf.WriteString(query)
|
||||
|
||||
if arg < len(meta) {
|
||||
return "", nil, errors.New("number of bindVars less than number arguments")
|
||||
}
|
||||
|
||||
return buf.String(), newArgs, nil
|
||||
}
|
||||
|
||||
func appendReflectSlice(args []interface{}, v reflect.Value, vlen int) []interface{} {
|
||||
switch val := v.Interface().(type) {
|
||||
case []interface{}:
|
||||
args = append(args, val...)
|
||||
case []int:
|
||||
for i := range val {
|
||||
args = append(args, val[i])
|
||||
}
|
||||
case []string:
|
||||
for i := range val {
|
||||
args = append(args, val[i])
|
||||
}
|
||||
default:
|
||||
for si := 0; si < vlen; si++ {
|
||||
args = append(args, v.Index(si).Interface())
|
||||
}
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
|
@ -0,0 +1,12 @@
|
|||
// Package sqlx provides general purpose extensions to database/sql.
|
||||
//
|
||||
// It is intended to seamlessly wrap database/sql and provide convenience
|
||||
// methods which are useful in the development of database driven applications.
|
||||
// None of the underlying database/sql methods are changed. Instead all extended
|
||||
// behavior is implemented through new methods defined on wrapper types.
|
||||
//
|
||||
// Additions include scanning into structs, named query support, rebinding
|
||||
// queries for different drivers, convenient shorthands for common error handling
|
||||
// and more.
|
||||
//
|
||||
package sqlx
|
|
@ -0,0 +1,344 @@
|
|||
package sqlx
|
||||
|
||||
// Named Query Support
|
||||
//
|
||||
// * BindMap - bind query bindvars to map/struct args
|
||||
// * NamedExec, NamedQuery - named query w/ struct or map
|
||||
// * NamedStmt - a pre-compiled named query which is a prepared statement
|
||||
//
|
||||
// Internal Interfaces:
|
||||
//
|
||||
// * compileNamedQuery - rebind a named query, returning a query and list of names
|
||||
// * bindArgs, bindMapArgs, bindAnyArgs - given a list of names, return an arglist
|
||||
//
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"unicode"
|
||||
|
||||
"github.com/jmoiron/sqlx/reflectx"
|
||||
)
|
||||
|
||||
// NamedStmt is a prepared statement that executes named queries. Prepare it
|
||||
// how you would execute a NamedQuery, but pass in a struct or map when executing.
|
||||
type NamedStmt struct {
|
||||
Params []string
|
||||
QueryString string
|
||||
Stmt *Stmt
|
||||
}
|
||||
|
||||
// Close closes the named statement.
|
||||
func (n *NamedStmt) Close() error {
|
||||
return n.Stmt.Close()
|
||||
}
|
||||
|
||||
// Exec executes a named statement using the struct passed.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) Exec(arg interface{}) (sql.Result, error) {
|
||||
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
|
||||
if err != nil {
|
||||
return *new(sql.Result), err
|
||||
}
|
||||
return n.Stmt.Exec(args...)
|
||||
}
|
||||
|
||||
// Query executes a named statement using the struct argument, returning rows.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) Query(arg interface{}) (*sql.Rows, error) {
|
||||
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return n.Stmt.Query(args...)
|
||||
}
|
||||
|
||||
// QueryRow executes a named statement against the database. Because sqlx cannot
|
||||
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
|
||||
// returns a *sqlx.Row instead.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) QueryRow(arg interface{}) *Row {
|
||||
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
|
||||
if err != nil {
|
||||
return &Row{err: err}
|
||||
}
|
||||
return n.Stmt.QueryRowx(args...)
|
||||
}
|
||||
|
||||
// MustExec execs a NamedStmt, panicing on error
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) MustExec(arg interface{}) sql.Result {
|
||||
res, err := n.Exec(arg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// Queryx using this NamedStmt
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) Queryx(arg interface{}) (*Rows, error) {
|
||||
r, err := n.Query(arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
|
||||
}
|
||||
|
||||
// QueryRowx this NamedStmt. Because of limitations with QueryRow, this is
|
||||
// an alias for QueryRow.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) QueryRowx(arg interface{}) *Row {
|
||||
return n.QueryRow(arg)
|
||||
}
|
||||
|
||||
// Select using this NamedStmt
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) Select(dest interface{}, arg interface{}) error {
|
||||
rows, err := n.Queryx(arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// if something happens here, we want to make sure the rows are Closed
|
||||
defer rows.Close()
|
||||
return scanAll(rows, dest, false)
|
||||
}
|
||||
|
||||
// Get using this NamedStmt
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) Get(dest interface{}, arg interface{}) error {
|
||||
r := n.QueryRowx(arg)
|
||||
return r.scanAny(dest, false)
|
||||
}
|
||||
|
||||
// Unsafe creates an unsafe version of the NamedStmt
|
||||
func (n *NamedStmt) Unsafe() *NamedStmt {
|
||||
r := &NamedStmt{Params: n.Params, Stmt: n.Stmt, QueryString: n.QueryString}
|
||||
r.Stmt.unsafe = true
|
||||
return r
|
||||
}
|
||||
|
||||
// A union interface of preparer and binder, required to be able to prepare
|
||||
// named statements (as the bindtype must be determined).
|
||||
type namedPreparer interface {
|
||||
Preparer
|
||||
binder
|
||||
}
|
||||
|
||||
func prepareNamed(p namedPreparer, query string) (*NamedStmt, error) {
|
||||
bindType := BindType(p.DriverName())
|
||||
q, args, err := compileNamedQuery([]byte(query), bindType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt, err := Preparex(p, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &NamedStmt{
|
||||
QueryString: q,
|
||||
Params: args,
|
||||
Stmt: stmt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func bindAnyArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
|
||||
if maparg, ok := arg.(map[string]interface{}); ok {
|
||||
return bindMapArgs(names, maparg)
|
||||
}
|
||||
return bindArgs(names, arg, m)
|
||||
}
|
||||
|
||||
// private interface to generate a list of interfaces from a given struct
|
||||
// type, given a list of names to pull out of the struct. Used by public
|
||||
// BindStruct interface.
|
||||
func bindArgs(names []string, arg interface{}, m *reflectx.Mapper) ([]interface{}, error) {
|
||||
arglist := make([]interface{}, 0, len(names))
|
||||
|
||||
// grab the indirected value of arg
|
||||
v := reflect.ValueOf(arg)
|
||||
for v = reflect.ValueOf(arg); v.Kind() == reflect.Ptr; {
|
||||
v = v.Elem()
|
||||
}
|
||||
|
||||
fields := m.TraversalsByName(v.Type(), names)
|
||||
for i, t := range fields {
|
||||
if len(t) == 0 {
|
||||
return arglist, fmt.Errorf("could not find name %s in %#v", names[i], arg)
|
||||
}
|
||||
val := reflectx.FieldByIndexesReadOnly(v, t)
|
||||
arglist = append(arglist, val.Interface())
|
||||
}
|
||||
|
||||
return arglist, nil
|
||||
}
|
||||
|
||||
// like bindArgs, but for maps.
|
||||
func bindMapArgs(names []string, arg map[string]interface{}) ([]interface{}, error) {
|
||||
arglist := make([]interface{}, 0, len(names))
|
||||
|
||||
for _, name := range names {
|
||||
val, ok := arg[name]
|
||||
if !ok {
|
||||
return arglist, fmt.Errorf("could not find name %s in %#v", name, arg)
|
||||
}
|
||||
arglist = append(arglist, val)
|
||||
}
|
||||
return arglist, nil
|
||||
}
|
||||
|
||||
// bindStruct binds a named parameter query with fields from a struct argument.
|
||||
// The rules for binding field names to parameter names follow the same
|
||||
// conventions as for StructScan, including obeying the `db` struct tags.
|
||||
func bindStruct(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
|
||||
bound, names, err := compileNamedQuery([]byte(query), bindType)
|
||||
if err != nil {
|
||||
return "", []interface{}{}, err
|
||||
}
|
||||
|
||||
arglist, err := bindArgs(names, arg, m)
|
||||
if err != nil {
|
||||
return "", []interface{}{}, err
|
||||
}
|
||||
|
||||
return bound, arglist, nil
|
||||
}
|
||||
|
||||
// bindMap binds a named parameter query with a map of arguments.
|
||||
func bindMap(bindType int, query string, args map[string]interface{}) (string, []interface{}, error) {
|
||||
bound, names, err := compileNamedQuery([]byte(query), bindType)
|
||||
if err != nil {
|
||||
return "", []interface{}{}, err
|
||||
}
|
||||
|
||||
arglist, err := bindMapArgs(names, args)
|
||||
return bound, arglist, err
|
||||
}
|
||||
|
||||
// -- Compilation of Named Queries
|
||||
|
||||
// Allow digits and letters in bind params; additionally runes are
|
||||
// checked against underscores, meaning that bind params can have be
|
||||
// alphanumeric with underscores. Mind the difference between unicode
|
||||
// digits and numbers, where '5' is a digit but '五' is not.
|
||||
var allowedBindRunes = []*unicode.RangeTable{unicode.Letter, unicode.Digit}
|
||||
|
||||
// FIXME: this function isn't safe for unicode named params, as a failing test
|
||||
// can testify. This is not a regression but a failure of the original code
|
||||
// as well. It should be modified to range over runes in a string rather than
|
||||
// bytes, even though this is less convenient and slower. Hopefully the
|
||||
// addition of the prepared NamedStmt (which will only do this once) will make
|
||||
// up for the slightly slower ad-hoc NamedExec/NamedQuery.
|
||||
|
||||
// compile a NamedQuery into an unbound query (using the '?' bindvar) and
|
||||
// a list of names.
|
||||
func compileNamedQuery(qs []byte, bindType int) (query string, names []string, err error) {
|
||||
names = make([]string, 0, 10)
|
||||
rebound := make([]byte, 0, len(qs))
|
||||
|
||||
inName := false
|
||||
last := len(qs) - 1
|
||||
currentVar := 1
|
||||
name := make([]byte, 0, 10)
|
||||
|
||||
for i, b := range qs {
|
||||
// a ':' while we're in a name is an error
|
||||
if b == ':' {
|
||||
// if this is the second ':' in a '::' escape sequence, append a ':'
|
||||
if inName && i > 0 && qs[i-1] == ':' {
|
||||
rebound = append(rebound, ':')
|
||||
inName = false
|
||||
continue
|
||||
} else if inName {
|
||||
err = errors.New("unexpected `:` while reading named param at " + strconv.Itoa(i))
|
||||
return query, names, err
|
||||
}
|
||||
inName = true
|
||||
name = []byte{}
|
||||
// if we're in a name, and this is an allowed character, continue
|
||||
} else if inName && (unicode.IsOneOf(allowedBindRunes, rune(b)) || b == '_' || b == '.') && i != last {
|
||||
// append the byte to the name if we are in a name and not on the last byte
|
||||
name = append(name, b)
|
||||
// if we're in a name and it's not an allowed character, the name is done
|
||||
} else if inName {
|
||||
inName = false
|
||||
// if this is the final byte of the string and it is part of the name, then
|
||||
// make sure to add it to the name
|
||||
if i == last && unicode.IsOneOf(allowedBindRunes, rune(b)) {
|
||||
name = append(name, b)
|
||||
}
|
||||
// add the string representation to the names list
|
||||
names = append(names, string(name))
|
||||
// add a proper bindvar for the bindType
|
||||
switch bindType {
|
||||
// oracle only supports named type bind vars even for positional
|
||||
case NAMED:
|
||||
rebound = append(rebound, ':')
|
||||
rebound = append(rebound, name...)
|
||||
case QUESTION, UNKNOWN:
|
||||
rebound = append(rebound, '?')
|
||||
case DOLLAR:
|
||||
rebound = append(rebound, '$')
|
||||
for _, b := range strconv.Itoa(currentVar) {
|
||||
rebound = append(rebound, byte(b))
|
||||
}
|
||||
currentVar++
|
||||
}
|
||||
// add this byte to string unless it was not part of the name
|
||||
if i != last {
|
||||
rebound = append(rebound, b)
|
||||
} else if !unicode.IsOneOf(allowedBindRunes, rune(b)) {
|
||||
rebound = append(rebound, b)
|
||||
}
|
||||
} else {
|
||||
// this is a normal byte and should just go onto the rebound query
|
||||
rebound = append(rebound, b)
|
||||
}
|
||||
}
|
||||
|
||||
return string(rebound), names, err
|
||||
}
|
||||
|
||||
// BindNamed binds a struct or a map to a query with named parameters.
|
||||
// DEPRECATED: use sqlx.Named` instead of this, it may be removed in future.
|
||||
func BindNamed(bindType int, query string, arg interface{}) (string, []interface{}, error) {
|
||||
return bindNamedMapper(bindType, query, arg, mapper())
|
||||
}
|
||||
|
||||
// Named takes a query using named parameters and an argument and
|
||||
// returns a new query with a list of args that can be executed by
|
||||
// a database. The return value uses the `?` bindvar.
|
||||
func Named(query string, arg interface{}) (string, []interface{}, error) {
|
||||
return bindNamedMapper(QUESTION, query, arg, mapper())
|
||||
}
|
||||
|
||||
func bindNamedMapper(bindType int, query string, arg interface{}, m *reflectx.Mapper) (string, []interface{}, error) {
|
||||
if maparg, ok := arg.(map[string]interface{}); ok {
|
||||
return bindMap(bindType, query, maparg)
|
||||
}
|
||||
return bindStruct(bindType, query, arg, m)
|
||||
}
|
||||
|
||||
// NamedQuery binds a named query and then runs Query on the result using the
|
||||
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
|
||||
// map[string]interface{} types.
|
||||
func NamedQuery(e Ext, query string, arg interface{}) (*Rows, error) {
|
||||
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e.Queryx(q, args...)
|
||||
}
|
||||
|
||||
// NamedExec uses BindStruct to get a query executable by the driver and
|
||||
// then runs Exec on the result. Returns an error from the binding
|
||||
// or the query excution itself.
|
||||
func NamedExec(e Ext, query string, arg interface{}) (sql.Result, error) {
|
||||
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e.Exec(q, args...)
|
||||
}
|
|
@ -0,0 +1,132 @@
|
|||
// +build go1.8
|
||||
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
// A union interface of contextPreparer and binder, required to be able to
|
||||
// prepare named statements with context (as the bindtype must be determined).
|
||||
type namedPreparerContext interface {
|
||||
PreparerContext
|
||||
binder
|
||||
}
|
||||
|
||||
func prepareNamedContext(ctx context.Context, p namedPreparerContext, query string) (*NamedStmt, error) {
|
||||
bindType := BindType(p.DriverName())
|
||||
q, args, err := compileNamedQuery([]byte(query), bindType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
stmt, err := PreparexContext(ctx, p, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &NamedStmt{
|
||||
QueryString: q,
|
||||
Params: args,
|
||||
Stmt: stmt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ExecContext executes a named statement using the struct passed.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) ExecContext(ctx context.Context, arg interface{}) (sql.Result, error) {
|
||||
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
|
||||
if err != nil {
|
||||
return *new(sql.Result), err
|
||||
}
|
||||
return n.Stmt.ExecContext(ctx, args...)
|
||||
}
|
||||
|
||||
// QueryContext executes a named statement using the struct argument, returning rows.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) QueryContext(ctx context.Context, arg interface{}) (*sql.Rows, error) {
|
||||
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return n.Stmt.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
// QueryRowContext executes a named statement against the database. Because sqlx cannot
|
||||
// create a *sql.Row with an error condition pre-set for binding errors, sqlx
|
||||
// returns a *sqlx.Row instead.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) QueryRowContext(ctx context.Context, arg interface{}) *Row {
|
||||
args, err := bindAnyArgs(n.Params, arg, n.Stmt.Mapper)
|
||||
if err != nil {
|
||||
return &Row{err: err}
|
||||
}
|
||||
return n.Stmt.QueryRowxContext(ctx, args...)
|
||||
}
|
||||
|
||||
// MustExecContext execs a NamedStmt, panicing on error
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) MustExecContext(ctx context.Context, arg interface{}) sql.Result {
|
||||
res, err := n.ExecContext(ctx, arg)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// QueryxContext using this NamedStmt
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) QueryxContext(ctx context.Context, arg interface{}) (*Rows, error) {
|
||||
r, err := n.QueryContext(ctx, arg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{Rows: r, Mapper: n.Stmt.Mapper, unsafe: isUnsafe(n)}, err
|
||||
}
|
||||
|
||||
// QueryRowxContext this NamedStmt. Because of limitations with QueryRow, this is
|
||||
// an alias for QueryRow.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) QueryRowxContext(ctx context.Context, arg interface{}) *Row {
|
||||
return n.QueryRowContext(ctx, arg)
|
||||
}
|
||||
|
||||
// SelectContext using this NamedStmt
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) SelectContext(ctx context.Context, dest interface{}, arg interface{}) error {
|
||||
rows, err := n.QueryxContext(ctx, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// if something happens here, we want to make sure the rows are Closed
|
||||
defer rows.Close()
|
||||
return scanAll(rows, dest, false)
|
||||
}
|
||||
|
||||
// GetContext using this NamedStmt
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (n *NamedStmt) GetContext(ctx context.Context, dest interface{}, arg interface{}) error {
|
||||
r := n.QueryRowxContext(ctx, arg)
|
||||
return r.scanAny(dest, false)
|
||||
}
|
||||
|
||||
// NamedQueryContext binds a named query and then runs Query on the result using the
|
||||
// provided Ext (sqlx.Tx, sqlx.Db). It works with both structs and with
|
||||
// map[string]interface{} types.
|
||||
func NamedQueryContext(ctx context.Context, e ExtContext, query string, arg interface{}) (*Rows, error) {
|
||||
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e.QueryxContext(ctx, q, args...)
|
||||
}
|
||||
|
||||
// NamedExecContext uses BindStruct to get a query executable by the driver and
|
||||
// then runs Exec on the result. Returns an error from the binding
|
||||
// or the query excution itself.
|
||||
func NamedExecContext(ctx context.Context, e ExtContext, query string, arg interface{}) (sql.Result, error) {
|
||||
q, args, err := bindNamedMapper(BindType(e.DriverName()), query, arg, mapperFor(e))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return e.ExecContext(ctx, q, args...)
|
||||
}
|
|
@ -0,0 +1,136 @@
|
|||
// +build go1.8
|
||||
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNamedContextQueries(t *testing.T) {
|
||||
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
|
||||
loadDefaultFixture(db, t)
|
||||
test := Test{t}
|
||||
var ns *NamedStmt
|
||||
var err error
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Check that invalid preparations fail
|
||||
ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first:name")
|
||||
if err == nil {
|
||||
t.Error("Expected an error with invalid prepared statement.")
|
||||
}
|
||||
|
||||
ns, err = db.PrepareNamedContext(ctx, "invalid sql")
|
||||
if err == nil {
|
||||
t.Error("Expected an error with invalid prepared statement.")
|
||||
}
|
||||
|
||||
// Check closing works as anticipated
|
||||
ns, err = db.PrepareNamedContext(ctx, "SELECT * FROM person WHERE first_name=:first_name")
|
||||
test.Error(err)
|
||||
err = ns.Close()
|
||||
test.Error(err)
|
||||
|
||||
ns, err = db.PrepareNamedContext(ctx, `
|
||||
SELECT first_name, last_name, email
|
||||
FROM person WHERE first_name=:first_name AND email=:email`)
|
||||
test.Error(err)
|
||||
|
||||
// test Queryx w/ uses Query
|
||||
p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"}
|
||||
|
||||
rows, err := ns.QueryxContext(ctx, p)
|
||||
test.Error(err)
|
||||
for rows.Next() {
|
||||
var p2 Person
|
||||
rows.StructScan(&p2)
|
||||
if p.FirstName != p2.FirstName {
|
||||
t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName)
|
||||
}
|
||||
if p.LastName != p2.LastName {
|
||||
t.Errorf("got %s, expected %s", p.LastName, p2.LastName)
|
||||
}
|
||||
if p.Email != p2.Email {
|
||||
t.Errorf("got %s, expected %s", p.Email, p2.Email)
|
||||
}
|
||||
}
|
||||
|
||||
// test Select
|
||||
people := make([]Person, 0, 5)
|
||||
err = ns.SelectContext(ctx, &people, p)
|
||||
test.Error(err)
|
||||
|
||||
if len(people) != 1 {
|
||||
t.Errorf("got %d results, expected %d", len(people), 1)
|
||||
}
|
||||
if p.FirstName != people[0].FirstName {
|
||||
t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName)
|
||||
}
|
||||
if p.LastName != people[0].LastName {
|
||||
t.Errorf("got %s, expected %s", p.LastName, people[0].LastName)
|
||||
}
|
||||
if p.Email != people[0].Email {
|
||||
t.Errorf("got %s, expected %s", p.Email, people[0].Email)
|
||||
}
|
||||
|
||||
// test Exec
|
||||
ns, err = db.PrepareNamedContext(ctx, `
|
||||
INSERT INTO person (first_name, last_name, email)
|
||||
VALUES (:first_name, :last_name, :email)`)
|
||||
test.Error(err)
|
||||
|
||||
js := Person{
|
||||
FirstName: "Julien",
|
||||
LastName: "Savea",
|
||||
Email: "jsavea@ab.co.nz",
|
||||
}
|
||||
_, err = ns.ExecContext(ctx, js)
|
||||
test.Error(err)
|
||||
|
||||
// Make sure we can pull him out again
|
||||
p2 := Person{}
|
||||
db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email)
|
||||
if p2.Email != js.Email {
|
||||
t.Errorf("expected %s, got %s", js.Email, p2.Email)
|
||||
}
|
||||
|
||||
// test Txn NamedStmts
|
||||
tx := db.MustBeginTx(ctx, nil)
|
||||
txns := tx.NamedStmtContext(ctx, ns)
|
||||
|
||||
// We're going to add Steven in this txn
|
||||
sl := Person{
|
||||
FirstName: "Steven",
|
||||
LastName: "Luatua",
|
||||
Email: "sluatua@ab.co.nz",
|
||||
}
|
||||
|
||||
_, err = txns.ExecContext(ctx, sl)
|
||||
test.Error(err)
|
||||
// then rollback...
|
||||
tx.Rollback()
|
||||
// looking for Steven after a rollback should fail
|
||||
err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
|
||||
if err != sql.ErrNoRows {
|
||||
t.Errorf("expected no rows error, got %v", err)
|
||||
}
|
||||
|
||||
// now do the same, but commit
|
||||
tx = db.MustBeginTx(ctx, nil)
|
||||
txns = tx.NamedStmtContext(ctx, ns)
|
||||
_, err = txns.ExecContext(ctx, sl)
|
||||
test.Error(err)
|
||||
tx.Commit()
|
||||
|
||||
// looking for Steven after a Commit should succeed
|
||||
err = db.GetContext(ctx, &p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
|
||||
test.Error(err)
|
||||
if p2.Email != sl.Email {
|
||||
t.Errorf("expected %s, got %s", sl.Email, p2.Email)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
|
@ -0,0 +1,227 @@
|
|||
package sqlx
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
table := []struct {
|
||||
Q, R, D, N string
|
||||
V []string
|
||||
}{
|
||||
// basic test for named parameters, invalid char ',' terminating
|
||||
{
|
||||
Q: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
|
||||
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
|
||||
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
|
||||
N: `INSERT INTO foo (a,b,c,d) VALUES (:name, :age, :first, :last)`,
|
||||
V: []string{"name", "age", "first", "last"},
|
||||
},
|
||||
// This query tests a named parameter ending the string as well as numbers
|
||||
{
|
||||
Q: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
|
||||
R: `SELECT * FROM a WHERE first_name=? AND last_name=?`,
|
||||
D: `SELECT * FROM a WHERE first_name=$1 AND last_name=$2`,
|
||||
N: `SELECT * FROM a WHERE first_name=:name1 AND last_name=:name2`,
|
||||
V: []string{"name1", "name2"},
|
||||
},
|
||||
{
|
||||
Q: `SELECT "::foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
|
||||
R: `SELECT ":foo" FROM a WHERE first_name=? AND last_name=?`,
|
||||
D: `SELECT ":foo" FROM a WHERE first_name=$1 AND last_name=$2`,
|
||||
N: `SELECT ":foo" FROM a WHERE first_name=:name1 AND last_name=:name2`,
|
||||
V: []string{"name1", "name2"},
|
||||
},
|
||||
{
|
||||
Q: `SELECT 'a::b::c' || first_name, '::::ABC::_::' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
|
||||
R: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=? AND last_name=?`,
|
||||
D: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=$1 AND last_name=$2`,
|
||||
N: `SELECT 'a:b:c' || first_name, '::ABC:_:' FROM person WHERE first_name=:first_name AND last_name=:last_name`,
|
||||
V: []string{"first_name", "last_name"},
|
||||
},
|
||||
/* This unicode awareness test sadly fails, because of our byte-wise worldview.
|
||||
* We could certainly iterate by Rune instead, though it's a great deal slower,
|
||||
* it's probably the RightWay(tm)
|
||||
{
|
||||
Q: `INSERT INTO foo (a,b,c,d) VALUES (:あ, :b, :キコ, :名前)`,
|
||||
R: `INSERT INTO foo (a,b,c,d) VALUES (?, ?, ?, ?)`,
|
||||
D: `INSERT INTO foo (a,b,c,d) VALUES ($1, $2, $3, $4)`,
|
||||
N: []string{"name", "age", "first", "last"},
|
||||
},
|
||||
*/
|
||||
}
|
||||
|
||||
for _, test := range table {
|
||||
qr, names, err := compileNamedQuery([]byte(test.Q), QUESTION)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
if qr != test.R {
|
||||
t.Errorf("expected %s, got %s", test.R, qr)
|
||||
}
|
||||
if len(names) != len(test.V) {
|
||||
t.Errorf("expected %#v, got %#v", test.V, names)
|
||||
} else {
|
||||
for i, name := range names {
|
||||
if name != test.V[i] {
|
||||
t.Errorf("expected %dth name to be %s, got %s", i+1, test.V[i], name)
|
||||
}
|
||||
}
|
||||
}
|
||||
qd, _, _ := compileNamedQuery([]byte(test.Q), DOLLAR)
|
||||
if qd != test.D {
|
||||
t.Errorf("\nexpected: `%s`\ngot: `%s`", test.D, qd)
|
||||
}
|
||||
|
||||
qq, _, _ := compileNamedQuery([]byte(test.Q), NAMED)
|
||||
if qq != test.N {
|
||||
t.Errorf("\nexpected: `%s`\ngot: `%s`\n(len: %d vs %d)", test.N, qq, len(test.N), len(qq))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type Test struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (t Test) Error(err error, msg ...interface{}) {
|
||||
if err != nil {
|
||||
if len(msg) == 0 {
|
||||
t.t.Error(err)
|
||||
} else {
|
||||
t.t.Error(msg...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t Test) Errorf(err error, format string, args ...interface{}) {
|
||||
if err != nil {
|
||||
t.t.Errorf(format, args...)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamedQueries(t *testing.T) {
|
||||
RunWithSchema(defaultSchema, t, func(db *DB, t *testing.T) {
|
||||
loadDefaultFixture(db, t)
|
||||
test := Test{t}
|
||||
var ns *NamedStmt
|
||||
var err error
|
||||
|
||||
// Check that invalid preparations fail
|
||||
ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first:name")
|
||||
if err == nil {
|
||||
t.Error("Expected an error with invalid prepared statement.")
|
||||
}
|
||||
|
||||
ns, err = db.PrepareNamed("invalid sql")
|
||||
if err == nil {
|
||||
t.Error("Expected an error with invalid prepared statement.")
|
||||
}
|
||||
|
||||
// Check closing works as anticipated
|
||||
ns, err = db.PrepareNamed("SELECT * FROM person WHERE first_name=:first_name")
|
||||
test.Error(err)
|
||||
err = ns.Close()
|
||||
test.Error(err)
|
||||
|
||||
ns, err = db.PrepareNamed(`
|
||||
SELECT first_name, last_name, email
|
||||
FROM person WHERE first_name=:first_name AND email=:email`)
|
||||
test.Error(err)
|
||||
|
||||
// test Queryx w/ uses Query
|
||||
p := Person{FirstName: "Jason", LastName: "Moiron", Email: "jmoiron@jmoiron.net"}
|
||||
|
||||
rows, err := ns.Queryx(p)
|
||||
test.Error(err)
|
||||
for rows.Next() {
|
||||
var p2 Person
|
||||
rows.StructScan(&p2)
|
||||
if p.FirstName != p2.FirstName {
|
||||
t.Errorf("got %s, expected %s", p.FirstName, p2.FirstName)
|
||||
}
|
||||
if p.LastName != p2.LastName {
|
||||
t.Errorf("got %s, expected %s", p.LastName, p2.LastName)
|
||||
}
|
||||
if p.Email != p2.Email {
|
||||
t.Errorf("got %s, expected %s", p.Email, p2.Email)
|
||||
}
|
||||
}
|
||||
|
||||
// test Select
|
||||
people := make([]Person, 0, 5)
|
||||
err = ns.Select(&people, p)
|
||||
test.Error(err)
|
||||
|
||||
if len(people) != 1 {
|
||||
t.Errorf("got %d results, expected %d", len(people), 1)
|
||||
}
|
||||
if p.FirstName != people[0].FirstName {
|
||||
t.Errorf("got %s, expected %s", p.FirstName, people[0].FirstName)
|
||||
}
|
||||
if p.LastName != people[0].LastName {
|
||||
t.Errorf("got %s, expected %s", p.LastName, people[0].LastName)
|
||||
}
|
||||
if p.Email != people[0].Email {
|
||||
t.Errorf("got %s, expected %s", p.Email, people[0].Email)
|
||||
}
|
||||
|
||||
// test Exec
|
||||
ns, err = db.PrepareNamed(`
|
||||
INSERT INTO person (first_name, last_name, email)
|
||||
VALUES (:first_name, :last_name, :email)`)
|
||||
test.Error(err)
|
||||
|
||||
js := Person{
|
||||
FirstName: "Julien",
|
||||
LastName: "Savea",
|
||||
Email: "jsavea@ab.co.nz",
|
||||
}
|
||||
_, err = ns.Exec(js)
|
||||
test.Error(err)
|
||||
|
||||
// Make sure we can pull him out again
|
||||
p2 := Person{}
|
||||
db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), js.Email)
|
||||
if p2.Email != js.Email {
|
||||
t.Errorf("expected %s, got %s", js.Email, p2.Email)
|
||||
}
|
||||
|
||||
// test Txn NamedStmts
|
||||
tx := db.MustBegin()
|
||||
txns := tx.NamedStmt(ns)
|
||||
|
||||
// We're going to add Steven in this txn
|
||||
sl := Person{
|
||||
FirstName: "Steven",
|
||||
LastName: "Luatua",
|
||||
Email: "sluatua@ab.co.nz",
|
||||
}
|
||||
|
||||
_, err = txns.Exec(sl)
|
||||
test.Error(err)
|
||||
// then rollback...
|
||||
tx.Rollback()
|
||||
// looking for Steven after a rollback should fail
|
||||
err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
|
||||
if err != sql.ErrNoRows {
|
||||
t.Errorf("expected no rows error, got %v", err)
|
||||
}
|
||||
|
||||
// now do the same, but commit
|
||||
tx = db.MustBegin()
|
||||
txns = tx.NamedStmt(ns)
|
||||
_, err = txns.Exec(sl)
|
||||
test.Error(err)
|
||||
tx.Commit()
|
||||
|
||||
// looking for Steven after a Commit should succeed
|
||||
err = db.Get(&p2, db.Rebind("SELECT * FROM person WHERE email=?"), sl.Email)
|
||||
test.Error(err)
|
||||
if p2.Email != sl.Email {
|
||||
t.Errorf("expected %s, got %s", sl.Email, p2.Email)
|
||||
}
|
||||
|
||||
})
|
||||
}
|
|
@ -0,0 +1,17 @@
|
|||
# reflectx
|
||||
|
||||
The sqlx package has special reflect needs. In particular, it needs to:
|
||||
|
||||
* be able to map a name to a field
|
||||
* understand embedded structs
|
||||
* understand mapping names to fields by a particular tag
|
||||
* user specified name -> field mapping functions
|
||||
|
||||
These behaviors mimic the behaviors by the standard library marshallers and also the
|
||||
behavior of standard Go accessors.
|
||||
|
||||
The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is
|
||||
addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct
|
||||
tags in the ways that are vital to most marshallers, and they are slow.
|
||||
|
||||
This reflectx package extends reflect to achieve these goals.
|
|
@ -0,0 +1,422 @@
|
|||
// Package reflectx implements extensions to the standard reflect lib suitable
|
||||
// for implementing marshalling and unmarshalling packages. The main Mapper type
|
||||
// allows for Go-compatible named attribute access, including accessing embedded
|
||||
// struct attributes and the ability to use functions and struct tags to
|
||||
// customize field names.
|
||||
//
|
||||
package reflectx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A FieldInfo is metadata for a struct field.
|
||||
type FieldInfo struct {
|
||||
Index []int
|
||||
Path string
|
||||
Field reflect.StructField
|
||||
Zero reflect.Value
|
||||
Name string
|
||||
Options map[string]string
|
||||
Embedded bool
|
||||
Children []*FieldInfo
|
||||
Parent *FieldInfo
|
||||
}
|
||||
|
||||
// A StructMap is an index of field metadata for a struct.
|
||||
type StructMap struct {
|
||||
Tree *FieldInfo
|
||||
Index []*FieldInfo
|
||||
Paths map[string]*FieldInfo
|
||||
Names map[string]*FieldInfo
|
||||
}
|
||||
|
||||
// GetByPath returns a *FieldInfo for a given string path.
|
||||
func (f StructMap) GetByPath(path string) *FieldInfo {
|
||||
return f.Paths[path]
|
||||
}
|
||||
|
||||
// GetByTraversal returns a *FieldInfo for a given integer path. It is
|
||||
// analogous to reflect.FieldByIndex, but using the cached traversal
|
||||
// rather than re-executing the reflect machinery each time.
|
||||
func (f StructMap) GetByTraversal(index []int) *FieldInfo {
|
||||
if len(index) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tree := f.Tree
|
||||
for _, i := range index {
|
||||
if i >= len(tree.Children) || tree.Children[i] == nil {
|
||||
return nil
|
||||
}
|
||||
tree = tree.Children[i]
|
||||
}
|
||||
return tree
|
||||
}
|
||||
|
||||
// Mapper is a general purpose mapper of names to struct fields. A Mapper
|
||||
// behaves like most marshallers in the standard library, obeying a field tag
|
||||
// for name mapping but also providing a basic transform function.
|
||||
type Mapper struct {
|
||||
cache map[reflect.Type]*StructMap
|
||||
tagName string
|
||||
tagMapFunc func(string) string
|
||||
mapFunc func(string) string
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewMapper returns a new mapper using the tagName as its struct field tag.
|
||||
// If tagName is the empty string, it is ignored.
|
||||
func NewMapper(tagName string) *Mapper {
|
||||
return &Mapper{
|
||||
cache: make(map[reflect.Type]*StructMap),
|
||||
tagName: tagName,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
|
||||
// AND a mapper for tag values. This is useful for tags like json which can
|
||||
// have values like "name,omitempty".
|
||||
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
|
||||
return &Mapper{
|
||||
cache: make(map[reflect.Type]*StructMap),
|
||||
tagName: tagName,
|
||||
mapFunc: mapFunc,
|
||||
tagMapFunc: tagMapFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
|
||||
// a struct field name mapper func given by f. Tags will take precedence, but
|
||||
// for any other field, the mapped name will be f(field.Name)
|
||||
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
|
||||
return &Mapper{
|
||||
cache: make(map[reflect.Type]*StructMap),
|
||||
tagName: tagName,
|
||||
mapFunc: f,
|
||||
}
|
||||
}
|
||||
|
||||
// TypeMap returns a mapping of field strings to int slices representing
|
||||
// the traversal down the struct to reach the field.
|
||||
func (m *Mapper) TypeMap(t reflect.Type) *StructMap {
|
||||
m.mutex.Lock()
|
||||
mapping, ok := m.cache[t]
|
||||
if !ok {
|
||||
mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
|
||||
m.cache[t] = mapping
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return mapping
|
||||
}
|
||||
|
||||
// FieldMap returns the mapper's mapping of field names to reflect values. Panics
|
||||
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
|
||||
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
mustBe(v, reflect.Struct)
|
||||
|
||||
r := map[string]reflect.Value{}
|
||||
tm := m.TypeMap(v.Type())
|
||||
for tagName, fi := range tm.Names {
|
||||
r[tagName] = FieldByIndexes(v, fi.Index)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FieldByName returns a field by its mapped name as a reflect.Value.
|
||||
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
|
||||
// Returns zero Value if the name is not found.
|
||||
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
mustBe(v, reflect.Struct)
|
||||
|
||||
tm := m.TypeMap(v.Type())
|
||||
fi, ok := tm.Names[name]
|
||||
if !ok {
|
||||
return v
|
||||
}
|
||||
return FieldByIndexes(v, fi.Index)
|
||||
}
|
||||
|
||||
// FieldsByName returns a slice of values corresponding to the slice of names
|
||||
// for the value. Panics if v's Kind is not Struct or v is not Indirectable
|
||||
// to a struct Kind. Returns zero Value for each name not found.
|
||||
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
mustBe(v, reflect.Struct)
|
||||
|
||||
tm := m.TypeMap(v.Type())
|
||||
vals := make([]reflect.Value, 0, len(names))
|
||||
for _, name := range names {
|
||||
fi, ok := tm.Names[name]
|
||||
if !ok {
|
||||
vals = append(vals, *new(reflect.Value))
|
||||
} else {
|
||||
vals = append(vals, FieldByIndexes(v, fi.Index))
|
||||
}
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
// TraversalsByName returns a slice of int slices which represent the struct
|
||||
// traversals for each mapped name. Panics if t is not a struct or Indirectable
|
||||
// to a struct. Returns empty int slice for each name not found.
|
||||
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
|
||||
t = Deref(t)
|
||||
mustBe(t, reflect.Struct)
|
||||
tm := m.TypeMap(t)
|
||||
|
||||
r := make([][]int, 0, len(names))
|
||||
for _, name := range names {
|
||||
fi, ok := tm.Names[name]
|
||||
if !ok {
|
||||
r = append(r, []int{})
|
||||
} else {
|
||||
r = append(r, fi.Index)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FieldByIndexes returns a value for the field given by the struct traversal
|
||||
// for the given value.
|
||||
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
|
||||
for _, i := range indexes {
|
||||
v = reflect.Indirect(v).Field(i)
|
||||
// if this is a pointer and it's nil, allocate a new value and set it
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
alloc := reflect.New(Deref(v.Type()))
|
||||
v.Set(alloc)
|
||||
}
|
||||
if v.Kind() == reflect.Map && v.IsNil() {
|
||||
v.Set(reflect.MakeMap(v.Type()))
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
|
||||
// but is not concerned with allocating nil pointers because the value is
|
||||
// going to be used for reading and not setting.
|
||||
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
|
||||
for _, i := range indexes {
|
||||
v = reflect.Indirect(v).Field(i)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Deref is Indirect for reflect.Types
|
||||
func Deref(t reflect.Type) reflect.Type {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// -- helpers & utilities --
|
||||
|
||||
type kinder interface {
|
||||
Kind() reflect.Kind
|
||||
}
|
||||
|
||||
// mustBe checks a value against a kind, panicing with a reflect.ValueError
|
||||
// if the kind isn't that which is required.
|
||||
func mustBe(v kinder, expected reflect.Kind) {
|
||||
if k := v.Kind(); k != expected {
|
||||
panic(&reflect.ValueError{Method: methodName(), Kind: k})
|
||||
}
|
||||
}
|
||||
|
||||
// methodName returns the caller of the function calling methodName
|
||||
func methodName() string {
|
||||
pc, _, _, _ := runtime.Caller(2)
|
||||
f := runtime.FuncForPC(pc)
|
||||
if f == nil {
|
||||
return "unknown method"
|
||||
}
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
type typeQueue struct {
|
||||
t reflect.Type
|
||||
fi *FieldInfo
|
||||
pp string // Parent path
|
||||
}
|
||||
|
||||
// A copying append that creates a new slice each time.
|
||||
func apnd(is []int, i int) []int {
|
||||
x := make([]int, len(is)+1)
|
||||
for p, n := range is {
|
||||
x[p] = n
|
||||
}
|
||||
x[len(x)-1] = i
|
||||
return x
|
||||
}
|
||||
|
||||
type mapf func(string) string
|
||||
|
||||
// parseName parses the tag and the target name for the given field using
|
||||
// the tagName (eg 'json' for `json:"foo"` tags), mapFunc for mapping the
|
||||
// field's name to a target name, and tagMapFunc for mapping the tag to
|
||||
// a target name.
|
||||
func parseName(field reflect.StructField, tagName string, mapFunc, tagMapFunc mapf) (tag, fieldName string) {
|
||||
// first, set the fieldName to the field's name
|
||||
fieldName = field.Name
|
||||
// if a mapFunc is set, use that to override the fieldName
|
||||
if mapFunc != nil {
|
||||
fieldName = mapFunc(fieldName)
|
||||
}
|
||||
|
||||
// if there's no tag to look for, return the field name
|
||||
if tagName == "" {
|
||||
return "", fieldName
|
||||
}
|
||||
|
||||
// if this tag is not set using the normal convention in the tag,
|
||||
// then return the fieldname.. this check is done because according
|
||||
// to the reflect documentation:
|
||||
// If the tag does not have the conventional format,
|
||||
// the value returned by Get is unspecified.
|
||||
// which doesn't sound great.
|
||||
if !strings.Contains(string(field.Tag), tagName+":") {
|
||||
return "", fieldName
|
||||
}
|
||||
|
||||
// at this point we're fairly sure that we have a tag, so lets pull it out
|
||||
tag = field.Tag.Get(tagName)
|
||||
|
||||
// if we have a mapper function, call it on the whole tag
|
||||
// XXX: this is a change from the old version, which pulled out the name
|
||||
// before the tagMapFunc could be run, but I think this is the right way
|
||||
if tagMapFunc != nil {
|
||||
tag = tagMapFunc(tag)
|
||||
}
|
||||
|
||||
// finally, split the options from the name
|
||||
parts := strings.Split(tag, ",")
|
||||
fieldName = parts[0]
|
||||
|
||||
return tag, fieldName
|
||||
}
|
||||
|
||||
// parseOptions parses options out of a tag string, skipping the name
|
||||
func parseOptions(tag string) map[string]string {
|
||||
parts := strings.Split(tag, ",")
|
||||
options := make(map[string]string, len(parts))
|
||||
if len(parts) > 1 {
|
||||
for _, opt := range parts[1:] {
|
||||
// short circuit potentially expensive split op
|
||||
if strings.Contains(opt, "=") {
|
||||
kv := strings.Split(opt, "=")
|
||||
options[kv[0]] = kv[1]
|
||||
continue
|
||||
}
|
||||
options[opt] = ""
|
||||
}
|
||||
}
|
||||
return options
|
||||
}
|
||||
|
||||
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
|
||||
// tagMapFunc to determine the canonical names of fields.
|
||||
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc mapf) *StructMap {
|
||||
m := []*FieldInfo{}
|
||||
|
||||
root := &FieldInfo{}
|
||||
queue := []typeQueue{}
|
||||
queue = append(queue, typeQueue{Deref(t), root, ""})
|
||||
|
||||
QueueLoop:
|
||||
for len(queue) != 0 {
|
||||
// pop the first item off of the queue
|
||||
tq := queue[0]
|
||||
queue = queue[1:]
|
||||
|
||||
// ignore recursive field
|
||||
for p := tq.fi.Parent; p != nil; p = p.Parent {
|
||||
if tq.fi.Field.Type == p.Field.Type {
|
||||
continue QueueLoop
|
||||
}
|
||||
}
|
||||
|
||||
nChildren := 0
|
||||
if tq.t.Kind() == reflect.Struct {
|
||||
nChildren = tq.t.NumField()
|
||||
}
|
||||
tq.fi.Children = make([]*FieldInfo, nChildren)
|
||||
|
||||
// iterate through all of its fields
|
||||
for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
|
||||
|
||||
f := tq.t.Field(fieldPos)
|
||||
|
||||
// parse the tag and the target name using the mapping options for this field
|
||||
tag, name := parseName(f, tagName, mapFunc, tagMapFunc)
|
||||
|
||||
// if the name is "-", disabled via a tag, skip it
|
||||
if name == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
fi := FieldInfo{
|
||||
Field: f,
|
||||
Name: name,
|
||||
Zero: reflect.New(f.Type).Elem(),
|
||||
Options: parseOptions(tag),
|
||||
}
|
||||
|
||||
// if the path is empty this path is just the name
|
||||
if tq.pp == "" {
|
||||
fi.Path = fi.Name
|
||||
} else {
|
||||
fi.Path = tq.pp + "." + fi.Name
|
||||
}
|
||||
|
||||
// skip unexported fields
|
||||
if len(f.PkgPath) != 0 && !f.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
// bfs search of anonymous embedded structs
|
||||
if f.Anonymous {
|
||||
pp := tq.pp
|
||||
if tag != "" {
|
||||
pp = fi.Path
|
||||
}
|
||||
|
||||
fi.Embedded = true
|
||||
fi.Index = apnd(tq.fi.Index, fieldPos)
|
||||
nChildren := 0
|
||||
ft := Deref(f.Type)
|
||||
if ft.Kind() == reflect.Struct {
|
||||
nChildren = ft.NumField()
|
||||
}
|
||||
fi.Children = make([]*FieldInfo, nChildren)
|
||||
queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
|
||||
} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {
|
||||
fi.Index = apnd(tq.fi.Index, fieldPos)
|
||||
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
|
||||
queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
|
||||
}
|
||||
|
||||
fi.Index = apnd(tq.fi.Index, fieldPos)
|
||||
fi.Parent = tq.fi
|
||||
tq.fi.Children[fieldPos] = &fi
|
||||
m = append(m, &fi)
|
||||
}
|
||||
}
|
||||
|
||||
flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}}
|
||||
for _, fi := range flds.Index {
|
||||
flds.Paths[fi.Path] = fi
|
||||
if fi.Name != "" && !fi.Embedded {
|
||||
flds.Names[fi.Path] = fi
|
||||
}
|
||||
}
|
||||
|
||||
return flds
|
||||
}
|
|
@ -0,0 +1,905 @@
|
|||
package reflectx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ival(v reflect.Value) int {
|
||||
return v.Interface().(int)
|
||||
}
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int
|
||||
B int
|
||||
C int
|
||||
}
|
||||
|
||||
f := Foo{1, 2, 3}
|
||||
fv := reflect.ValueOf(f)
|
||||
m := NewMapperFunc("", func(s string) string { return s })
|
||||
|
||||
v := m.FieldByName(fv, "A")
|
||||
if ival(v) != f.A {
|
||||
t.Errorf("Expecting %d, got %d", ival(v), f.A)
|
||||
}
|
||||
v = m.FieldByName(fv, "B")
|
||||
if ival(v) != f.B {
|
||||
t.Errorf("Expecting %d, got %d", f.B, ival(v))
|
||||
}
|
||||
v = m.FieldByName(fv, "C")
|
||||
if ival(v) != f.C {
|
||||
t.Errorf("Expecting %d, got %d", f.C, ival(v))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicEmbedded(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
Foo // `db:""` is implied for an embedded struct
|
||||
B int
|
||||
C int `db:"-"`
|
||||
}
|
||||
|
||||
type Baz struct {
|
||||
A int
|
||||
Bar `db:"Bar"`
|
||||
}
|
||||
|
||||
m := NewMapperFunc("db", func(s string) string { return s })
|
||||
|
||||
z := Baz{}
|
||||
z.A = 1
|
||||
z.B = 2
|
||||
z.C = 4
|
||||
z.Bar.Foo.A = 3
|
||||
|
||||
zv := reflect.ValueOf(z)
|
||||
fields := m.TypeMap(reflect.TypeOf(z))
|
||||
|
||||
if len(fields.Index) != 5 {
|
||||
t.Errorf("Expecting 5 fields")
|
||||
}
|
||||
|
||||
// for _, fi := range fields.Index {
|
||||
// log.Println(fi)
|
||||
// }
|
||||
|
||||
v := m.FieldByName(zv, "A")
|
||||
if ival(v) != z.A {
|
||||
t.Errorf("Expecting %d, got %d", z.A, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "Bar.B")
|
||||
if ival(v) != z.Bar.B {
|
||||
t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "Bar.A")
|
||||
if ival(v) != z.Bar.Foo.A {
|
||||
t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "Bar.C")
|
||||
if _, ok := v.Interface().(int); ok {
|
||||
t.Errorf("Expecting Bar.C to not exist")
|
||||
}
|
||||
|
||||
fi := fields.GetByPath("Bar.C")
|
||||
if fi != nil {
|
||||
t.Errorf("Bar.C should not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedSimple(t *testing.T) {
|
||||
type UUID [16]byte
|
||||
type MyID struct {
|
||||
UUID
|
||||
}
|
||||
type Item struct {
|
||||
ID MyID
|
||||
}
|
||||
z := Item{}
|
||||
|
||||
m := NewMapper("db")
|
||||
m.TypeMap(reflect.TypeOf(z))
|
||||
}
|
||||
|
||||
func TestBasicEmbeddedWithTags(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int `db:"a"`
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
Foo // `db:""` is implied for an embedded struct
|
||||
B int `db:"b"`
|
||||
}
|
||||
|
||||
type Baz struct {
|
||||
A int `db:"a"`
|
||||
Bar // `db:""` is implied for an embedded struct
|
||||
}
|
||||
|
||||
m := NewMapper("db")
|
||||
|
||||
z := Baz{}
|
||||
z.A = 1
|
||||
z.B = 2
|
||||
z.Bar.Foo.A = 3
|
||||
|
||||
zv := reflect.ValueOf(z)
|
||||
fields := m.TypeMap(reflect.TypeOf(z))
|
||||
|
||||
if len(fields.Index) != 5 {
|
||||
t.Errorf("Expecting 5 fields")
|
||||
}
|
||||
|
||||
// for _, fi := range fields.index {
|
||||
// log.Println(fi)
|
||||
// }
|
||||
|
||||
v := m.FieldByName(zv, "a")
|
||||
if ival(v) != z.Bar.Foo.A { // the dominant field
|
||||
t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "b")
|
||||
if ival(v) != z.B {
|
||||
t.Errorf("Expecting %d, got %d", z.B, ival(v))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlatTags(t *testing.T) {
|
||||
m := NewMapper("db")
|
||||
|
||||
type Asset struct {
|
||||
Title string `db:"title"`
|
||||
}
|
||||
type Post struct {
|
||||
Author string `db:"author,required"`
|
||||
Asset Asset `db:""`
|
||||
}
|
||||
// Post columns: (author title)
|
||||
|
||||
post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}}
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
v := m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "title")
|
||||
if v.Interface().(string) != post.Asset.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedStruct(t *testing.T) {
|
||||
m := NewMapper("db")
|
||||
|
||||
type Details struct {
|
||||
Active bool `db:"active"`
|
||||
}
|
||||
type Asset struct {
|
||||
Title string `db:"title"`
|
||||
Details Details `db:"details"`
|
||||
}
|
||||
type Post struct {
|
||||
Author string `db:"author,required"`
|
||||
Asset `db:"asset"`
|
||||
}
|
||||
// Post columns: (author asset.title asset.details.active)
|
||||
|
||||
post := Post{
|
||||
Author: "Joe",
|
||||
Asset: Asset{Title: "Hello", Details: Details{Active: true}},
|
||||
}
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
v := m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "title")
|
||||
if _, ok := v.Interface().(string); ok {
|
||||
t.Errorf("Expecting field to not exist")
|
||||
}
|
||||
v = m.FieldByName(pv, "asset.title")
|
||||
if v.Interface().(string) != post.Asset.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset.details.active")
|
||||
if v.Interface().(bool) != post.Asset.Details.Active {
|
||||
t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInlineStruct(t *testing.T) {
|
||||
m := NewMapperTagFunc("db", strings.ToLower, nil)
|
||||
|
||||
type Employee struct {
|
||||
Name string
|
||||
ID int
|
||||
}
|
||||
type Boss Employee
|
||||
type person struct {
|
||||
Employee `db:"employee"`
|
||||
Boss `db:"boss"`
|
||||
}
|
||||
// employees columns: (employee.name employee.id boss.name boss.id)
|
||||
|
||||
em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}}
|
||||
ev := reflect.ValueOf(em)
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(em))
|
||||
if len(fields.Index) != 6 {
|
||||
t.Errorf("Expecting 6 fields")
|
||||
}
|
||||
|
||||
v := m.FieldByName(ev, "employee.name")
|
||||
if v.Interface().(string) != em.Employee.Name {
|
||||
t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(ev, "boss.id")
|
||||
if ival(v) != em.Boss.ID {
|
||||
t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecursiveStruct(t *testing.T) {
|
||||
type Person struct {
|
||||
Parent *Person
|
||||
}
|
||||
m := NewMapperFunc("db", strings.ToLower)
|
||||
var p *Person
|
||||
m.TypeMap(reflect.TypeOf(p))
|
||||
}
|
||||
|
||||
func TestFieldsEmbedded(t *testing.T) {
|
||||
m := NewMapper("db")
|
||||
|
||||
type Person struct {
|
||||
Name string `db:"name,size=64"`
|
||||
}
|
||||
type Place struct {
|
||||
Name string `db:"name"`
|
||||
}
|
||||
type Article struct {
|
||||
Title string `db:"title"`
|
||||
}
|
||||
type PP struct {
|
||||
Person `db:"person,required"`
|
||||
Place `db:",someflag"`
|
||||
Article `db:",required"`
|
||||
}
|
||||
// PP columns: (person.name name title)
|
||||
|
||||
pp := PP{}
|
||||
pp.Person.Name = "Peter"
|
||||
pp.Place.Name = "Toronto"
|
||||
pp.Article.Title = "Best city ever"
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(pp))
|
||||
// for i, f := range fields {
|
||||
// log.Println(i, f)
|
||||
// }
|
||||
|
||||
ppv := reflect.ValueOf(pp)
|
||||
|
||||
v := m.FieldByName(ppv, "person.name")
|
||||
if v.Interface().(string) != pp.Person.Name {
|
||||
t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string))
|
||||
}
|
||||
|
||||
v = m.FieldByName(ppv, "name")
|
||||
if v.Interface().(string) != pp.Place.Name {
|
||||
t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string))
|
||||
}
|
||||
|
||||
v = m.FieldByName(ppv, "title")
|
||||
if v.Interface().(string) != pp.Article.Title {
|
||||
t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string))
|
||||
}
|
||||
|
||||
fi := fields.GetByPath("person")
|
||||
if _, ok := fi.Options["required"]; !ok {
|
||||
t.Errorf("Expecting required option to be set")
|
||||
}
|
||||
if !fi.Embedded {
|
||||
t.Errorf("Expecting field to be embedded")
|
||||
}
|
||||
if len(fi.Index) != 1 || fi.Index[0] != 0 {
|
||||
t.Errorf("Expecting index to be [0]")
|
||||
}
|
||||
|
||||
fi = fields.GetByPath("person.name")
|
||||
if fi == nil {
|
||||
t.Errorf("Expecting person.name to exist")
|
||||
}
|
||||
if fi.Path != "person.name" {
|
||||
t.Errorf("Expecting %s, got %s", "person.name", fi.Path)
|
||||
}
|
||||
if fi.Options["size"] != "64" {
|
||||
t.Errorf("Expecting %s, got %s", "64", fi.Options["size"])
|
||||
}
|
||||
|
||||
fi = fields.GetByTraversal([]int{1, 0})
|
||||
if fi == nil {
|
||||
t.Errorf("Expecting traveral to exist")
|
||||
}
|
||||
if fi.Path != "name" {
|
||||
t.Errorf("Expecting %s, got %s", "name", fi.Path)
|
||||
}
|
||||
|
||||
fi = fields.GetByTraversal([]int{2})
|
||||
if fi == nil {
|
||||
t.Errorf("Expecting traversal to exist")
|
||||
}
|
||||
if _, ok := fi.Options["required"]; !ok {
|
||||
t.Errorf("Expecting required option to be set")
|
||||
}
|
||||
|
||||
trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"})
|
||||
if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) {
|
||||
t.Errorf("Expecting traversal: %v", trs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPtrFields(t *testing.T) {
|
||||
m := NewMapperTagFunc("db", strings.ToLower, nil)
|
||||
type Asset struct {
|
||||
Title string
|
||||
}
|
||||
type Post struct {
|
||||
*Asset `db:"asset"`
|
||||
Author string
|
||||
}
|
||||
|
||||
post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}}
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(post))
|
||||
if len(fields.Index) != 3 {
|
||||
t.Errorf("Expecting 3 fields")
|
||||
}
|
||||
|
||||
v := m.FieldByName(pv, "asset.title")
|
||||
if v.Interface().(string) != post.Asset.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamedPtrFields(t *testing.T) {
|
||||
m := NewMapperTagFunc("db", strings.ToLower, nil)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Asset struct {
|
||||
Title string
|
||||
|
||||
Owner *User `db:"owner"`
|
||||
}
|
||||
type Post struct {
|
||||
Author string
|
||||
|
||||
Asset1 *Asset `db:"asset1"`
|
||||
Asset2 *Asset `db:"asset2"`
|
||||
}
|
||||
|
||||
post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(post))
|
||||
if len(fields.Index) != 9 {
|
||||
t.Errorf("Expecting 9 fields")
|
||||
}
|
||||
|
||||
v := m.FieldByName(pv, "asset1.title")
|
||||
if v.Interface().(string) != post.Asset1.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset1.owner.name")
|
||||
if v.Interface().(string) != post.Asset1.Owner.Name {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset2.title")
|
||||
if v.Interface().(string) != post.Asset2.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset2.owner.name")
|
||||
if v.Interface().(string) != post.Asset2.Owner.Name {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldMap(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int
|
||||
B int
|
||||
C int
|
||||
}
|
||||
|
||||
f := Foo{1, 2, 3}
|
||||
m := NewMapperFunc("db", strings.ToLower)
|
||||
|
||||
fm := m.FieldMap(reflect.ValueOf(f))
|
||||
|
||||
if len(fm) != 3 {
|
||||
t.Errorf("Expecting %d keys, got %d", 3, len(fm))
|
||||
}
|
||||
if fm["a"].Interface().(int) != 1 {
|
||||
t.Errorf("Expecting %d, got %d", 1, ival(fm["a"]))
|
||||
}
|
||||
if fm["b"].Interface().(int) != 2 {
|
||||
t.Errorf("Expecting %d, got %d", 2, ival(fm["b"]))
|
||||
}
|
||||
if fm["c"].Interface().(int) != 3 {
|
||||
t.Errorf("Expecting %d, got %d", 3, ival(fm["c"]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagNameMapping(t *testing.T) {
|
||||
type Strategy struct {
|
||||
StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"`
|
||||
StrategyName string
|
||||
}
|
||||
|
||||
m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string {
|
||||
if strings.Contains(value, ",") {
|
||||
return strings.Split(value, ",")[0]
|
||||
}
|
||||
return value
|
||||
})
|
||||
strategy := Strategy{"1", "Alpah"}
|
||||
mapping := m.TypeMap(reflect.TypeOf(strategy))
|
||||
|
||||
for _, key := range []string{"strategy_id", "STRATEGYNAME"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapping(t *testing.T) {
|
||||
type Person struct {
|
||||
ID int
|
||||
Name string
|
||||
WearsGlasses bool `db:"wears_glasses"`
|
||||
}
|
||||
|
||||
m := NewMapperFunc("db", strings.ToLower)
|
||||
p := Person{1, "Jason", true}
|
||||
mapping := m.TypeMap(reflect.TypeOf(p))
|
||||
|
||||
for _, key := range []string{"id", "name", "wears_glasses"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
|
||||
type SportsPerson struct {
|
||||
Weight int
|
||||
Age int
|
||||
Person
|
||||
}
|
||||
s := SportsPerson{Weight: 100, Age: 30, Person: p}
|
||||
mapping = m.TypeMap(reflect.TypeOf(s))
|
||||
for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
|
||||
type RugbyPlayer struct {
|
||||
Position int
|
||||
IsIntense bool `db:"is_intense"`
|
||||
IsAllBlack bool `db:"-"`
|
||||
SportsPerson
|
||||
}
|
||||
r := RugbyPlayer{12, true, false, s}
|
||||
mapping = m.TypeMap(reflect.TypeOf(r))
|
||||
for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
|
||||
if fi := mapping.GetByPath("isallblack"); fi != nil {
|
||||
t.Errorf("Expecting to ignore `IsAllBlack` field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetByTraversal(t *testing.T) {
|
||||
type C struct {
|
||||
C0 int
|
||||
C1 int
|
||||
}
|
||||
type B struct {
|
||||
B0 string
|
||||
B1 *C
|
||||
}
|
||||
type A struct {
|
||||
A0 int
|
||||
A1 B
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Index []int
|
||||
ExpectedName string
|
||||
ExpectNil bool
|
||||
}{
|
||||
{
|
||||
Index: []int{0},
|
||||
ExpectedName: "A0",
|
||||
},
|
||||
{
|
||||
Index: []int{1, 0},
|
||||
ExpectedName: "B0",
|
||||
},
|
||||
{
|
||||
Index: []int{1, 1, 1},
|
||||
ExpectedName: "C1",
|
||||
},
|
||||
{
|
||||
Index: []int{3, 4, 5},
|
||||
ExpectNil: true,
|
||||
},
|
||||
{
|
||||
Index: []int{},
|
||||
ExpectNil: true,
|
||||
},
|
||||
{
|
||||
Index: nil,
|
||||
ExpectNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
m := NewMapperFunc("db", func(n string) string { return n })
|
||||
tm := m.TypeMap(reflect.TypeOf(A{}))
|
||||
|
||||
for i, tc := range testCases {
|
||||
fi := tm.GetByTraversal(tc.Index)
|
||||
if tc.ExpectNil {
|
||||
if fi != nil {
|
||||
t.Errorf("%d: expected nil, got %v", i, fi)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if fi == nil {
|
||||
t.Errorf("%d: expected %s, got nil", i, tc.ExpectedName)
|
||||
continue
|
||||
}
|
||||
|
||||
if fi.Name != tc.ExpectedName {
|
||||
t.Errorf("%d: expected %s, got %s", i, tc.ExpectedName, fi.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestMapperMethodsByName tests Mapper methods FieldByName and TraversalsByName
|
||||
func TestMapperMethodsByName(t *testing.T) {
|
||||
type C struct {
|
||||
C0 string
|
||||
C1 int
|
||||
}
|
||||
type B struct {
|
||||
B0 *C `db:"B0"`
|
||||
B1 C `db:"B1"`
|
||||
B2 string `db:"B2"`
|
||||
}
|
||||
type A struct {
|
||||
A0 *B `db:"A0"`
|
||||
B `db:"A1"`
|
||||
A2 int
|
||||
a3 int
|
||||
}
|
||||
|
||||
val := &A{
|
||||
A0: &B{
|
||||
B0: &C{C0: "0", C1: 1},
|
||||
B1: C{C0: "2", C1: 3},
|
||||
B2: "4",
|
||||
},
|
||||
B: B{
|
||||
B0: nil,
|
||||
B1: C{C0: "5", C1: 6},
|
||||
B2: "7",
|
||||
},
|
||||
A2: 8,
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
Name string
|
||||
ExpectInvalid bool
|
||||
ExpectedValue interface{}
|
||||
ExpectedIndexes []int
|
||||
}{
|
||||
{
|
||||
Name: "A0.B0.C0",
|
||||
ExpectedValue: "0",
|
||||
ExpectedIndexes: []int{0, 0, 0},
|
||||
},
|
||||
{
|
||||
Name: "A0.B0.C1",
|
||||
ExpectedValue: 1,
|
||||
ExpectedIndexes: []int{0, 0, 1},
|
||||
},
|
||||
{
|
||||
Name: "A0.B1.C0",
|
||||
ExpectedValue: "2",
|
||||
ExpectedIndexes: []int{0, 1, 0},
|
||||
},
|
||||
{
|
||||
Name: "A0.B1.C1",
|
||||
ExpectedValue: 3,
|
||||
ExpectedIndexes: []int{0, 1, 1},
|
||||
},
|
||||
{
|
||||
Name: "A0.B2",
|
||||
ExpectedValue: "4",
|
||||
ExpectedIndexes: []int{0, 2},
|
||||
},
|
||||
{
|
||||
Name: "A1.B0.C0",
|
||||
ExpectedValue: "",
|
||||
ExpectedIndexes: []int{1, 0, 0},
|
||||
},
|
||||
{
|
||||
Name: "A1.B0.C1",
|
||||
ExpectedValue: 0,
|
||||
ExpectedIndexes: []int{1, 0, 1},
|
||||
},
|
||||
{
|
||||
Name: "A1.B1.C0",
|
||||
ExpectedValue: "5",
|
||||
ExpectedIndexes: []int{1, 1, 0},
|
||||
},
|
||||
{
|
||||
Name: "A1.B1.C1",
|
||||
ExpectedValue: 6,
|
||||
ExpectedIndexes: []int{1, 1, 1},
|
||||
},
|
||||
{
|
||||
Name: "A1.B2",
|
||||
ExpectedValue: "7",
|
||||
ExpectedIndexes: []int{1, 2},
|
||||
},
|
||||
{
|
||||
Name: "A2",
|
||||
ExpectedValue: 8,
|
||||
ExpectedIndexes: []int{2},
|
||||
},
|
||||
{
|
||||
Name: "XYZ",
|
||||
ExpectInvalid: true,
|
||||
ExpectedIndexes: []int{},
|
||||
},
|
||||
{
|
||||
Name: "a3",
|
||||
ExpectInvalid: true,
|
||||
ExpectedIndexes: []int{},
|
||||
},
|
||||
}
|
||||
|
||||
// build the names array from the test cases
|
||||
names := make([]string, len(testCases))
|
||||
for i, tc := range testCases {
|
||||
names[i] = tc.Name
|
||||
}
|
||||
m := NewMapperFunc("db", func(n string) string { return n })
|
||||
v := reflect.ValueOf(val)
|
||||
values := m.FieldsByName(v, names)
|
||||
if len(values) != len(testCases) {
|
||||
t.Errorf("expected %d values, got %d", len(testCases), len(values))
|
||||
t.FailNow()
|
||||
}
|
||||
indexes := m.TraversalsByName(v.Type(), names)
|
||||
if len(indexes) != len(testCases) {
|
||||
t.Errorf("expected %d traversals, got %d", len(testCases), len(indexes))
|
||||
t.FailNow()
|
||||
}
|
||||
for i, val := range values {
|
||||
tc := testCases[i]
|
||||
traversal := indexes[i]
|
||||
if !reflect.DeepEqual(tc.ExpectedIndexes, traversal) {
|
||||
t.Errorf("expected %v, got %v", tc.ExpectedIndexes, traversal)
|
||||
t.FailNow()
|
||||
}
|
||||
val = reflect.Indirect(val)
|
||||
if tc.ExpectInvalid {
|
||||
if val.IsValid() {
|
||||
t.Errorf("%d: expected zero value, got %v", i, val)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !val.IsValid() {
|
||||
t.Errorf("%d: expected valid value, got %v", i, val)
|
||||
continue
|
||||
}
|
||||
actualValue := reflect.Indirect(val).Interface()
|
||||
if !reflect.DeepEqual(tc.ExpectedValue, actualValue) {
|
||||
t.Errorf("%d: expected %v, got %v", i, tc.ExpectedValue, actualValue)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldByIndexes(t *testing.T) {
|
||||
type C struct {
|
||||
C0 bool
|
||||
C1 string
|
||||
C2 int
|
||||
C3 map[string]int
|
||||
}
|
||||
type B struct {
|
||||
B1 C
|
||||
B2 *C
|
||||
}
|
||||
type A struct {
|
||||
A1 B
|
||||
A2 *B
|
||||
}
|
||||
testCases := []struct {
|
||||
value interface{}
|
||||
indexes []int
|
||||
expectedValue interface{}
|
||||
readOnly bool
|
||||
}{
|
||||
{
|
||||
value: A{
|
||||
A1: B{B1: C{C0: true}},
|
||||
},
|
||||
indexes: []int{0, 0, 0},
|
||||
expectedValue: true,
|
||||
readOnly: true,
|
||||
},
|
||||
{
|
||||
value: A{
|
||||
A2: &B{B2: &C{C1: "answer"}},
|
||||
},
|
||||
indexes: []int{1, 1, 1},
|
||||
expectedValue: "answer",
|
||||
readOnly: true,
|
||||
},
|
||||
{
|
||||
value: &A{},
|
||||
indexes: []int{1, 1, 3},
|
||||
expectedValue: map[string]int{},
|
||||
},
|
||||
}
|
||||
|
||||
for i, tc := range testCases {
|
||||
checkResults := func(v reflect.Value) {
|
||||
if tc.expectedValue == nil {
|
||||
if !v.IsNil() {
|
||||
t.Errorf("%d: expected nil, actual %v", i, v.Interface())
|
||||
}
|
||||
} else {
|
||||
if !reflect.DeepEqual(tc.expectedValue, v.Interface()) {
|
||||
t.Errorf("%d: expected %v, actual %v", i, tc.expectedValue, v.Interface())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
checkResults(FieldByIndexes(reflect.ValueOf(tc.value), tc.indexes))
|
||||
if tc.readOnly {
|
||||
checkResults(FieldByIndexesReadOnly(reflect.ValueOf(tc.value), tc.indexes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMustBe(t *testing.T) {
|
||||
typ := reflect.TypeOf(E1{})
|
||||
mustBe(typ, reflect.Struct)
|
||||
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
valueErr, ok := r.(*reflect.ValueError)
|
||||
if !ok {
|
||||
t.Errorf("unexpected Method: %s", valueErr.Method)
|
||||
t.Error("expected panic with *reflect.ValueError")
|
||||
return
|
||||
}
|
||||
if valueErr.Method != "github.com/jmoiron/sqlx/reflectx.TestMustBe" {
|
||||
}
|
||||
if valueErr.Kind != reflect.String {
|
||||
t.Errorf("unexpected Kind: %s", valueErr.Kind)
|
||||
}
|
||||
} else {
|
||||
t.Error("expected panic")
|
||||
}
|
||||
}()
|
||||
|
||||
typ = reflect.TypeOf("string")
|
||||
mustBe(typ, reflect.Struct)
|
||||
t.Error("got here, didn't expect to")
|
||||
}
|
||||
|
||||
type E1 struct {
|
||||
A int
|
||||
}
|
||||
type E2 struct {
|
||||
E1
|
||||
B int
|
||||
}
|
||||
type E3 struct {
|
||||
E2
|
||||
C int
|
||||
}
|
||||
type E4 struct {
|
||||
E3
|
||||
D int
|
||||
}
|
||||
|
||||
func BenchmarkFieldNameL1(b *testing.B) {
|
||||
e4 := E4{D: 1}
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.FieldByName("D")
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldNameL4(b *testing.B) {
|
||||
e4 := E4{}
|
||||
e4.A = 1
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.FieldByName("A")
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldPosL1(b *testing.B) {
|
||||
e4 := E4{D: 1}
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.Field(1)
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldPosL4(b *testing.B) {
|
||||
e4 := E4{}
|
||||
e4.A = 1
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.Field(0)
|
||||
f = f.Field(0)
|
||||
f = f.Field(0)
|
||||
f = f.Field(0)
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldByIndexL4(b *testing.B) {
|
||||
e4 := E4{}
|
||||
e4.A = 1
|
||||
idx := []int{0, 0, 0, 0}
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := FieldByIndexes(v, idx)
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,335 @@
|
|||
// +build go1.8
|
||||
|
||||
package sqlx
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// ConnectContext to a database and verify with a ping.
|
||||
func ConnectContext(ctx context.Context, driverName, dataSourceName string) (*DB, error) {
|
||||
db, err := Open(driverName, dataSourceName)
|
||||
if err != nil {
|
||||
return db, err
|
||||
}
|
||||
err = db.PingContext(ctx)
|
||||
return db, err
|
||||
}
|
||||
|
||||
// QueryerContext is an interface used by GetContext and SelectContext
|
||||
type QueryerContext interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error)
|
||||
QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row
|
||||
}
|
||||
|
||||
// PreparerContext is an interface used by PreparexContext.
|
||||
type PreparerContext interface {
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
// ExecerContext is an interface used by MustExecContext and LoadFileContext
|
||||
type ExecerContext interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
// ExtContext is a union interface which can bind, query, and exec, with Context
|
||||
// used by NamedQueryContext and NamedExecContext.
|
||||
type ExtContext interface {
|
||||
binder
|
||||
QueryerContext
|
||||
ExecerContext
|
||||
}
|
||||
|
||||
// SelectContext executes a query using the provided Queryer, and StructScans
|
||||
// each row into dest, which must be a slice. If the slice elements are
|
||||
// scannable, then the result set must have only one column. Otherwise,
|
||||
// StructScan is used. The *sql.Rows are closed automatically.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func SelectContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
|
||||
rows, err := q.QueryxContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// if something happens here, we want to make sure the rows are Closed
|
||||
defer rows.Close()
|
||||
return scanAll(rows, dest, false)
|
||||
}
|
||||
|
||||
// PreparexContext prepares a statement.
|
||||
//
|
||||
// The provided context is used for the preparation of the statement, not for
|
||||
// the execution of the statement.
|
||||
func PreparexContext(ctx context.Context, p PreparerContext, query string) (*Stmt, error) {
|
||||
s, err := p.PrepareContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Stmt{Stmt: s, unsafe: isUnsafe(p), Mapper: mapperFor(p)}, err
|
||||
}
|
||||
|
||||
// GetContext does a QueryRow using the provided Queryer, and scans the
|
||||
// resulting row to dest. If dest is scannable, the result must only have one
|
||||
// column. Otherwise, StructScan is used. Get will return sql.ErrNoRows like
|
||||
// row.Scan would. Any placeholder parameters are replaced with supplied args.
|
||||
// An error is returned if the result set is empty.
|
||||
func GetContext(ctx context.Context, q QueryerContext, dest interface{}, query string, args ...interface{}) error {
|
||||
r := q.QueryRowxContext(ctx, query, args...)
|
||||
return r.scanAny(dest, false)
|
||||
}
|
||||
|
||||
// LoadFileContext exec's every statement in a file (as a single call to Exec).
|
||||
// LoadFileContext may return a nil *sql.Result if errors are encountered
|
||||
// locating or reading the file at path. LoadFile reads the entire file into
|
||||
// memory, so it is not suitable for loading large data dumps, but can be useful
|
||||
// for initializing schemas or loading indexes.
|
||||
//
|
||||
// FIXME: this does not really work with multi-statement files for mattn/go-sqlite3
|
||||
// or the go-mysql-driver/mysql drivers; pq seems to be an exception here. Detecting
|
||||
// this by requiring something with DriverName() and then attempting to split the
|
||||
// queries will be difficult to get right, and its current driver-specific behavior
|
||||
// is deemed at least not complex in its incorrectness.
|
||||
func LoadFileContext(ctx context.Context, e ExecerContext, path string) (*sql.Result, error) {
|
||||
realpath, err := filepath.Abs(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contents, err := ioutil.ReadFile(realpath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := e.ExecContext(ctx, string(contents))
|
||||
return &res, err
|
||||
}
|
||||
|
||||
// MustExecContext execs the query using e and panics if there was an error.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func MustExecContext(ctx context.Context, e ExecerContext, query string, args ...interface{}) sql.Result {
|
||||
res, err := e.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// PrepareNamedContext returns an sqlx.NamedStmt
|
||||
func (db *DB) PrepareNamedContext(ctx context.Context, query string) (*NamedStmt, error) {
|
||||
return prepareNamedContext(ctx, db, query)
|
||||
}
|
||||
|
||||
// NamedQueryContext using this DB.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (db *DB) NamedQueryContext(ctx context.Context, query string, arg interface{}) (*Rows, error) {
|
||||
return NamedQueryContext(ctx, db, query, arg)
|
||||
}
|
||||
|
||||
// NamedExecContext using this DB.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (db *DB) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
|
||||
return NamedExecContext(ctx, db, query, arg)
|
||||
}
|
||||
|
||||
// SelectContext using this DB.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (db *DB) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return SelectContext(ctx, db, dest, query, args...)
|
||||
}
|
||||
|
||||
// GetContext using this DB.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
// An error is returned if the result set is empty.
|
||||
func (db *DB) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return GetContext(ctx, db, dest, query, args...)
|
||||
}
|
||||
|
||||
// PreparexContext returns an sqlx.Stmt instead of a sql.Stmt.
|
||||
//
|
||||
// The provided context is used for the preparation of the statement, not for
|
||||
// the execution of the statement.
|
||||
func (db *DB) PreparexContext(ctx context.Context, query string) (*Stmt, error) {
|
||||
return PreparexContext(ctx, db, query)
|
||||
}
|
||||
|
||||
// QueryxContext queries the database and returns an *sqlx.Rows.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (db *DB) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
|
||||
r, err := db.DB.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{Rows: r, unsafe: db.unsafe, Mapper: db.Mapper}, err
|
||||
}
|
||||
|
||||
// QueryRowxContext queries the database and returns an *sqlx.Row.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (db *DB) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
|
||||
rows, err := db.DB.QueryContext(ctx, query, args...)
|
||||
return &Row{rows: rows, err: err, unsafe: db.unsafe, Mapper: db.Mapper}
|
||||
}
|
||||
|
||||
// MustBeginTx starts a transaction, and panics on error. Returns an *sqlx.Tx instead
|
||||
// of an *sql.Tx.
|
||||
//
|
||||
// The provided context is used until the transaction is committed or rolled
|
||||
// back. If the context is canceled, the sql package will roll back the
|
||||
// transaction. Tx.Commit will return an error if the context provided to
|
||||
// MustBeginContext is canceled.
|
||||
func (db *DB) MustBeginTx(ctx context.Context, opts *sql.TxOptions) *Tx {
|
||||
tx, err := db.BeginTxx(ctx, opts)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return tx
|
||||
}
|
||||
|
||||
// MustExecContext (panic) runs MustExec using this database.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (db *DB) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
|
||||
return MustExecContext(ctx, db, query, args...)
|
||||
}
|
||||
|
||||
// BeginTxx begins a transaction and returns an *sqlx.Tx instead of an
|
||||
// *sql.Tx.
|
||||
//
|
||||
// The provided context is used until the transaction is committed or rolled
|
||||
// back. If the context is canceled, the sql package will roll back the
|
||||
// transaction. Tx.Commit will return an error if the context provided to
|
||||
// BeginxContext is canceled.
|
||||
func (db *DB) BeginTxx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
|
||||
tx, err := db.DB.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Tx{Tx: tx, driverName: db.driverName, unsafe: db.unsafe, Mapper: db.Mapper}, err
|
||||
}
|
||||
|
||||
// StmtxContext returns a version of the prepared statement which runs within a
|
||||
// transaction. Provided stmt can be either *sql.Stmt or *sqlx.Stmt.
|
||||
func (tx *Tx) StmtxContext(ctx context.Context, stmt interface{}) *Stmt {
|
||||
var s *sql.Stmt
|
||||
switch v := stmt.(type) {
|
||||
case Stmt:
|
||||
s = v.Stmt
|
||||
case *Stmt:
|
||||
s = v.Stmt
|
||||
case sql.Stmt:
|
||||
s = &v
|
||||
case *sql.Stmt:
|
||||
s = v
|
||||
default:
|
||||
panic(fmt.Sprintf("non-statement type %v passed to Stmtx", reflect.ValueOf(stmt).Type()))
|
||||
}
|
||||
return &Stmt{Stmt: tx.StmtContext(ctx, s), Mapper: tx.Mapper}
|
||||
}
|
||||
|
||||
// NamedStmtContext returns a version of the prepared statement which runs
|
||||
// within a transaction.
|
||||
func (tx *Tx) NamedStmtContext(ctx context.Context, stmt *NamedStmt) *NamedStmt {
|
||||
return &NamedStmt{
|
||||
QueryString: stmt.QueryString,
|
||||
Params: stmt.Params,
|
||||
Stmt: tx.StmtxContext(ctx, stmt.Stmt),
|
||||
}
|
||||
}
|
||||
|
||||
// MustExecContext runs MustExecContext within a transaction.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (tx *Tx) MustExecContext(ctx context.Context, query string, args ...interface{}) sql.Result {
|
||||
return MustExecContext(ctx, tx, query, args...)
|
||||
}
|
||||
|
||||
// QueryxContext within a transaction and context.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (tx *Tx) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
|
||||
r, err := tx.Tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{Rows: r, unsafe: tx.unsafe, Mapper: tx.Mapper}, err
|
||||
}
|
||||
|
||||
// SelectContext within a transaction and context.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (tx *Tx) SelectContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return SelectContext(ctx, tx, dest, query, args...)
|
||||
}
|
||||
|
||||
// GetContext within a transaction and context.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
// An error is returned if the result set is empty.
|
||||
func (tx *Tx) GetContext(ctx context.Context, dest interface{}, query string, args ...interface{}) error {
|
||||
return GetContext(ctx, tx, dest, query, args...)
|
||||
}
|
||||
|
||||
// QueryRowxContext within a transaction and context.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (tx *Tx) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
|
||||
rows, err := tx.Tx.QueryContext(ctx, query, args...)
|
||||
return &Row{rows: rows, err: err, unsafe: tx.unsafe, Mapper: tx.Mapper}
|
||||
}
|
||||
|
||||
// NamedExecContext using this Tx.
|
||||
// Any named placeholder parameters are replaced with fields from arg.
|
||||
func (tx *Tx) NamedExecContext(ctx context.Context, query string, arg interface{}) (sql.Result, error) {
|
||||
return NamedExecContext(ctx, tx, query, arg)
|
||||
}
|
||||
|
||||
// SelectContext using the prepared statement.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (s *Stmt) SelectContext(ctx context.Context, dest interface{}, args ...interface{}) error {
|
||||
return SelectContext(ctx, &qStmt{s}, dest, "", args...)
|
||||
}
|
||||
|
||||
// GetContext using the prepared statement.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
// An error is returned if the result set is empty.
|
||||
func (s *Stmt) GetContext(ctx context.Context, dest interface{}, args ...interface{}) error {
|
||||
return GetContext(ctx, &qStmt{s}, dest, "", args...)
|
||||
}
|
||||
|
||||
// MustExecContext (panic) using this statement. Note that the query portion of
|
||||
// the error output will be blank, as Stmt does not expose its query.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (s *Stmt) MustExecContext(ctx context.Context, args ...interface{}) sql.Result {
|
||||
return MustExecContext(ctx, &qStmt{s}, "", args...)
|
||||
}
|
||||
|
||||
// QueryRowxContext using this statement.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (s *Stmt) QueryRowxContext(ctx context.Context, args ...interface{}) *Row {
|
||||
qs := &qStmt{s}
|
||||
return qs.QueryRowxContext(ctx, "", args...)
|
||||
}
|
||||
|
||||
// QueryxContext using this statement.
|
||||
// Any placeholder parameters are replaced with supplied args.
|
||||
func (s *Stmt) QueryxContext(ctx context.Context, args ...interface{}) (*Rows, error) {
|
||||
qs := &qStmt{s}
|
||||
return qs.QueryxContext(ctx, "", args...)
|
||||
}
|
||||
|
||||
func (q *qStmt) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
return q.Stmt.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
func (q *qStmt) QueryxContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
|
||||
r, err := q.Stmt.QueryContext(ctx, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Rows{Rows: r, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}, err
|
||||
}
|
||||
|
||||
func (q *qStmt) QueryRowxContext(ctx context.Context, query string, args ...interface{}) *Row {
|
||||
rows, err := q.Stmt.QueryContext(ctx, args...)
|
||||
return &Row{rows: rows, err: err, unsafe: q.Stmt.unsafe, Mapper: q.Stmt.Mapper}
|
||||
}
|
||||
|
||||
func (q *qStmt) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
return q.Stmt.ExecContext(ctx, args...)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,5 @@
|
|||
# types
|
||||
|
||||
The types package provides some useful types which implement the `sql.Scanner`
|
||||
and `driver.Valuer` interfaces, suitable for use as scan and value targets with
|
||||
database/sql.
|
|
@ -0,0 +1,172 @@
|
|||
package types
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"io/ioutil"
|
||||
)
|
||||
|
||||
// GzippedText is a []byte which transparently gzips data being submitted to
|
||||
// a database and ungzips data being Scanned from a database.
|
||||
type GzippedText []byte
|
||||
|
||||
// Value implements the driver.Valuer interface, gzipping the raw value of
|
||||
// this GzippedText.
|
||||
func (g GzippedText) Value() (driver.Value, error) {
|
||||
b := make([]byte, 0, len(g))
|
||||
buf := bytes.NewBuffer(b)
|
||||
w := gzip.NewWriter(buf)
|
||||
w.Write(g)
|
||||
w.Close()
|
||||
return buf.Bytes(), nil
|
||||
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface, ungzipping the value coming off
|
||||
// the wire and storing the raw result in the GzippedText.
|
||||
func (g *GzippedText) Scan(src interface{}) error {
|
||||
var source []byte
|
||||
switch src.(type) {
|
||||
case string:
|
||||
source = []byte(src.(string))
|
||||
case []byte:
|
||||
source = src.([]byte)
|
||||
default:
|
||||
return errors.New("Incompatible type for GzippedText")
|
||||
}
|
||||
reader, err := gzip.NewReader(bytes.NewReader(source))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer reader.Close()
|
||||
b, err := ioutil.ReadAll(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*g = GzippedText(b)
|
||||
return nil
|
||||
}
|
||||
|
||||
// JSONText is a json.RawMessage, which is a []byte underneath.
|
||||
// Value() validates the json format in the source, and returns an error if
|
||||
// the json is not valid. Scan does no validation. JSONText additionally
|
||||
// implements `Unmarshal`, which unmarshals the json within to an interface{}
|
||||
type JSONText json.RawMessage
|
||||
|
||||
var emptyJSON = JSONText("{}")
|
||||
|
||||
// MarshalJSON returns the *j as the JSON encoding of j.
|
||||
func (j JSONText) MarshalJSON() ([]byte, error) {
|
||||
if len(j) == 0 {
|
||||
return emptyJSON, nil
|
||||
}
|
||||
return j, nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON sets *j to a copy of data
|
||||
func (j *JSONText) UnmarshalJSON(data []byte) error {
|
||||
if j == nil {
|
||||
return errors.New("JSONText: UnmarshalJSON on nil pointer")
|
||||
}
|
||||
*j = append((*j)[0:0], data...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value returns j as a value. This does a validating unmarshal into another
|
||||
// RawMessage. If j is invalid json, it returns an error.
|
||||
func (j JSONText) Value() (driver.Value, error) {
|
||||
var m json.RawMessage
|
||||
var err = j.Unmarshal(&m)
|
||||
if err != nil {
|
||||
return []byte{}, err
|
||||
}
|
||||
return []byte(j), nil
|
||||
}
|
||||
|
||||
// Scan stores the src in *j. No validation is done.
|
||||
func (j *JSONText) Scan(src interface{}) error {
|
||||
var source []byte
|
||||
switch t := src.(type) {
|
||||
case string:
|
||||
source = []byte(t)
|
||||
case []byte:
|
||||
if len(t) == 0 {
|
||||
source = emptyJSON
|
||||
} else {
|
||||
source = t
|
||||
}
|
||||
case nil:
|
||||
*j = emptyJSON
|
||||
default:
|
||||
return errors.New("Incompatible type for JSONText")
|
||||
}
|
||||
*j = JSONText(append((*j)[0:0], source...))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unmarshal unmarshal's the json in j to v, as in json.Unmarshal.
|
||||
func (j *JSONText) Unmarshal(v interface{}) error {
|
||||
if len(*j) == 0 {
|
||||
*j = emptyJSON
|
||||
}
|
||||
return json.Unmarshal([]byte(*j), v)
|
||||
}
|
||||
|
||||
// String supports pretty printing for JSONText types.
|
||||
func (j JSONText) String() string {
|
||||
return string(j)
|
||||
}
|
||||
|
||||
// NullJSONText represents a JSONText that may be null.
|
||||
// NullJSONText implements the scanner interface so
|
||||
// it can be used as a scan destination, similar to NullString.
|
||||
type NullJSONText struct {
|
||||
JSONText
|
||||
Valid bool // Valid is true if JSONText is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (n *NullJSONText) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
n.JSONText, n.Valid = emptyJSON, false
|
||||
return nil
|
||||
}
|
||||
n.Valid = true
|
||||
return n.JSONText.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (n NullJSONText) Value() (driver.Value, error) {
|
||||
if !n.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return n.JSONText.Value()
|
||||
}
|
||||
|
||||
// BitBool is an implementation of a bool for the MySQL type BIT(1).
|
||||
// This type allows you to avoid wasting an entire byte for MySQL's boolean type TINYINT.
|
||||
type BitBool bool
|
||||
|
||||
// Value implements the driver.Valuer interface,
|
||||
// and turns the BitBool into a bitfield (BIT(1)) for MySQL storage.
|
||||
func (b BitBool) Value() (driver.Value, error) {
|
||||
if b {
|
||||
return []byte{1}, nil
|
||||
}
|
||||
return []byte{0}, nil
|
||||
}
|
||||
|
||||
// Scan implements the sql.Scanner interface,
|
||||
// and turns the bitfield incoming from MySQL into a BitBool
|
||||
func (b *BitBool) Scan(src interface{}) error {
|
||||
v, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("bad []byte type assertion")
|
||||
}
|
||||
*b = v[0] == 1
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,127 @@
|
|||
package types
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestGzipText(t *testing.T) {
|
||||
g := GzippedText("Hello, world")
|
||||
v, err := g.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
err = (&g).Scan(v)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
if string(g) != "Hello, world" {
|
||||
t.Errorf("Was expecting the string we sent in (Hello World), got %s", string(g))
|
||||
}
|
||||
}
|
||||
|
||||
func TestJSONText(t *testing.T) {
|
||||
j := JSONText(`{"foo": 1, "bar": 2}`)
|
||||
v, err := j.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
err = (&j).Scan(v)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
m := map[string]interface{}{}
|
||||
j.Unmarshal(&m)
|
||||
|
||||
if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 {
|
||||
t.Errorf("Expected valid json but got some garbage instead? %#v", m)
|
||||
}
|
||||
|
||||
j = JSONText(`{"foo": 1, invalid, false}`)
|
||||
v, err = j.Value()
|
||||
if err == nil {
|
||||
t.Errorf("Was expecting invalid json to fail!")
|
||||
}
|
||||
|
||||
j = JSONText("")
|
||||
v, err = j.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
|
||||
err = (&j).Scan(v)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
|
||||
j = JSONText(nil)
|
||||
v, err = j.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
|
||||
err = (&j).Scan(v)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullJSONText(t *testing.T) {
|
||||
j := NullJSONText{}
|
||||
err := j.Scan(`{"foo": 1, "bar": 2}`)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
v, err := j.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
err = (&j).Scan(v)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
m := map[string]interface{}{}
|
||||
j.Unmarshal(&m)
|
||||
|
||||
if m["foo"].(float64) != 1 || m["bar"].(float64) != 2 {
|
||||
t.Errorf("Expected valid json but got some garbage instead? %#v", m)
|
||||
}
|
||||
|
||||
j = NullJSONText{}
|
||||
err = j.Scan(nil)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
if j.Valid != false {
|
||||
t.Errorf("Expected valid to be false, but got true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBitBool(t *testing.T) {
|
||||
// Test true value
|
||||
var b BitBool = true
|
||||
|
||||
v, err := b.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Cannot return error")
|
||||
}
|
||||
err = (&b).Scan(v)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
if !b {
|
||||
t.Errorf("Was expecting the bool we sent in (true), got %v", b)
|
||||
}
|
||||
|
||||
// Test false value
|
||||
b = false
|
||||
|
||||
v, err = b.Value()
|
||||
if err != nil {
|
||||
t.Errorf("Cannot return error")
|
||||
}
|
||||
err = (&b).Scan(v)
|
||||
if err != nil {
|
||||
t.Errorf("Was not expecting an error")
|
||||
}
|
||||
if b {
|
||||
t.Errorf("Was expecting the bool we sent in (false), got %v", b)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
// Created by cgo -godefs - DO NOT EDIT
|
||||
// cgo -godefs defs_darwin.go
|
||||
|
||||
package socket
|
||||
|
||||
const (
|
||||
sysAF_UNSPEC = 0x0
|
||||
sysAF_INET = 0x2
|
||||
sysAF_INET6 = 0x1e
|
||||
|
||||
sysSOCK_RAW = 0x3
|
||||
)
|
||||
|
||||
type iovec struct {
|
||||
Base *byte
|
||||
Len uint64
|
||||
}
|
||||
|
||||
type msghdr struct {
|
||||
Name *byte
|
||||
Namelen uint32
|
||||
Pad_cgo_0 [4]byte
|
||||
Iov *iovec
|
||||
Iovlen int32
|
||||
Pad_cgo_1 [4]byte
|
||||
Control *byte
|
||||
Controllen uint32
|
||||
Flags int32
|
||||
}
|
||||
|
||||
type cmsghdr struct {
|
||||
Len uint32
|
||||
Level int32
|
||||
Type int32
|
||||
}
|
||||
|
||||
type sockaddrInet struct {
|
||||
Len uint8
|
||||
Family uint8
|
||||
Port uint16
|
||||
Addr [4]byte /* in_addr */
|
||||
Zero [8]int8
|
||||
}
|
||||
|
||||
type sockaddrInet6 struct {
|
||||
Len uint8
|
||||
Family uint8
|
||||
Port uint16
|
||||
Flowinfo uint32
|
||||
Addr [16]byte /* in6_addr */
|
||||
Scope_id uint32
|
||||
}
|
||||
|
||||
const (
|
||||
sizeofIovec = 0x10
|
||||
sizeofMsghdr = 0x30
|
||||
sizeofCmsghdr = 0xc
|
||||
|
||||
sizeofSockaddrInet = 0x10
|
||||
sizeofSockaddrInet6 = 0x1c
|
||||
)
|
Loading…
Reference in New Issue