Use sipper in download_progress

This commit is contained in:
Héctor Ramón Jiménez 2025-02-11 00:56:14 +01:00
parent 54ffbbf043
commit 05618ea9b3
No known key found for this signature in database
GPG key ID: 7CC46565708259A7
2 changed files with 37 additions and 38 deletions

View file

@ -1,16 +1,14 @@
use iced::futures::{SinkExt, Stream, StreamExt}; use iced::futures::StreamExt;
use iced::stream::try_channel; use iced::task::{sipper, Straw};
use std::sync::Arc; use std::sync::Arc;
pub fn download( pub fn download(url: impl AsRef<str>) -> impl Straw<(), Progress, Error> {
url: impl AsRef<str>, sipper(move |mut progress| async move {
) -> impl Stream<Item = Result<Progress, Error>> {
try_channel(1, move |mut output| async move {
let response = reqwest::get(url.as_ref()).await?; let response = reqwest::get(url.as_ref()).await?;
let total = response.content_length().ok_or(Error::NoContentLength)?; let total = response.content_length().ok_or(Error::NoContentLength)?;
let _ = output.send(Progress::Downloading { percent: 0.0 }).await; let _ = progress.send(Progress { percent: 0.0 }).await;
let mut byte_stream = response.bytes_stream(); let mut byte_stream = response.bytes_stream();
let mut downloaded = 0; let mut downloaded = 0;
@ -19,23 +17,20 @@ pub fn download(
let bytes = next_bytes?; let bytes = next_bytes?;
downloaded += bytes.len(); downloaded += bytes.len();
let _ = output let _ = progress
.send(Progress::Downloading { .send(Progress {
percent: 100.0 * downloaded as f32 / total as f32, percent: 100.0 * downloaded as f32 / total as f32,
}) })
.await; .await;
} }
let _ = output.send(Progress::Finished).await;
Ok(()) Ok(())
}) })
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum Progress { pub struct Progress {
Downloading { percent: f32 }, pub percent: f32,
Finished,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View file

@ -25,7 +25,7 @@ struct Example {
pub enum Message { pub enum Message {
Add, Add,
Download(usize), Download(usize),
DownloadProgressed(usize, Result<download::Progress, download::Error>), DownloadUpdated(usize, Update),
} }
impl Example { impl Example {
@ -52,15 +52,13 @@ impl Example {
let task = download.start(); let task = download.start();
task.map(move |progress| { task.map(move |update| Message::DownloadUpdated(index, update))
Message::DownloadProgressed(index, progress)
})
} }
Message::DownloadProgressed(id, progress) => { Message::DownloadUpdated(id, update) => {
if let Some(download) = if let Some(download) =
self.downloads.iter_mut().find(|download| download.id == id) self.downloads.iter_mut().find(|download| download.id == id)
{ {
download.progress(progress); download.update(update);
} }
Task::none() Task::none()
@ -95,6 +93,12 @@ struct Download {
state: State, state: State,
} }
#[derive(Debug, Clone)]
pub enum Update {
Downloading(download::Progress),
Finished(Result<(), download::Error>),
}
#[derive(Debug)] #[derive(Debug)]
enum State { enum State {
Idle, Idle,
@ -111,18 +115,20 @@ impl Download {
} }
} }
pub fn start( pub fn start(&mut self) -> Task<Update> {
&mut self,
) -> Task<Result<download::Progress, download::Error>> {
match self.state { match self.state {
State::Idle { .. } State::Idle { .. }
| State::Finished { .. } | State::Finished { .. }
| State::Errored { .. } => { | State::Errored { .. } => {
let (task, handle) = Task::stream(download( let (task, handle) = Task::sip(
download(
"https://huggingface.co/\ "https://huggingface.co/\
mattshumer/Reflection-Llama-3.1-70B/\ mattshumer/Reflection-Llama-3.1-70B/\
resolve/main/model-00001-of-00162.safetensors", resolve/main/model-00001-of-00162.safetensors",
)) ),
Update::Downloading,
Update::Finished,
)
.abortable(); .abortable();
self.state = State::Downloading { self.state = State::Downloading {
@ -136,20 +142,18 @@ impl Download {
} }
} }
pub fn progress( pub fn update(&mut self, update: Update) {
&mut self,
new_progress: Result<download::Progress, download::Error>,
) {
if let State::Downloading { progress, .. } = &mut self.state { if let State::Downloading { progress, .. } = &mut self.state {
match new_progress { match update {
Ok(download::Progress::Downloading { percent }) => { Update::Downloading(new_progress) => {
*progress = percent; *progress = new_progress.percent;
} }
Ok(download::Progress::Finished) => { Update::Finished(result) => {
self.state = State::Finished; self.state = if result.is_ok() {
} State::Finished
Err(_error) => { } else {
self.state = State::Errored; State::Errored
};
} }
} }
} }