]> Untitled Git - lemmy.git/blob - crates/routes/src/feeds.rs
Make functions work with both connection and pool (#3420)
[lemmy.git] / crates / routes / src / feeds.rs
1 use actix_web::{error::ErrorBadRequest, web, Error, HttpRequest, HttpResponse, Result};
2 use anyhow::anyhow;
3 use chrono::{DateTime, NaiveDateTime, Utc};
4 use lemmy_api_common::context::LemmyContext;
5 use lemmy_db_schema::{
6   newtypes::LocalUserId,
7   source::{community::Community, local_user::LocalUser, person::Person},
8   traits::{ApubActor, Crud},
9   utils::DbPool,
10   CommentSortType,
11   ListingType,
12   SortType,
13 };
14 use lemmy_db_views::{
15   post_view::PostQuery,
16   structs::{PostView, SiteView},
17 };
18 use lemmy_db_views_actor::{
19   comment_reply_view::CommentReplyQuery,
20   person_mention_view::PersonMentionQuery,
21   structs::{CommentReplyView, PersonMentionView},
22 };
23 use lemmy_utils::{claims::Claims, error::LemmyError, utils::markdown::markdown_to_html};
24 use once_cell::sync::Lazy;
25 use rss::{
26   extension::dublincore::DublinCoreExtensionBuilder,
27   ChannelBuilder,
28   GuidBuilder,
29   Item,
30   ItemBuilder,
31 };
32 use serde::Deserialize;
33 use std::{collections::BTreeMap, str::FromStr};
34
35 const RSS_FETCH_LIMIT: i64 = 20;
36
37 #[derive(Deserialize)]
38 struct Params {
39   sort: Option<String>,
40   limit: Option<i64>,
41   page: Option<i64>,
42 }
43
44 impl Params {
45   fn sort_type(&self) -> Result<SortType, Error> {
46     let sort_query = self
47       .sort
48       .clone()
49       .unwrap_or_else(|| SortType::Hot.to_string());
50     SortType::from_str(&sort_query).map_err(ErrorBadRequest)
51   }
52   fn get_limit(&self) -> i64 {
53     self.limit.unwrap_or(RSS_FETCH_LIMIT)
54   }
55   fn get_page(&self) -> i64 {
56     self.page.unwrap_or(1)
57   }
58 }
59
60 enum RequestType {
61   Community,
62   User,
63   Front,
64   Inbox,
65 }
66
67 pub fn config(cfg: &mut web::ServiceConfig) {
68   cfg
69     .route("/feeds/{type}/{name}.xml", web::get().to(get_feed))
70     .route("/feeds/all.xml", web::get().to(get_all_feed))
71     .route("/feeds/local.xml", web::get().to(get_local_feed));
72 }
73
74 static RSS_NAMESPACE: Lazy<BTreeMap<String, String>> = Lazy::new(|| {
75   let mut h = BTreeMap::new();
76   h.insert(
77     "dc".to_string(),
78     rss::extension::dublincore::NAMESPACE.to_string(),
79   );
80   h
81 });
82
83 #[tracing::instrument(skip_all)]
84 async fn get_all_feed(
85   info: web::Query<Params>,
86   context: web::Data<LemmyContext>,
87 ) -> Result<HttpResponse, Error> {
88   Ok(
89     get_feed_data(
90       &context,
91       ListingType::All,
92       info.sort_type()?,
93       info.get_limit(),
94       info.get_page(),
95     )
96     .await?,
97   )
98 }
99
100 #[tracing::instrument(skip_all)]
101 async fn get_local_feed(
102   info: web::Query<Params>,
103   context: web::Data<LemmyContext>,
104 ) -> Result<HttpResponse, Error> {
105   Ok(
106     get_feed_data(
107       &context,
108       ListingType::Local,
109       info.sort_type()?,
110       info.get_limit(),
111       info.get_page(),
112     )
113     .await?,
114   )
115 }
116
117 #[tracing::instrument(skip_all)]
118 async fn get_feed_data(
119   context: &LemmyContext,
120   listing_type: ListingType,
121   sort_type: SortType,
122   limit: i64,
123   page: i64,
124 ) -> Result<HttpResponse, LemmyError> {
125   let site_view = SiteView::read_local(&mut context.pool()).await?;
126
127   let posts = PostQuery::builder()
128     .pool(&mut context.pool())
129     .listing_type(Some(listing_type))
130     .sort(Some(sort_type))
131     .limit(Some(limit))
132     .page(Some(page))
133     .build()
134     .list()
135     .await?;
136
137   let items = create_post_items(posts, &context.settings().get_protocol_and_hostname())?;
138
139   let mut channel_builder = ChannelBuilder::default();
140   channel_builder
141     .namespaces(RSS_NAMESPACE.clone())
142     .title(&format!("{} - {}", site_view.site.name, listing_type))
143     .link(context.settings().get_protocol_and_hostname())
144     .items(items);
145
146   if let Some(site_desc) = site_view.site.description {
147     channel_builder.description(&site_desc);
148   }
149
150   let rss = channel_builder.build().to_string();
151   Ok(
152     HttpResponse::Ok()
153       .content_type("application/rss+xml")
154       .body(rss),
155   )
156 }
157
158 #[tracing::instrument(skip_all)]
159 async fn get_feed(
160   req: HttpRequest,
161   info: web::Query<Params>,
162   context: web::Data<LemmyContext>,
163 ) -> Result<HttpResponse, Error> {
164   let req_type: String = req.match_info().get("type").unwrap_or("none").parse()?;
165   let param: String = req.match_info().get("name").unwrap_or("none").parse()?;
166
167   let request_type = match req_type.as_str() {
168     "u" => RequestType::User,
169     "c" => RequestType::Community,
170     "front" => RequestType::Front,
171     "inbox" => RequestType::Inbox,
172     _ => return Err(ErrorBadRequest(LemmyError::from(anyhow!("wrong_type")))),
173   };
174
175   let jwt_secret = context.secret().jwt_secret.clone();
176   let protocol_and_hostname = context.settings().get_protocol_and_hostname();
177
178   let builder = match request_type {
179     RequestType::User => {
180       get_feed_user(
181         &mut context.pool(),
182         &info.sort_type()?,
183         &info.get_limit(),
184         &info.get_page(),
185         &param,
186         &protocol_and_hostname,
187       )
188       .await
189     }
190     RequestType::Community => {
191       get_feed_community(
192         &mut context.pool(),
193         &info.sort_type()?,
194         &info.get_limit(),
195         &info.get_page(),
196         &param,
197         &protocol_and_hostname,
198       )
199       .await
200     }
201     RequestType::Front => {
202       get_feed_front(
203         &mut context.pool(),
204         &jwt_secret,
205         &info.sort_type()?,
206         &info.get_limit(),
207         &info.get_page(),
208         &param,
209         &protocol_and_hostname,
210       )
211       .await
212     }
213     RequestType::Inbox => {
214       get_feed_inbox(
215         &mut context.pool(),
216         &jwt_secret,
217         &param,
218         &protocol_and_hostname,
219       )
220       .await
221     }
222   }
223   .map_err(ErrorBadRequest)?;
224
225   let rss = builder.build().to_string();
226
227   Ok(
228     HttpResponse::Ok()
229       .content_type("application/rss+xml")
230       .body(rss),
231   )
232 }
233
234 #[tracing::instrument(skip_all)]
235 async fn get_feed_user(
236   pool: &mut DbPool<'_>,
237   sort_type: &SortType,
238   limit: &i64,
239   page: &i64,
240   user_name: &str,
241   protocol_and_hostname: &str,
242 ) -> Result<ChannelBuilder, LemmyError> {
243   let site_view = SiteView::read_local(pool).await?;
244   let person = Person::read_from_name(pool, user_name, false).await?;
245
246   let posts = PostQuery::builder()
247     .pool(pool)
248     .listing_type(Some(ListingType::All))
249     .sort(Some(*sort_type))
250     .creator_id(Some(person.id))
251     .limit(Some(*limit))
252     .page(Some(*page))
253     .build()
254     .list()
255     .await?;
256
257   let items = create_post_items(posts, protocol_and_hostname)?;
258
259   let mut channel_builder = ChannelBuilder::default();
260   channel_builder
261     .namespaces(RSS_NAMESPACE.clone())
262     .title(&format!("{} - {}", site_view.site.name, person.name))
263     .link(person.actor_id.to_string())
264     .items(items);
265
266   Ok(channel_builder)
267 }
268
269 #[tracing::instrument(skip_all)]
270 async fn get_feed_community(
271   pool: &mut DbPool<'_>,
272   sort_type: &SortType,
273   limit: &i64,
274   page: &i64,
275   community_name: &str,
276   protocol_and_hostname: &str,
277 ) -> Result<ChannelBuilder, LemmyError> {
278   let site_view = SiteView::read_local(pool).await?;
279   let community = Community::read_from_name(pool, community_name, false).await?;
280
281   let posts = PostQuery::builder()
282     .pool(pool)
283     .sort(Some(*sort_type))
284     .community_id(Some(community.id))
285     .limit(Some(*limit))
286     .page(Some(*page))
287     .build()
288     .list()
289     .await?;
290
291   let items = create_post_items(posts, protocol_and_hostname)?;
292
293   let mut channel_builder = ChannelBuilder::default();
294   channel_builder
295     .namespaces(RSS_NAMESPACE.clone())
296     .title(&format!("{} - {}", site_view.site.name, community.name))
297     .link(community.actor_id.to_string())
298     .items(items);
299
300   if let Some(community_desc) = community.description {
301     channel_builder.description(&community_desc);
302   }
303
304   Ok(channel_builder)
305 }
306
307 #[tracing::instrument(skip_all)]
308 async fn get_feed_front(
309   pool: &mut DbPool<'_>,
310   jwt_secret: &str,
311   sort_type: &SortType,
312   limit: &i64,
313   page: &i64,
314   jwt: &str,
315   protocol_and_hostname: &str,
316 ) -> Result<ChannelBuilder, LemmyError> {
317   let site_view = SiteView::read_local(pool).await?;
318   let local_user_id = LocalUserId(Claims::decode(jwt, jwt_secret)?.claims.sub);
319   let local_user = LocalUser::read(pool, local_user_id).await?;
320
321   let posts = PostQuery::builder()
322     .pool(pool)
323     .listing_type(Some(ListingType::Subscribed))
324     .local_user(Some(&local_user))
325     .sort(Some(*sort_type))
326     .limit(Some(*limit))
327     .page(Some(*page))
328     .build()
329     .list()
330     .await?;
331
332   let items = create_post_items(posts, protocol_and_hostname)?;
333
334   let mut channel_builder = ChannelBuilder::default();
335   channel_builder
336     .namespaces(RSS_NAMESPACE.clone())
337     .title(&format!("{} - Subscribed", site_view.site.name))
338     .link(protocol_and_hostname)
339     .items(items);
340
341   if let Some(site_desc) = site_view.site.description {
342     channel_builder.description(&site_desc);
343   }
344
345   Ok(channel_builder)
346 }
347
348 #[tracing::instrument(skip_all)]
349 async fn get_feed_inbox(
350   pool: &mut DbPool<'_>,
351   jwt_secret: &str,
352   jwt: &str,
353   protocol_and_hostname: &str,
354 ) -> Result<ChannelBuilder, LemmyError> {
355   let site_view = SiteView::read_local(pool).await?;
356   let local_user_id = LocalUserId(Claims::decode(jwt, jwt_secret)?.claims.sub);
357   let local_user = LocalUser::read(pool, local_user_id).await?;
358   let person_id = local_user.person_id;
359   let show_bot_accounts = local_user.show_bot_accounts;
360
361   let sort = CommentSortType::New;
362
363   let replies = CommentReplyQuery::builder()
364     .pool(pool)
365     .recipient_id(Some(person_id))
366     .my_person_id(Some(person_id))
367     .show_bot_accounts(Some(show_bot_accounts))
368     .sort(Some(sort))
369     .limit(Some(RSS_FETCH_LIMIT))
370     .build()
371     .list()
372     .await?;
373
374   let mentions = PersonMentionQuery::builder()
375     .pool(pool)
376     .recipient_id(Some(person_id))
377     .my_person_id(Some(person_id))
378     .show_bot_accounts(Some(show_bot_accounts))
379     .sort(Some(sort))
380     .limit(Some(RSS_FETCH_LIMIT))
381     .build()
382     .list()
383     .await?;
384
385   let items = create_reply_and_mention_items(replies, mentions, protocol_and_hostname)?;
386
387   let mut channel_builder = ChannelBuilder::default();
388   channel_builder
389     .namespaces(RSS_NAMESPACE.clone())
390     .title(&format!("{} - Inbox", site_view.site.name))
391     .link(format!("{protocol_and_hostname}/inbox",))
392     .items(items);
393
394   if let Some(site_desc) = site_view.site.description {
395     channel_builder.description(&site_desc);
396   }
397
398   Ok(channel_builder)
399 }
400
401 #[tracing::instrument(skip_all)]
402 fn create_reply_and_mention_items(
403   replies: Vec<CommentReplyView>,
404   mentions: Vec<PersonMentionView>,
405   protocol_and_hostname: &str,
406 ) -> Result<Vec<Item>, LemmyError> {
407   let mut reply_items: Vec<Item> = replies
408     .iter()
409     .map(|r| {
410       let reply_url = format!("{}/comment/{}", protocol_and_hostname, r.comment.id);
411       build_item(
412         &r.creator.name,
413         &r.comment.published,
414         &reply_url,
415         &r.comment.content,
416         protocol_and_hostname,
417       )
418     })
419     .collect::<Result<Vec<Item>, LemmyError>>()?;
420
421   let mut mention_items: Vec<Item> = mentions
422     .iter()
423     .map(|m| {
424       let mention_url = format!("{}/comment/{}", protocol_and_hostname, m.comment.id);
425       build_item(
426         &m.creator.name,
427         &m.comment.published,
428         &mention_url,
429         &m.comment.content,
430         protocol_and_hostname,
431       )
432     })
433     .collect::<Result<Vec<Item>, LemmyError>>()?;
434
435   reply_items.append(&mut mention_items);
436   Ok(reply_items)
437 }
438
439 #[tracing::instrument(skip_all)]
440 fn build_item(
441   creator_name: &str,
442   published: &NaiveDateTime,
443   url: &str,
444   content: &str,
445   protocol_and_hostname: &str,
446 ) -> Result<Item, LemmyError> {
447   let mut i = ItemBuilder::default();
448   i.title(format!("Reply from {creator_name}"));
449   let author_url = format!("{protocol_and_hostname}/u/{creator_name}");
450   i.author(format!(
451     "/u/{creator_name} <a href=\"{author_url}\">(link)</a>"
452   ));
453   let dt = DateTime::<Utc>::from_utc(*published, Utc);
454   i.pub_date(dt.to_rfc2822());
455   i.comments(url.to_owned());
456   let guid = GuidBuilder::default().permalink(true).value(url).build();
457   i.guid(guid);
458   i.link(url.to_owned());
459   // TODO add images
460   let html = markdown_to_html(content);
461   i.description(html);
462   Ok(i.build())
463 }
464
465 #[tracing::instrument(skip_all)]
466 fn create_post_items(
467   posts: Vec<PostView>,
468   protocol_and_hostname: &str,
469 ) -> Result<Vec<Item>, LemmyError> {
470   let mut items: Vec<Item> = Vec::new();
471
472   for p in posts {
473     let mut i = ItemBuilder::default();
474     let mut dc_extension = DublinCoreExtensionBuilder::default();
475
476     i.title(p.post.name);
477
478     dc_extension.creators(vec![p.creator.actor_id.to_string()]);
479
480     let dt = DateTime::<Utc>::from_utc(p.post.published, Utc);
481     i.pub_date(dt.to_rfc2822());
482
483     let post_url = format!("{}/post/{}", protocol_and_hostname, p.post.id);
484     i.comments(post_url.clone());
485     let guid = GuidBuilder::default()
486       .permalink(true)
487       .value(&post_url)
488       .build();
489     i.guid(guid);
490
491     let community_url = format!("{}/c/{}", protocol_and_hostname, p.community.name);
492
493     // TODO add images
494     let mut description = format!("submitted by <a href=\"{}\">{}</a> to <a href=\"{}\">{}</a><br>{} points | <a href=\"{}\">{} comments</a>",
495     p.creator.actor_id,
496     p.creator.name,
497     community_url,
498     p.community.name,
499     p.counts.score,
500     post_url,
501     p.counts.comments);
502
503     // If its a url post, add it to the description
504     if let Some(url) = p.post.url {
505       let link_html = format!("<br><a href=\"{url}\">{url}</a>");
506       description.push_str(&link_html);
507       i.link(url.to_string());
508     } else {
509       i.link(post_url.clone());
510     }
511
512     if let Some(body) = p.post.body {
513       let html = markdown_to_html(&body);
514       description.push_str(&html);
515     }
516
517     i.description(description);
518
519     i.dublin_core_ext(dc_extension.build());
520     items.push(i.build());
521   }
522
523   Ok(items)
524 }