1 Commits

Author SHA1 Message Date
Zoe
adac21ce29 revert flattening (#3), add -C 2025-05-22 15:52:34 -05:00
2 changed files with 81 additions and 106 deletions

168
main.go
View File

@@ -13,9 +13,7 @@ import (
"os/signal" "os/signal"
"path" "path"
"path/filepath" "path/filepath"
"regexp"
"runtime" "runtime"
"sort"
"strings" "strings"
"sync" "sync"
"syscall" "syscall"
@@ -25,8 +23,6 @@ import (
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
) )
var executableName string
//go:embed embed/zqdgr.config.json //go:embed embed/zqdgr.config.json
var zqdgrConfig []byte var zqdgrConfig []byte
@@ -48,6 +44,7 @@ type Config struct {
} }
type Script struct { type Script struct {
zqdgr *ZQDGR
command *exec.Cmd command *exec.Cmd
mutex sync.Mutex mutex sync.Mutex
scriptName string scriptName string
@@ -57,60 +54,44 @@ type Script struct {
exitCode int exitCode int
} }
func flattenZQDGRScript(commandString string) string { type ZQDGR struct {
keys := make([]string, 0, len(config.Scripts)) Config Config
for k := range config.Scripts { WorkingDirectory string
keys = append(keys, k) EnableWebSocket bool
WSServer *WSServer
} }
// Sort the keys in descending order in order to prevent scripts that might be substrings of other scripts to type WSServer struct {
// evaluate first. upgrader websocket.Upgrader
sort.Slice(keys, func(i, j int) bool { clients map[*websocket.Conn]bool
return len(keys[i]) > len(keys[j]) clientsMux sync.Mutex
})
// escape scripts to be evaluated via regex
escapedKeys := make([]string, len(keys))
for i, key := range keys {
escapedKeys[i] = regexp.QuoteMeta(key)
}
pattern := `\b(` + executableName + `)\b` + `\s+` + `\b(` + strings.Join(escapedKeys, "|") + `)\b`
re := regexp.MustCompile(pattern)
currentCommand := commandString
for {
previousCommand := currentCommand
currentCommand = re.ReplaceAllStringFunc(currentCommand, func(match string) string {
// match the script name, not the whole `zqdgr script` command
match = strings.Split(match, " ")[1]
if val, ok := config.Scripts[match]; ok {
return val
}
return match
})
// If the current command has not changed, we have completely evaluated the command.
if currentCommand == previousCommand {
break
}
} }
if re.MatchString(currentCommand) { func NewZQDGR(enableWebSocket bool, configDir string) *ZQDGR {
fmt.Println("Error: circular dependency detected in scripts") zqdgr := &ZQDGR{
os.Exit(1) WorkingDirectory: configDir,
} }
return currentCommand zqdgr.loadConfig()
zqdgr.EnableWebSocket = enableWebSocket
zqdgr.WSServer = &WSServer{
upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
},
clients: make(map[*websocket.Conn]bool),
clientsMux: sync.Mutex{},
} }
func NewCommand(scriptName string, args ...string) *exec.Cmd { return zqdgr
if script, ok := config.Scripts[scriptName]; ok { }
func (zqdgr *ZQDGR) NewCommand(scriptName string, args ...string) *exec.Cmd {
if script, ok := zqdgr.Config.Scripts[scriptName]; ok {
fullCmd := strings.Join(append([]string{script}, args...), " ") fullCmd := strings.Join(append([]string{script}, args...), " ")
fullCmd = flattenZQDGRScript(fullCmd)
var cmd *exec.Cmd var cmd *exec.Cmd
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
cmd = exec.Command("cmd", "/C", fullCmd) cmd = exec.Command("cmd", "/C", fullCmd)
@@ -118,6 +99,8 @@ func NewCommand(scriptName string, args ...string) *exec.Cmd {
cmd = exec.Command("sh", "-c", fullCmd) cmd = exec.Command("sh", "-c", fullCmd)
} }
cmd.Dir = zqdgr.WorkingDirectory
cmd.SysProcAttr = &syscall.SysProcAttr{ cmd.SysProcAttr = &syscall.SysProcAttr{
Setpgid: true, Setpgid: true,
} }
@@ -132,8 +115,8 @@ func NewCommand(scriptName string, args ...string) *exec.Cmd {
} }
} }
func NewScript(scriptName string, args ...string) *Script { func (zqdgr *ZQDGR) NewScript(scriptName string, args ...string) *Script {
command := NewCommand(scriptName, args...) command := zqdgr.NewCommand(scriptName, args...)
if command == nil { if command == nil {
log.Fatal("script not found") log.Fatal("script not found")
@@ -141,6 +124,7 @@ func NewScript(scriptName string, args ...string) *Script {
} }
return &Script{ return &Script{
zqdgr: zqdgr,
command: command, command: command,
scriptName: scriptName, scriptName: scriptName,
isRestarting: false, isRestarting: false,
@@ -186,7 +170,7 @@ func (s *Script) Restart() error {
if s.command.Process != nil { if s.command.Process != nil {
var signal syscall.Signal var signal syscall.Signal
switch config.ShutdownSignal { switch s.zqdgr.Config.ShutdownSignal {
case "SIGINT": case "SIGINT":
signal = syscall.SIGINT signal = syscall.SIGINT
case "SIGTERM": case "SIGTERM":
@@ -202,7 +186,7 @@ func (s *Script) Restart() error {
} }
} }
s.command = NewCommand(s.scriptName) s.command = s.zqdgr.NewCommand(s.scriptName)
if s.command == nil { if s.command == nil {
// this should never happen // this should never happen
@@ -217,17 +201,17 @@ func (s *Script) Restart() error {
err := s.Start() err := s.Start()
// tell the websocket clients to refresh // tell the websocket clients to refresh
if enableWebSocket { if s.zqdgr.EnableWebSocket {
clientsMux.Lock() s.zqdgr.WSServer.clientsMux.Lock()
for client := range clients { for client := range s.zqdgr.WSServer.clients {
err := client.WriteMessage(websocket.TextMessage, []byte("refresh")) err := client.WriteMessage(websocket.TextMessage, []byte("refresh"))
if err != nil { if err != nil {
log.Printf("error broadcasting refresh: %v", err) log.Printf("error broadcasting refresh: %v", err)
client.Close() client.Close()
delete(clients, client) delete(s.zqdgr.WSServer.clients, client)
} }
} }
clientsMux.Unlock() s.zqdgr.WSServer.clientsMux.Unlock()
} }
return err return err
@@ -237,49 +221,36 @@ func (s *Script) Wait() {
s.wg.Wait() s.wg.Wait()
} }
func handleWs(w http.ResponseWriter, r *http.Request) { func (wsServer *WSServer) handleWs(w http.ResponseWriter, r *http.Request) {
conn, err := upgrader.Upgrade(w, r, nil) conn, err := wsServer.upgrader.Upgrade(w, r, nil)
if err != nil { if err != nil {
log.Printf("error upgrading connection: %v", err) log.Printf("error upgrading connection: %v", err)
return return
} }
clientsMux.Lock() wsServer.clientsMux.Lock()
clients[conn] = true wsServer.clients[conn] = true
clientsMux.Unlock() wsServer.clientsMux.Unlock()
for { for {
_, _, err := conn.ReadMessage() _, _, err := conn.ReadMessage()
if err != nil { if err != nil {
clientsMux.Lock() wsServer.clientsMux.Lock()
delete(clients, conn) delete(wsServer.clients, conn)
clientsMux.Unlock() wsServer.clientsMux.Unlock()
break break
} }
} }
} }
var ( func (zqdgr *ZQDGR) loadConfig() error {
enableWebSocket = false data, err := os.ReadFile(path.Join(zqdgr.WorkingDirectory, "zqdgr.config.json"))
config Config
script *Script
upgrader = websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return true
},
}
clients = make(map[*websocket.Conn]bool)
clientsMux sync.Mutex
)
func loadConfig() error {
data, err := os.ReadFile("zqdgr.config.json")
if err == nil { if err == nil {
if err := json.Unmarshal(data, &config); err != nil { if err := json.Unmarshal(data, &zqdgr.Config); err != nil {
return fmt.Errorf("error parsing config file: %v", err) return fmt.Errorf("error parsing config file: %v", err)
} }
} else { } else {
config = Config{ zqdgr.Config = Config{
Scripts: map[string]string{ Scripts: map[string]string{
"build": "go build", "build": "go build",
"run": "go run main.go", "run": "go run main.go",
@@ -293,22 +264,23 @@ func loadConfig() error {
func main() { func main() {
noWs := flag.Bool("no-ws", false, "Disable WebSocket server") noWs := flag.Bool("no-ws", false, "Disable WebSocket server")
configDir := flag.String("config", ".", "Path to the config directory")
flag.StringVar(configDir, "C", *configDir, "Path to the config directory")
flag.Parse() flag.Parse()
if err := loadConfig(); err != nil { os.Args = flag.Args()
log.Fatal(err)
} zqdgr := NewZQDGR(*noWs, *configDir)
var command string var command string
var commandArgs []string var commandArgs []string
// get the name of the executable, and if it's a path then get the base name for i, arg := range os.Args {
// this is mainly for testing
executableName = path.Base(os.Args[0])
for i, arg := range os.Args[1:] {
if arg == "--" { if arg == "--" {
if i+2 < len(os.Args) {
commandArgs = os.Args[i+2:] commandArgs = os.Args[i+2:]
}
break break
} }
@@ -462,7 +434,7 @@ func main() {
log.Fatal("please specify a script to run") log.Fatal("please specify a script to run")
} }
watchMode = true watchMode = true
for i := 0; i < len(commandArgs); i++ { for i := range commandArgs {
if strings.HasPrefix(commandArgs[i], "-") { if strings.HasPrefix(commandArgs[i], "-") {
continue continue
} }
@@ -473,7 +445,7 @@ func main() {
scriptName = command scriptName = command
} }
script = NewScript(scriptName, commandArgs...) script := zqdgr.NewScript(scriptName, commandArgs...)
if err := script.Start(); err != nil { if err := script.Start(); err != nil {
log.Fatal(err) log.Fatal(err)
@@ -487,7 +459,7 @@ func main() {
log.Println("Received signal, exiting...") log.Println("Received signal, exiting...")
if script.command != nil { if script.command != nil {
var signal syscall.Signal var signal syscall.Signal
switch config.ShutdownSignal { switch zqdgr.Config.ShutdownSignal {
case "SIGINT": case "SIGINT":
signal = syscall.SIGINT signal = syscall.SIGINT
case "SIGTERM": case "SIGTERM":
@@ -506,10 +478,10 @@ func main() {
if watchMode { if watchMode {
if !*noWs { if !*noWs {
enableWebSocket = true zqdgr.EnableWebSocket = true
go func() { go func() {
http.HandleFunc("/ws", handleWs) http.HandleFunc("/ws", zqdgr.WSServer.handleWs)
log.Printf("WebSocket server running on :2067") log.Printf("WebSocket server running on :2067")
if err := http.ListenAndServe(":2067", nil); err != nil { if err := http.ListenAndServe(":2067", nil); err != nil {
log.Printf("WebSocket server error: %v", err) log.Printf("WebSocket server error: %v", err)
@@ -517,7 +489,7 @@ func main() {
}() }()
} }
if config.Pattern == "" { if zqdgr.Config.Pattern == "" {
log.Fatal("watch pattern not specified in config") log.Fatal("watch pattern not specified in config")
} }
@@ -525,7 +497,7 @@ func main() {
var currentPattern string var currentPattern string
inMatch := false inMatch := false
// iterate over every letter in the pattern // iterate over every letter in the pattern
for _, p := range config.Pattern { for _, p := range zqdgr.Config.Pattern {
if string(p) == "{" { if string(p) == "{" {
if inMatch { if inMatch {
log.Fatal("unmatched { in pattern") log.Fatal("unmatched { in pattern")
@@ -561,7 +533,7 @@ func main() {
} }
watcherConfig := WatcherConfig{ watcherConfig := WatcherConfig{
excludedDirs: globList(config.ExcludedDirs), excludedDirs: globList(zqdgr.Config.ExcludedDirs),
pattern: paternArray, pattern: paternArray,
} }

View File

@@ -14,9 +14,12 @@
"dev": "sleep 5; echo 'test' && sleep 2 && echo 'test2'", "dev": "sleep 5; echo 'test' && sleep 2 && echo 'test2'",
"test": "zqdgr test:1 && zqdgr test:2 && zqdgr test:3 && zqdgr test:4", "test": "zqdgr test:1 && zqdgr test:2 && zqdgr test:3 && zqdgr test:4",
"test:1": "echo 'a'", "test:1": "echo 'a'",
"test:2": "false", "test:2": "true",
"test:3": "echo 'b'", "test:3": "echo 'b'",
"test:4": "zqdgr test:3", "test:4": "zqdgr test:3",
"test:5": "zqdgr test:6",
"test:6": "zqdgr test:7",
"test:7": "zqdgr test:5",
"recursive": "zqdgr recursive" "recursive": "zqdgr recursive"
}, },
"pattern": "**/*.go" "pattern": "**/*.go"