diff --git a/kit/file/file.go b/kit/file/file.go index 1273130..92c0d9f 100644 --- a/kit/file/file.go +++ b/kit/file/file.go @@ -5,26 +5,113 @@ import ( "mime/multipart" "net/http" "os" + "pandax/kit/biz" + "strconv" + "sync" ) -// DownloadFile 会将url下载到本地文件,它会在下载时写入,而不是将整个文件加载到内存中。 -func DownloadFile(url, filepath string) error { - // Get the data - resp, err := http.Get(url) +const ( + MaxConcurrency = 16 // 最大并发数 +) + +type DownloadTask struct { + URL string + FilePath string +} + +func DownloadFileWithConcurrency(url, filepath string) error { + resp, err := http.Head(url) if err != nil { return err } defer resp.Body.Close() - // Create the file - out, err := os.Create(filepath) + fileSize, err := strconv.Atoi(resp.Header.Get("Content-Length")) if err != nil { return err } - defer out.Close() - // Write the body to file - _, err = io.Copy(out, resp.Body) - return err + + file, err := os.OpenFile(filepath, os.O_RDWR|os.O_CREATE, 0666) + if err != nil { + return err + } + defer file.Close() + + // 检查本地文件大小 + localFileSize, err := file.Seek(0, io.SeekEnd) + if err != nil { + return err + } + + // 计算剩余未下载的文件大小 + remainingSize := fileSize - int(localFileSize) + + // 计算每个片段的大小 + chunkSize := remainingSize / MaxConcurrency + + // 创建等待组,用于等待所有goroutine完成 + var wg sync.WaitGroup + wg.Add(MaxConcurrency) + + // 创建并发下载任务 + for i := 0; i < MaxConcurrency; i++ { + start := localFileSize + int64(i*chunkSize) + end := start + int64(chunkSize) - 1 + + // 最后一个片段的结束位置可能超过文件大小,需要修正 + if i == MaxConcurrency-1 { + end = int64(fileSize) - 1 + } + + go func(index int, start, end int64) { + defer wg.Done() + + err := downloadChunk(url, filepath, start, end) + if err != nil { + biz.NewBizErr("文件下载失败") + // 处理下载错误 + } + }(i, start, end) + } + + // 等待所有goroutine完成 + wg.Wait() + + return nil +} + +func downloadChunk(url, filepath string, start, end int64) error { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return err + } + + // 设置Range头部 + req.Header.Set("Range", "bytes="+strconv.FormatInt(start, 10)+"-"+strconv.FormatInt(end, 10)) + + resp, err := http.DefaultClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + file, err := os.OpenFile(filepath, os.O_RDWR, 0666) + if err != nil { + return err + } + defer file.Close() + + _, err = file.Seek(start, io.SeekStart) + if err != nil { + return err + } + + _, err = io.CopyN(file, resp.Body, end-start+1) + if err != nil { + return err + } + + return nil } func SaveUploadedFile(file *multipart.FileHeader, dst string) error {