Newer
Older
thu-learn-downloader-optimized / thu_learn_downloader / download / downloader.py
import logging
import os
from collections.abc import Sequence
from concurrent.futures import Executor, ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
import re
from typing import Optional, Self

import dateutil.parser
import tenacity
from requests import Response
from rich.console import Group
from rich.live import Live
from rich.panel import Panel
from rich.progress import (
    BarColumn,
    DownloadColumn,
    MofNCompleteColumn,
    Progress,
    TaskID,
    TaskProgressColumn,
    TextColumn,
    TimeElapsedColumn,
    TimeRemainingColumn,
    TransferSpeedColumn,
)
from rich.style import Style, StyleType

from thu_learn_downloader.client.client import Client
from thu_learn_downloader.client.course import Course
from thu_learn_downloader.client.document import Document, DocumentClass
from thu_learn_downloader.client.homework import Homework
from thu_learn_downloader.client.semester import Semester

from . import description, filename, style
from .selector import Selector


class Downloader:
    prefix: Path
    selector: Selector

    executor: Executor

    live: Live
    semesters_task_id: TaskID
    courses_task_id: TaskID
    documents_task_id: TaskID

    def __init__(
        self,
        prefix: Optional[Path] = None,
        selector: Optional[Selector] = None,
        jobs: int = 8,
    ) -> None:
        self.prefix = prefix or Path.home() / "thu-learn"
        self.selector = selector or Selector()
        self.executor = ThreadPoolExecutor(max_workers=jobs)

        self.progress_prepare = Progress(
            TextColumn("{task.description}", style="bold bright_blue"),
            BarColumn(),
            MofNCompleteColumn(),
            TimeElapsedColumn(),
        )
        self.progress_download = Progress(
            TextColumn("{task.description}", style="bold"),
            BarColumn(),
            DownloadColumn(),
            TaskProgressColumn(),
            TimeElapsedColumn(),
            TimeRemainingColumn(),
            TransferSpeedColumn(),
        )
        self.semesters_task_id = self.progress_prepare.add_task(description="Semesters")
        self.courses_task_id = self.progress_prepare.add_task(description="Courses")
        self.documents_task_id = self.progress_prepare.add_task(description="Documents")
        self.live = Live(
            Group(
                Panel(self.progress_download, height=jobs + 2),
                Panel(self.progress_prepare),
            )
        )

    def __enter__(self) -> Self:
        self.live.__enter__()
        self.executor.__enter__()
        return self

    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
        self.executor.__exit__(exc_type, exc_val, exc_tb)
        self.live.__exit__(exc_type, exc_val, exc_tb)

    @tenacity.retry(
        stop=tenacity.stop_after_attempt(max_attempt_number=4),
        before_sleep=tenacity.before_sleep_log(logging.getLogger(), logging.DEBUG),
    )
    def download(
        self,
        client: Client,
        url: str,
        output: Path,
        *,
        task_id: TaskID,
    ) -> None:
        self.progress_download.update(task_id=task_id, visible=True)
        response: Response = client.get(url=url, stream=True)
        os.makedirs(name=output.parent, exist_ok=True)
        with output.open(mode="wb") as file:
            for chunk in response.iter_content(chunk_size=8192):
                bytes_written: int = file.write(chunk)
                self.progress_download.advance(task_id=task_id, advance=bytes_written)
        self.progress_download.update(task_id=task_id, visible=False)

    def do_sync_file(
        self,
        client: Client,
        url: str,
        output: Path,
        description: str,
        *,
        remote_size: Optional[int] = None,
        remote_time: Optional[datetime] = None,
        style: Optional[StyleType] = None,
    ) -> None:
        if remote_size is None or remote_time is None:
            response: Response = client.get(url=url, stream=True)
            if remote_size is None:
                try:
                    remote_size = int(response.headers["Content-Length"])
                except Exception:
                    remote_size = None
            if remote_time is None:
                try:
                    remote_time = dateutil.parser.parse(response.headers["Date"])
                except Exception:
                    remote_time = None
        if (
            remote_size is not None
            and output.exists()
            and output.stat().st_size == remote_size
        ):
            style = style or Style()
            if not isinstance(style, Style):
                style = Style.parse(style)
            self.live.console.log(
                "[reverse] SKIPPED [/]", description, style=style + Style(dim=True)
            )
            return
        task_id: TaskID = self.progress_download.add_task(
            description=description, total=remote_size
        )
        self.download(client=client, url=url, output=output, task_id=task_id)
        if remote_time:
            os.utime(
                path=output, times=(remote_time.timestamp(), remote_time.timestamp())
            )
        self.live.console.log("[reverse] SUCCESS [/]", description, style=style)

    def sync_file(
        self,
        client: Client,
        url: str,
        output: Path,
        description: str,
        *,
        remote_size: Optional[int] = None,
        remote_time: Optional[datetime] = None,
        style: Optional[StyleType] = None,
    ) -> None:
        self.executor.submit(
            self.do_sync_file,
            client=client,
            url=url,
            output=output,
            description=description,
            remote_size=remote_size,
            remote_time=remote_time,
            style=style,
        )

    def sync_semesters(self, semesters: Sequence[Semester]) -> None:
        if self.selector.semesters:
            semesters = [
                semester
                for semester in semesters
                if semester.id in self.selector.semesters
            ]
        self.progress_prepare.reset(task_id=self.semesters_task_id)
        for semester in self.progress_prepare.track(
            sequence=semesters,
            total=len(semesters),
            task_id=self.semesters_task_id,
            description="Semesters",
        ):
            self.sync_semester(semester=semester)

    def sync_semester(self, semester: Semester) -> None:
        self.sync_courses(semester=semester, courses=semester.courses)

    def sync_courses(self, semester: Semester, courses: Sequence[Course]) -> None:
        if self.selector.courses:
            courses = [
                course for course in courses if course.id in self.selector.courses
            ]
        self.progress_prepare.reset(task_id=self.courses_task_id)
        for course in self.progress_prepare.track(
            sequence=courses,
            total=len(courses),
            task_id=self.courses_task_id,
            description="Courses",
        ):
            self.sync_course(semester=semester, course=course)

    def sync_course(self, semester: Semester, course: Course) -> None:
        if self.selector.document:
            self.sync_documents(
                semester=semester,
                course=course,
                document_classes=course.document_classes,
                documents=course.documents,
            )
        if self.selector.homework:
            self.sync_homeworks(
                semester=semester, course=course, homeworks=course.homeworks
            )

    def sync_documents(
        self,
        semester: Semester,
        course: Course,
        document_classes: Sequence[DocumentClass],
        documents: Sequence[Document],
    ) -> None:
        document_class_map: dict[str, DocumentClass] = {
            document_class.id: document_class for document_class in document_classes
        }
        self.progress_prepare.reset(task_id=self.documents_task_id)
        for index, document in enumerate(
            self.progress_prepare.track(
                sequence=documents,
                total=len(documents),
                task_id=self.documents_task_id,
                description="Documents",
            ),
            start=1,
        ):
            self.sync_document(
                semester=semester,
                course=course,
                document_class=document_class_map[document.class_id],
                document=document,
                index=index,
            )

    def sync_document(
        self,
        semester: Semester,
        course: Course,
        document_class: DocumentClass,
        document: Document,
        index: int,
    ) -> None:
        self.sync_file(
            client=document.client,
            url=document.download_url,
            output=filename.document(
                prefix=self.prefix,
                semester=semester,
                course=course,
                document_class=document_class,
                document=document,
                index=index,
            ),
            description=description.document(
                semester=semester,
                course=course,
                document_class=document_class,
                document=document,
                index=index,
            ),
            remote_size=document.size,
            remote_time=document.upload_time,
            style=style.DOCUMENT,
        )

    def sync_homeworks(
        self, semester: Semester, course: Course, homeworks: Sequence[Homework]
    ) -> None:
        self.progress_prepare.reset(task_id=self.documents_task_id)
        for homework in self.progress_prepare.track(
            sequence=homeworks,
            total=len(homeworks),
            task_id=self.documents_task_id,
            description="Homeworks",
        ):
            self.sync_homework(semester=semester, course=course, homework=homework)

    def sync_homework(
        self, semester: Semester, course: Course, homework: Homework
    ) -> None:
        readme_path: Path = filename.homework(
            prefix=self.prefix, semester=semester, course=course, homework=homework
        )
        readme_path = Path(re.sub(r'[<>:"\\|?*\x00-\x1F]', '_', str(readme_path)))

        os.makedirs(readme_path.parent, exist_ok=True)
        # readme_path.write_text(homework.markdown)
        for attachment in homework.attachments:
            self.sync_file(
                client=homework.client,
                url=attachment.download_url,
                output=filename.attachment(
                    prefix=self.prefix,
                    semester=semester,
                    course=course,
                    homework=homework,
                    attachment=attachment,
                ),
                description=description.attachment(
                    semester=semester,
                    course=course,
                    homework=homework,
                    attachment=attachment,
                ),
                style=style.HOMEWORK,
            )