diff --git a/config.go b/config.go index 4c93c01..5cb43d0 100644 --- a/config.go +++ b/config.go @@ -2,15 +2,13 @@ package main import ( "encoding/json" - "fmt" - "io/ioutil" "log" "os" "os/user" "path/filepath" ) -// Config: Configuration structure. +// Config holds the application configuration. type Config struct { RaidTablePath string `json:"raid_table_path"` Services []string `json:"services"` @@ -18,7 +16,7 @@ type Config struct { EncryptionKey string `json:"encryption_key"` } -// ReadConfig: Read the configuration file. +// ReadConfig reads and parses the configuration file. func (a *App) ReadConfig() { usr, err := user.Current() if err != nil { @@ -37,27 +35,33 @@ func (a *App) ReadConfig() { // Determine which configuration to use. var configFile string - if _, err := os.Stat(app.flags.ConfigPath); err == nil && app.flags.ConfigPath != "" { + if app.flags.ConfigPath != "" { + if _, err := os.Stat(app.flags.ConfigPath); err != nil { + log.Fatalln("Specified configuration file does not exist:", app.flags.ConfigPath) + } configFile = app.flags.ConfigPath - } else if _, err := os.Stat(localConfig); err == nil { - configFile = localConfig - } else if _, err := os.Stat(homeDirConfig); err == nil { - configFile = homeDirConfig - } else if _, err := os.Stat(etcConfig); err == nil { - configFile = etcConfig } else { - log.Println("Unable to find a configuration file.") - return + // Search standard paths in priority order. + for _, candidate := range []string{localConfig, homeDirConfig, etcConfig} { + if _, err := os.Stat(candidate); err == nil { + configFile = candidate + break + } + } + if configFile == "" { + log.Println("Unable to find a configuration file.") + return + } } - jsonFile, err := ioutil.ReadFile(configFile) + jsonFile, err := os.ReadFile(configFile) if err != nil { - fmt.Printf("Error reading JSON file: %s\n", err) + log.Printf("Error reading JSON file: %s\n", err) return } err = json.Unmarshal(jsonFile, &app.config) if err != nil { - fmt.Printf("Error parsing JSON file: %s\n", err) + log.Printf("Error parsing JSON file: %s\n", err) } } diff --git a/flags.go b/flags.go index 57b188f..b8efac4 100644 --- a/flags.go +++ b/flags.go @@ -6,14 +6,14 @@ import ( "os" ) -// Flags: Configuration options for cli execution. +// Flags holds configuration options for CLI execution. type Flags struct { ConfigPath string EncryptionKey string EncryptionPassword string } -// Init: Parses configuration options. +// Init parses configuration options from command-line flags. func (f *Flags) Init() { flag.Usage = func() { fmt.Printf("raid-mount: Mounts raid drives and starts services\n\nUsage:\n") @@ -29,7 +29,7 @@ func (f *Flags) Init() { flag.StringVar(&f.ConfigPath, "c", "", usage+" (shorthand)") flag.StringVar(&f.EncryptionKey, "encryption-key", "", "Keyfile to decrypt drives") - usage = "Password to decrypt drives" + usage = "Password to decrypt drives (visible in process list; prefer RAID_MOUNT_ENCRYPTION_PASSWORD env var)" flag.StringVar(&f.EncryptionPassword, "encryption-password", "", usage) flag.StringVar(&f.EncryptionPassword, "p", "", usage+" (shorthand)") diff --git a/main.go b/main.go index d67dfef..cfcfe53 100644 --- a/main.go +++ b/main.go @@ -14,18 +14,18 @@ import ( "golang.org/x/term" ) -// RaidMount: Mount point details. +// RaidMount holds mount point details parsed from the raid table. type RaidMount struct { Source string - Target string - FSType string - Flags string + Target string + FSType string + Flags string CryptName string Encrypted bool Parallel bool } -// App: Global application structure. +// App is the global application structure. type App struct { flags *Flags config Config @@ -33,7 +33,7 @@ type App struct { var app *App -// isMounted: Checks the linux mounts for a target mountpoint to see if it is mounted. +// isMounted checks /proc/mounts for a target mountpoint to determine if it is mounted. func isMounted(target string) bool { file, err := os.Open("/proc/mounts") if err != nil { @@ -54,104 +54,177 @@ func isMounted(target string) bool { return false } -func mountDrive(mount RaidMount, encryptionPassword string, wg *sync.WaitGroup) { - // Make sure we tell the wait group that we're done when the mount is done. - defer wg.Done() +// closeLUKS attempts to close a LUKS volume by name, logging any failure. +func closeLUKS(cryptName string) { + cmd := exec.Command("cryptsetup", "close", cryptName) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + log.Printf("Failed to close LUKS volume %s: %v\n", cryptName, err) + } +} + +// mountBindfs handles mounting a bindfs FUSE filesystem. +// +// The Flags field uses a "|" separator to distinguish bindfs native flags from +// FUSE -o options: +// +// "resolve-symlinks|allow_other" +// └─ passed as --resolve-symlinks └─ passed as -o allow_other +// +// Either side of the "|" may be empty. If no "|" is present the entire Flags +// string is treated as bindfs native flags with no -o options. +func mountBindfs(mount RaidMount) error { + if isMounted(mount.Target) { + fmt.Println(mount.Target, "is already mounted") + return nil + } + + // Split flags into bindfs native flags and FUSE -o options. + var bindfsFlags []string + var fuseOpts string + + parts := strings.SplitN(mount.Flags, "|", 2) + nativeRaw := strings.TrimSpace(parts[0]) + if len(parts) == 2 { + fuseOpts = strings.TrimSpace(parts[1]) + } + + // Each comma-separated native flag becomes a --flag argument. + if nativeRaw != "" { + for _, f := range strings.Split(nativeRaw, ",") { + f = strings.TrimSpace(f) + if f != "" { + bindfsFlags = append(bindfsFlags, "--"+f) + } + } + } + + // Build the full argument list. + args := bindfsFlags + if fuseOpts != "" { + args = append(args, "-o", fuseOpts) + } + args = append(args, mount.Source, mount.Target) + + fmt.Println("bindfs", args) + cmd := exec.Command("bindfs", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + if err := cmd.Run(); err != nil { + return fmt.Errorf("bindfs %s: %w", mount.Target, err) + } + + if !isMounted(mount.Target) { + return fmt.Errorf("unable to mount: %s", mount.Target) + } + return nil +} + +// mountDrive decrypts (if needed) and mounts a single drive. Returns an error +// instead of calling log.Fatal so the caller can coordinate shutdown safely. +func mountDrive(mount RaidMount, encryptionPassword string) error { + // Dispatch bindfs mounts to their own handler. bindfs mounts are never + // encrypted so we skip the cryptsetup path entirely. + if mount.FSType == "bindfs" { + return mountBindfs(mount) + } + + // Track whether we opened the LUKS volume ourselves so we can clean up on failure. + openedLUKS := false + // If encrypted, decrypt the drive. if mount.Encrypted { // Check the device path to see if the encrypted drive is already decrypted. dmPath := "/dev/mapper/" + mount.CryptName if _, err := os.Stat(dmPath); err == nil { fmt.Println("Already decrypted:", mount.CryptName) - return + } else { + // Decrypt the drive. + args := []string{ + "open", + mount.Source, + mount.CryptName, + } + + // If encryption key file was provided, add argument. + if app.config.EncryptionKey != "" { + args = append(args, "--key-file="+app.config.EncryptionKey) + } + + fmt.Println("cryptsetup", args) + cmd := exec.Command("cryptsetup", args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("cryptsetup stdin pipe for %s: %w", mount.CryptName, err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("cryptsetup start %s: %w", mount.CryptName, err) + } + + // If password was provided, send it to cryptsetup and close stdin + // so the process receives EOF and does not block. + if encryptionPassword != "" { + fmt.Fprintln(stdin, encryptionPassword) + } + stdin.Close() + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("cryptsetup open %s: %w", mount.CryptName, err) + } + + // If we cannot verify that it is decrypted, the mount will not work. + if _, err := os.Stat(dmPath); err != nil { + return fmt.Errorf("unable to decrypt: %s", mount.CryptName) + } + openedLUKS = true } - // Decrypt the drive. - args := []string{ - "open", - mount.Source, - mount.CryptName, - } - - // If encryption key file was provided, add argument. - if app.config.EncryptionKey != "" { - args = append(args, "--key-file="+app.config.EncryptionKey) - } - - fmt.Println("cryptsetup", args) - cmd := exec.Command("cryptsetup", args...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - - stdin, err := cmd.StdinPipe() - if err != nil { - log.Fatalln(err) - } - - // If password was provided, send it to cryptsetup. - if encryptionPassword != "" { - fmt.Fprintln(stdin, encryptionPassword) - } - - // Run cryptsetup to decrypt drive and any error is fatal due to it preventing all required drives from mounting. - err = cmd.Start() - if err != nil { - log.Fatalln(err) - } - - err = cmd.Wait() - if err != nil { - log.Fatalln(err) - } - - // If we cannot verify that its decrypted, then we need to stop as mount won't work. - if _, err := os.Stat(dmPath); err != nil { - log.Fatalln("Unable to decrypt:", mount.CryptName) - } - - // Now that its decrypted, update the source path for mounting. + // Now that it is decrypted, update the source path for mounting. mount.Source = dmPath } - // If we're already mounted on this mountpoint, skip to the next one. + // If we're already mounted on this mountpoint, skip. if isMounted(mount.Target) { fmt.Println(mount.Target, "is already mounted") - return + return nil } - // Mount the mountpoint. - args := []string{ - "-t", - mount.FSType, - "-o", - mount.Flags, - mount.Source, - mount.Target, + // Build mount arguments, only adding -o if flags are non-empty. + args := []string{"-t", mount.FSType} + if mount.Flags != "" { + args = append(args, "-o", mount.Flags) } + args = append(args, mount.Source, mount.Target) fmt.Println("mount", args) cmd := exec.Command("mount", args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Run mount to mount the mountpoint, any error is fatal as we want to ensure that mountpoints mount. - err := cmd.Start() - if err != nil { - log.Fatalln(err) + if err := cmd.Run(); err != nil { + if mount.Encrypted && openedLUKS { + closeLUKS(mount.CryptName) + } + return fmt.Errorf("mount %s: %w", mount.Target, err) } - err = cmd.Wait() - if err != nil { - log.Fatalln(err) - } - - // Verified that it actually mounted. + // Verify that it actually mounted. if !isMounted(mount.Target) { - log.Fatalln("Unable to mount:", mount.Target) + if mount.Encrypted && openedLUKS { + closeLUKS(mount.CryptName) + } + return fmt.Errorf("unable to mount: %s", mount.Target) } + return nil } -// main: Starting application function. +// main is the entry point for the application. func main() { // Only allow running as root. if os.Getuid() != 0 { @@ -171,7 +244,7 @@ func main() { } var raidMounts []RaidMount - hasEncryptedDrives := false // If there are encrypted drives, we require a password to decrypt them. + hasEncryptedDrives := false // Open the raid mountpoint table file. raidTab, err := os.Open(app.config.RaidTablePath) @@ -197,7 +270,7 @@ func main() { continue } - // If line is not 5 fields, some formatting is wrong in the table. We will just log/ignore this line. + // If line is not 6 fields, some formatting is wrong in the table. if len(args) != 6 { log.Println("Line does not have correct number of arguments:", line) continue @@ -214,7 +287,7 @@ func main() { Parallel: false, } - // If the CryptName field is not none, then it is an encrypted drive. We must set the variables for logic below to easily determine if it has encryption. + // If the CryptName field is not none, then it is an encrypted drive. if mount.CryptName != "none" { mount.Encrypted = true hasEncryptedDrives = true @@ -250,9 +323,11 @@ func main() { } } - // If the encryption password was not provided and an encryption key not provided and there is a mountpoint that is encrypted, - // request the password from the user. + // Resolve the encryption password from flag, environment variable, or interactive prompt. encryptionPassword := app.flags.EncryptionPassword + if encryptionPassword == "" { + encryptionPassword = os.Getenv("RAID_MOUNT_ENCRYPTION_PASSWORD") + } if encryptionPassword == "" && app.config.EncryptionKey == "" && hasEncryptedDrives { fmt.Print("Please enter the encryption password: ") @@ -265,43 +340,58 @@ func main() { encryptionPassword = string(bytePassword) } - // With each mountpoint, decrypt and mount. + // With each mountpoint, decrypt and mount. Errors are collected so that a + // single failure does not silently kill goroutines via os.Exit. var wg sync.WaitGroup + var mu sync.Mutex + var mountErrors []error + for _, mount := range raidMounts { - // If this task is not parallel, wait for previous tasks to complete before processing. + // A non-parallel entry acts as a barrier: wait for all prior mounts to + // complete and abort if any of them failed. if !mount.Parallel { wg.Wait() + mu.Lock() + if len(mountErrors) > 0 { + for _, e := range mountErrors { + log.Println(e) + } + log.Fatalln("Aborting due to mount errors.") + } + mu.Unlock() } - // Add 1 to the wait group as we're spawning a task. - wg.Add(1) - // Mount the drive. - go mountDrive(mount, encryptionPassword, &wg) - } - // Now that all mounts are in progress, we wait before starting services. - wg.Wait() - // Now that all mountpoints are mounted, start the services in configuration. - for _, service := range app.config.Services { - // Start the service. - args := []string{ - "start", - service, + wg.Add(1) + go func(m RaidMount) { + defer wg.Done() + if err := mountDrive(m, encryptionPassword); err != nil { + mu.Lock() + mountErrors = append(mountErrors, err) + mu.Unlock() + } + }(mount) + } + + // Wait for all remaining mounts and check for errors before starting services. + wg.Wait() + if len(mountErrors) > 0 { + for _, e := range mountErrors { + log.Println(e) } + log.Fatalln("Aborting due to mount errors.") + } + + // Now that all mountpoints are mounted, start the configured services. + for _, service := range app.config.Services { + args := []string{"start", service} fmt.Println("systemctl", args) cmd := exec.Command("systemctl", args...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr - // Start systemctl, any error is not fatal to allow other services to start. - err = cmd.Start() - if err != nil { - log.Println(err) - } - - err = cmd.Wait() - if err != nil { - log.Println(err) + if err := cmd.Run(); err != nil { + log.Println("Failed to start service", service+":", err) } } } diff --git a/raidtab.example b/raidtab.example index 3cc4818..515b3c6 100644 --- a/raidtab.example +++ b/raidtab.example @@ -1,6 +1,14 @@ # Source Target FSType Flags CryptName Parallel +# +# The Parallel field controls mount ordering. Entries with Parallel=1 are +# launched concurrently with the preceding entries. An entry with Parallel=0 +# acts as a barrier: all previously launched mounts must complete before it +# starts. This lets you express dependencies by position — for example, +# individual drives can mount in parallel, then a mergerfs union that depends +# on them uses Parallel=0 to wait. + /dev/sdb1 /mnt/sdb1 xfs defaults none 1 /dev/sdc1 /mnt/sdc1 xfs defaults sdc1 1 -# Merged -/mnt/sdb1:/mnt/sdc1 /mnt/merged mergerfs config=/etc/mergerfs.ini,allow_other,use_ino,fsname=merged none 0 \ No newline at end of file +# Merged — waits for the parallel mounts above to finish before starting. +/mnt/sdb1:/mnt/sdc1 /mnt/merged mergerfs config=/etc/mergerfs.ini,allow_other,use_ino,fsname=merged none 0