use crate::{DynStrategy, Output};
use anyhow::Result;
use buildsrs_protocol::{types::*, *};
use futures::{SinkExt, StreamExt};
use ssh_key::{HashAlg, PrivateKey};
use std::time::Duration;
use tokio::{
net::TcpStream,
select,
sync::mpsc::{channel, Receiver, Sender},
task::JoinSet,
time::{interval, Interval},
};
use tokio_tungstenite::{connect_async, tungstenite::Message, MaybeTlsStream, WebSocketStream};
use tracing::*;
use url::Url;
type WebSocket = WebSocketStream<MaybeTlsStream<TcpStream>>;
#[allow(dead_code)]
pub enum Event {
Build(String),
}
pub struct Connection {
parallel: usize,
strategy: DynStrategy,
poll_timer: Interval,
private_key: PrivateKey,
websocket: WebSocket,
tasks: JoinSet<()>,
receiver: Receiver<Event>,
sender: Sender<Event>,
}
impl Connection {
pub async fn connect(
strategy: DynStrategy,
private_key: PrivateKey,
url: &Url,
) -> Result<Self> {
let (websocket, _) = connect_async(url.as_str()).await?;
Ok(Self::new(strategy, websocket, private_key))
}
pub fn new(strategy: DynStrategy, websocket: WebSocket, private_key: PrivateKey) -> Self {
let (sender, receiver) = channel(16);
Self {
parallel: 16,
strategy,
poll_timer: interval(Duration::from_secs(1)),
private_key,
websocket,
sender,
receiver,
tasks: Default::default(),
}
}
pub async fn send(&mut self, message: ClientMessage) -> Result<()> {
let signed = SignedMessage::new(&self.private_key, message)?;
let json = serde_json::to_string(&signed)?;
self.websocket.send(Message::Text(json)).await?;
Ok(())
}
pub async fn recv(websocket: &mut WebSocket) -> Result<ServerMessage> {
match websocket.next().await {
Some(Ok(Message::Text(text))) => {
info!("Got message {text}");
Ok(serde_json::from_str(&text)?)
}
_ => todo!(),
}
}
pub async fn authenticate(&mut self) -> Result<()> {
let fingerprint = self.private_key.public_key().fingerprint(HashAlg::Sha512);
self.send(ClientMessage::Hello(ClientHello { fingerprint }))
.await?;
let challenge = loop {
if let Some(message) = self.websocket.next().await {
let message: ServerMessage = match message? {
Message::Text(text) => serde_json::from_str(&text)?,
_other => continue,
};
match message {
ServerMessage::ChallengeRequest(challenge) => break challenge,
_ => continue,
}
}
};
self.send(ClientMessage::ChallengeResponse(challenge))
.await?;
Ok(())
}
pub async fn tasks_sync(&mut self) -> Result<()> {
if self.tasks.len() < self.parallel {
info!("Requesting another task");
self.send(ClientMessage::JobRequest(JobRequest {})).await?;
}
Ok(())
}
pub async fn handle_iter(&mut self) -> Result<()> {
info!("Waiting for event");
select! {
_tick = self.poll_timer.tick() => self.tasks_sync().await?,
message = Self::recv(&mut self.websocket) => self.handle_message(message?),
_result = self.tasks.join_next(), if !self.tasks.is_empty() => self.handle_done().await?,
_event = self.receiver.recv() => {},
}
Ok(())
}
#[instrument(skip(self))]
pub async fn handle(&mut self) -> Result<()> {
loop {
self.handle_iter().await?;
}
}
#[allow(dead_code)]
async fn handle_done(&mut self) -> Result<()> {
Ok(())
}
fn handle_message(&mut self, message: ServerMessage) {
info!("Got message {message:?}");
match message {
ServerMessage::JobList(jobs) => {
for job in jobs.jobs {
self.handle_job(job);
}
}
ServerMessage::JobResponse(job) => self.handle_job(job),
ServerMessage::ChallengeRequest(_) => unreachable!(),
}
}
fn handle_job(&mut self, job: Job) {
info!("Got job {job:?}");
let sender = self.sender.clone();
self.tasks
.spawn(Self::job(self.strategy.clone(), job, sender));
}
#[allow(clippy::no_effect_underscore_binding)]
async fn job(strategy: DynStrategy, job: Job, _sender: Sender<Event>) {
let builder = strategy.builder_from_url(&job.source, b"").await.unwrap();
let mut stream = builder.metadata().await.unwrap();
while let Some(output) = stream.next().await {
match output.unwrap() {
Output::Build(data) => debug!("{}", String::from_utf8_lossy(&data)),
Output::Data(data) => debug!("{data:?}"),
Output::Fetch(_data) => {}
}
}
tokio::time::sleep(Duration::from_secs(10)).await;
}
}