Refactor main into soak and squeeze

This commit is contained in:
Trevor Slocum 2019-06-22 05:44:41 -07:00
parent 719f1eb407
commit f65569e99b
1 changed files with 86 additions and 39 deletions

125
main.go
View File

@ -32,13 +32,18 @@ import (
"os/exec"
)
// BufferSize determines the size of the input/output buffer. The buffer is
// used directly when soaking initially. If the input is larger than the buffer
// it is written to a temporary file instead. The buffer is reused when
// writing to and reading from the temporary file.
const BufferSize = 64 * 1024 // 64 KiB
func CheckError(err error) {
if err != nil {
log.Fatal(err)
}
}
var (
buf = make([]byte, BufferSize)
bufInitialRead int
appendFile bool
tmpFile *os.File
)
func init() {
log.SetFlags(0)
@ -46,72 +51,97 @@ func init() {
}
func main() {
appendFile := flag.Bool("a", false, "append")
flag.BoolVar(&appendFile, "a", false, "append")
flag.Usage = func() {
_, _ = fmt.Fprintf(os.Stderr, "%s [-a] [file]: soak up all input from stdin and write it to [file] or stdout\n", os.Args[0])
fmt.Fprintf(os.Stderr, "%s [-a] [file]: soak up all input from stdin and write it to [file] or stdout\n", os.Args[0])
}
flag.Parse()
writeFileName := flag.Arg(0)
err := soak()
if err != nil {
log.Fatal(err)
}
defer cleanUp()
tmpFile, err := ioutil.TempFile("", "sponge-*.tmp")
CheckError(err)
defer func(tmpFile *os.File) {
if tmpFile == nil {
return // File was closed elsewhere
err = squeeze(flag.Arg(0))
if err != nil {
log.Fatal(err)
}
}
func soak() error {
var err error
bufInitialRead, err = io.ReadFull(os.Stdin, buf)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return err
} else if bufInitialRead == BufferSize && err != io.ErrUnexpectedEOF {
tmpFile, err = ioutil.TempFile("", "sponge-*.tmp")
if err != nil {
return err
}
tmpFileName := tmpFile.Name()
_ = tmpFile.Close()
_ = os.Remove(tmpFileName)
}(tmpFile)
_, err = tmpFile.Write(buf[:bufInitialRead])
if err != nil {
tmpFile.Close()
return err
}
buf := make([]byte, BufferSize)
_, err = io.CopyBuffer(tmpFile, os.Stdin, buf)
if err != nil {
tmpFile.Close()
return err
}
_, err = io.CopyBuffer(tmpFile, os.Stdin, buf)
CheckError(err)
_, err = tmpFile.Seek(io.SeekStart, 0)
if err != nil {
tmpFile.Close()
return err
}
}
_, err = tmpFile.Seek(io.SeekStart, 0)
CheckError(err)
return nil
}
func squeeze(filename string) error {
var out io.Writer
if writeFileName != "" {
if filename != "" {
regularFile := true
fileInfo, err := os.Stat(writeFileName)
fileInfo, err := os.Stat(filename)
if err != nil {
if !os.IsNotExist(err) {
CheckError(err)
return err
}
} else {
regularFile = fileInfo.Mode().IsRegular()
}
if regularFile && !*appendFile {
tmpFileName := tmpFile.Name()
if tmpFile != nil && regularFile && !appendFile {
tmpFile.Close()
tmpFile = nil
err = os.Rename(tmpFileName, writeFileName)
// Rename temporary file
err = os.Rename(tmpFile.Name(), filename)
if err != nil {
// Fall back to mv
cmd := exec.Command("mv", tmpFileName, writeFileName)
cmd := exec.Command("mv", tmpFile.Name(), filename)
err = cmd.Run()
}
CheckError(err)
return
tmpFile = nil
return err
}
openFlags := os.O_WRONLY
if regularFile {
openFlags |= os.O_CREATE
}
if *appendFile {
if appendFile {
openFlags |= os.O_APPEND
} else if regularFile {
openFlags |= os.O_TRUNC
}
writeFile, err := os.OpenFile(writeFileName, openFlags, 0644)
CheckError(err)
writeFile, err := os.OpenFile(filename, openFlags, 0644)
if err != nil {
return err
}
defer writeFile.Close()
out = writeFile
@ -119,6 +149,23 @@ func main() {
out = os.Stdout
}
_, err = io.CopyBuffer(out, tmpFile, buf)
CheckError(err)
var err error
if tmpFile != nil {
_, err = io.CopyBuffer(out, tmpFile, buf)
} else {
_, err = out.Write(buf[:bufInitialRead])
}
return err
}
func cleanUp() {
if tmpFile == nil {
return
}
tmpFile.Close()
err := os.Remove(tmpFile.Name())
if err != nil {
log.Fatal(err)
}
}