diff --git a/pkg/epoch/update.go b/pkg/epoch/update.go index 12c398b..fa9d842 100644 --- a/pkg/epoch/update.go +++ b/pkg/epoch/update.go @@ -10,142 +10,158 @@ import ( "os" "path/filepath" "regexp" + "slices" "strings" ) type UpdateStats struct { - Updated int - Current int - Outdated int - LogMessages []string - Error error + Updated int + Current int + Outdated int + Error error + MessageBuf []string } func Update(wowdir string, force bool, removeUnknown bool, skipDownload bool) UpdateStats { stats := UpdateStats{ - LogMessages: make([]string, 0), - Error: nil, + Error: nil, + MessageBuf: make([]string, 0), } - manifest, err := GetManifest() - if err != nil { - stats.Error = fmt.Errorf("Failed to get manifest: %v\n", err) - return stats - } + msgChan := make(chan string) + done := make(chan bool) - for _, file := range manifest.Files { - path := strings.ReplaceAll(file.Path, `\`, `/`) - path = strings.TrimLeft(path, `\`) - - localPath := filepath.Join(wowdir, path) - localDir := filepath.Dir(localPath) - if _, err = os.Stat(localDir); os.IsNotExist(err) { - err = os.MkdirAll(localDir, 0755) - if err != nil { - stats.Error = fmt.Errorf("failed to create directory %s: %v", localDir, err) - return stats - } + go func() { + manifest, err := GetManifest() + if err != nil { + stats.Error = fmt.Errorf("Failed to get manifest: %v\n", err) + done <- true + return } - if !force { - if _, err = os.Stat(localPath); err == nil { - data, err := os.ReadFile(localPath) - if err != nil { - stats.Error = fmt.Errorf("failed to read %s: %v", localPath, err) - return stats - } - hashBytes := md5.Sum(data) - hash := hex.EncodeToString(hashBytes[:]) - if hash == file.Hash { - stats.LogMessages = append(stats.LogMessages, fmt.Sprintf("File %s is up to date", localPath)) - stats.Current += 1 - continue - } else { - stats.Outdated += 1 - } - } - } - - if !skipDownload { - outFile, err := os.Create(localPath) - if err != nil { - stats.Error = fmt.Errorf("failed to create file %s: %v", localPath, err) - return stats - } - - downloadSuccess := false - for _, url := range []string{file.Urls.Cloudflare, file.Urls.Digitalocean, file.Urls.None} { - resp, err := http.Get(url) - if err != nil { - if resp != nil { - resp.Body.Close() - } - stats.LogMessages = append(stats.LogMessages, fmt.Sprintf("Failed to download %s: %v", url, err)) - continue - } - if resp.StatusCode != http.StatusOK { - resp.Body.Close() - stats.LogMessages = append(stats.LogMessages, fmt.Sprintf("HTTP Status %d", resp.StatusCode)) - continue - } - - _, err = io.Copy(outFile, resp.Body) - if err != nil { - stats.LogMessages = append(stats.LogMessages, fmt.Sprintf("Failed to write file %s: %v", localPath, err)) - resp.Body.Close() - continue - } - - resp.Body.Close() - downloadSuccess = true - stats.LogMessages = append(stats.LogMessages, fmt.Sprintf("Successfully downloaded %s", localPath)) - break - } - - outFile.Close() - if !downloadSuccess { - stats.Error = fmt.Errorf("Failed to download updates, see above messages") - return stats - } - - stats.Updated += 1 - } - } - - if removeUnknown { - patches := make([]string, 0) - patchreg := regexp.MustCompile(`patch-[A-Za-z].MPQ`) - for _, file := range manifest.Files { - if patchreg.MatchString(file.Path) { - patches = append(patches, strings.Split(file.Path, "Data\\")[1]) + path := strings.ReplaceAll(file.Path, `\`, `/`) + path = strings.TrimLeft(path, `\`) + + localPath := filepath.Join(wowdir, path) + localDir := filepath.Dir(localPath) + if _, err = os.Stat(localDir); os.IsNotExist(err) { + err = os.MkdirAll(localDir, 0755) + if err != nil { + stats.Error = fmt.Errorf("failed to create directory %s: %v", localDir, err) + done <- true + return + } + } + + if !force { + if _, err = os.Stat(localPath); err == nil { + data, err := os.ReadFile(localPath) + if err != nil { + stats.Error = fmt.Errorf("failed to read %s: %v", localPath, err) + done <- true + return + } + hashBytes := md5.Sum(data) + hash := hex.EncodeToString(hashBytes[:]) + if hash == file.Hash { + msgChan <- fmt.Sprintf("File %s is up to date", localPath) + stats.Current += 1 + continue + } else { + stats.Outdated += 1 + } + } + } + + if !skipDownload { + outFile, err := os.Create(localPath) + if err != nil { + stats.Error = fmt.Errorf("failed to create file %s: %v", localPath, err) + done <- true + return + } + + downloadSuccess := false + for _, url := range []string{file.Urls.Cloudflare, file.Urls.Digitalocean, file.Urls.None} { + resp, err := http.Get(url) + if err != nil { + if resp != nil { + resp.Body.Close() + } + msgChan <- fmt.Sprintf("Failed to download %s: %v", url, err) + continue + } + if resp.StatusCode != http.StatusOK { + resp.Body.Close() + msgChan <- fmt.Sprintf("HTTP Status %d", resp.StatusCode) + continue + } + + _, err = io.Copy(outFile, resp.Body) + if err != nil { + msgChan <- fmt.Sprintf("Failed to write file %s: %v", localPath, err) + resp.Body.Close() + continue + } + + resp.Body.Close() + downloadSuccess = true + msgChan <- fmt.Sprintf("Successfully downloaded %s", localPath) + break + } + + outFile.Close() + if !downloadSuccess { + stats.Error = fmt.Errorf("Failed to download updates, see above messages") + done <- true + return + } + + stats.Updated += 1 } } - err = filepath.WalkDir(filepath.Join(wowdir, "Data"), func(path string, d fs.DirEntry, err error) error { - if !d.IsDir() && patchreg.MatchString(d.Name()) { - del := true - for _, patch := range patches { - if patch == d.Name() { - del = false - break - } - } - if del { - err = os.Remove(path) - if err != nil { - return err - } - stats.LogMessages = append(stats.LogMessages, fmt.Sprintf("Removed unknown patch %s", d.Name())) + if removeUnknown { + patches := make([]string, 0) + patchreg := regexp.MustCompile(`patch-[A-Za-z].MPQ`) + + for _, file := range manifest.Files { + if patchreg.MatchString(file.Path) { + patches = append(patches, strings.Split(file.Path, "Data\\")[1]) } } - return nil - }) - if err != nil { - stats.Error = fmt.Errorf("failed to delete unknown patches: %v", err) + err = filepath.WalkDir(filepath.Join(wowdir, "Data"), func(path string, d fs.DirEntry, err error) error { + if !d.IsDir() && patchreg.MatchString(d.Name()) { + del := true + if slices.Contains(patches, d.Name()) { + del = false + } + if del { + err = os.Remove(path) + if err != nil { + return err + } + msgChan <- fmt.Sprintf("Removed unknown patch %s", d.Name()) + } + } + return nil + }) + if err != nil { + stats.Error = fmt.Errorf("failed to delete unknown patches: %v", err) + } + } + done <- true + }() + + for { + select { + case msg := <-msgChan: + fmt.Println(msg) + stats.MessageBuf = append(stats.MessageBuf, msg) + case <-done: + return stats } } - - return stats }