Newer
Older
thu-learn-downloader-optimized / thu_learn_downloader / downloader.py
@Qin Li Qin Li on 24 Feb 2023 5 KB fix: migrate to python 3.11
import os
import os.path
import time
from concurrent.futures import Executor, ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Optional

import requests
from rich.console import Console
from rich.progress import (
    BarColumn,
    DownloadColumn,
    Progress,
    TaskID,
    TaskProgressColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
    TransferSpeedColumn,
)
from rich.style import Style, StyleType

from .constants import (
    CHUNK_SIZE,
    FAILURE_PREFIX,
    MAX_ACTIVE_TASKS,
    RETRY_PREFIX,
    SKIPPED_PREFIX,
    SUCCESS_PREFIX,
)


def download_once(
    url: str,
    output: str | Path,
    session: Optional[requests.Session],
    raw_size: Optional[int] = None,
    upload_time: Optional[datetime] = None,
    progress: Progress = Progress(),
    task_id: TaskID = TaskID(0),
) -> bool:
    output = Path(output)

    resp: requests.Response = (session or requests).get(url=url, stream=True)
    raw_size = int(resp.headers.get("Content-Length", 0)) or raw_size

    if upload_time and os.path.exists(output) and os.path.getsize(output) == raw_size:
        mtime: float = os.path.getmtime(output)
        if mtime >= upload_time.timestamp():
            return False

    os.makedirs(output.parent, exist_ok=True)
    progress.reset(task_id=task_id, total=raw_size)
    with open(file=output, mode="wb") as fp:
        progress.start_task(task_id=task_id)
        for chunk in resp.iter_content(chunk_size=CHUNK_SIZE):
            bytes_written = fp.write(chunk)
            progress.advance(task_id=task_id, advance=bytes_written)
        if not raw_size:
            progress.update(task_id=task_id, total=fp.tell())

    if upload_time:
        mtime: float = upload_time.timestamp()
        os.utime(path=output, times=(mtime, mtime))

    return True


def download(
    url: str,
    output: str | Path,
    session: Optional[requests.Session],
    raw_size: Optional[int] = None,
    upload_time: Optional[datetime] = None,
    progress: Progress = Progress(),
    task_id: TaskID = TaskID(0),
    *,
    description: str = "",
    max_retries: int = 4,
    console: Console = Console(),
    style: StyleType = Style(color="bright_cyan", bold=True),
) -> None:
    if not description:
        description = progress.tasks[task_id].description
    style = style if isinstance(style, Style) else Style.parse(style)
    for i in range(max_retries + 1):
        try:
            if i:
                console.log(
                    RETRY_PREFIX,
                    i,
                    description,
                    style=style + Style(color="bright_yellow"),
                )
            rtn: bool = download_once(
                url=url,
                output=output,
                session=session,
                raw_size=raw_size,
                upload_time=upload_time,
                progress=progress,
                task_id=task_id,
            )
        except:
            console.log(
                FAILURE_PREFIX,
                description,
                style=style + Style(color="bright_red"),
            )
        else:
            if rtn:
                console.log(SUCCESS_PREFIX, description, style=style)
            else:
                console.log(
                    SKIPPED_PREFIX,
                    description,
                    style=style + Style(dim=True),
                )
            break
    progress.update(task_id=task_id, visible=False)


class Downloader:
    pool: Executor
    progress: Progress

    def __init__(self) -> None:
        self.pool = ThreadPoolExecutor()
        self.progress = Progress(
            TextColumn("{task.description}"),
            BarColumn(),
            DownloadColumn(),
            TaskProgressColumn(),
            TimeElapsedColumn(),
            TimeRemainingColumn(),
            TransferSpeedColumn(),
        )

    def schedule_download(
        self,
        url: str,
        output: str | Path,
        session: Optional[requests.Session],
        raw_size: Optional[int] = None,
        upload_time: Optional[datetime] = None,
        # progress: Progress = Progress(),
        # task_id: TaskID = TaskID(0),
        max_retries: int = 4,
        console: Console = Console(),
        style: StyleType = Style(color="bright_cyan", bold=True),
        *,
        description: str,
    ) -> None:
        while True:
            num_visible_tasks: int = len(
                list(filter(lambda task: task.visible, self.progress.tasks))
            )
            if num_visible_tasks < MAX_ACTIVE_TASKS:
                break
            else:
                time.sleep(1)
        if not isinstance(style, Style):
            style = Style.parse(style)
        task_id: TaskID = self.progress.add_task(
            description=style.render(description), start=False, total=raw_size
        )
        self.pool.submit(
            download,
            url=url,
            output=output,
            session=session,
            raw_size=raw_size,
            upload_time=upload_time,
            progress=self.progress,
            task_id=task_id,
            description=description,
            max_retries=max_retries,
            console=console,
            style=style,
        )