1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use crate::Backend;
use axum::{
    extract::{
        ws::{Message, WebSocket, WebSocketUpgrade},
        State,
    },
    response::Response,
    routing::get,
    Router,
};
use buildsrs_database::{entity::Builder, AnyMetadata, BoxError, Error as DatabaseError};
use buildsrs_protocol::{
    ssh_key::Fingerprint,
    types::{Challenge, JobKind},
    *,
};
use futures::StreamExt;
use tracing::*;

#[derive(thiserror::Error, Debug)]
pub enum WebSocketError {
    #[error("Missing Hello message")]
    MissingHello,
    #[error("Challenge incorrect")]
    ChallengeError,
    #[error("Stream is closed")]
    StreamClosed,
    #[error(transparent)]
    Axum(#[from] axum::Error),
    #[error(transparent)]
    Json(#[from] serde_json::Error),
    #[error(transparent)]
    Signature(#[from] SignatureError),
    #[error(transparent)]
    Database(#[from] DatabaseError),
    #[error(transparent)]
    DatabaseBox(#[from] BoxError),
}

async fn extract_fingerprint(socket: &mut WebSocket) -> Result<Fingerprint, WebSocketError> {
    while let Some(message) = socket.next().await {
        let message: SignedMessage<ClientMessage> = match message? {
            Message::Text(message) => serde_json::from_str(&message)?,
            _ => continue,
        };
        match message.message {
            ClientMessage::Hello(hello) => return Ok(hello.fingerprint),
            _ => continue,
        }
    }

    Err(WebSocketError::MissingHello)
}

struct Connection {
    websocket: WebSocket,
    builder: Builder,
    database: AnyMetadata,
}

impl Connection {
    async fn recv(&mut self) -> Result<ClientMessage, WebSocketError> {
        while let Some(message) = self.websocket.next().await {
            let message: SignedMessage<ClientMessage> = match message? {
                Message::Text(message) => serde_json::from_str(&message)?,
                _ => continue,
            };
            message.verify(&self.builder.public_key)?;
            return Ok(message.message);
        }
        Err(WebSocketError::StreamClosed)
    }

    async fn send(&mut self, message: ServerMessage) -> Result<(), WebSocketError> {
        self.websocket
            .send(Message::Text(serde_json::to_string(&message)?))
            .await?;
        Ok(())
    }

    async fn challenge(&mut self) -> Result<(), WebSocketError> {
        let challenge = Challenge::random();
        self.send(ServerMessage::ChallengeRequest(challenge.clone()))
            .await?;
        loop {
            let message = self.recv().await?;
            match message {
                ClientMessage::ChallengeResponse(response) => {
                    return if challenge == response {
                        Ok(())
                    } else {
                        Err(WebSocketError::ChallengeError)
                    }
                }
                _ => continue,
            }
        }
    }

    #[allow(clippy::no_effect_underscore_binding)]
    async fn handle_job_request(&mut self, _request: &JobRequest) -> Result<(), WebSocketError> {
        let writer = self.database.write().await?;
        let job = writer.job_request(self.builder.uuid).await?;
        let job = writer.job_info(job).await?;
        writer.commit().await?;
        info!("Created job for {:?}: {:?}", self.builder, job);
        let message = ServerMessage::JobResponse(Job {
            kind: JobKind::Metadata,
            name: job.name,
            source: job.url,
            uuid: job.uuid,
            version: job.version,
        });
        self.send(message).await?;
        Ok(())
    }

    async fn handle(&mut self) -> Result<(), WebSocketError> {
        loop {
            match self.recv().await? {
                ClientMessage::Hello(_) | ClientMessage::ChallengeResponse(_) => break,
                ClientMessage::JobRequest(request) => self.handle_job_request(&request).await?,
            };
        }
        Ok(())
    }
}

impl Backend {
    /// Handle jobs websocket connection.
    pub async fn handle_jobs(&self, mut websocket: WebSocket) -> Result<(), WebSocketError> {
        let fingerprint = extract_fingerprint(&mut websocket).await?;
        let database = self.database().read().await?;
        let uuid = database.builder_lookup(&fingerprint.to_string()).await?;
        let builder = database.builder_get(uuid).await?;
        let mut connection = Connection {
            websocket,
            builder,
            database: self.database().clone(),
        };
        connection.challenge().await?;
        connection.handle().await?;
        Ok(())
    }
}

async fn jobs_websocket(State(backend): State<Backend>, ws: WebSocketUpgrade) -> Response {
    ws.on_upgrade(move |socket| {
        let backend = backend.clone();
        async move {
            match backend.handle_jobs(socket).await {
                Ok(()) => {}
                Err(error) => error!("{error:#}"),
            }
        }
    })
}

pub fn routes() -> Router<Backend> {
    Router::new().route("/jobs", get(jobs_websocket))
}