architex/lib/matrix_server/state_resolution.ex
Pim Kunis 6f8c224d50 Implement client joined rooms endpoint
Track which rooms a local account has joined
Add some documentation to modules
2021-08-23 12:59:12 +02:00

302 lines
8.9 KiB
Elixir

defmodule MatrixServer.StateResolution do
@moduledoc """
Functions for resolving the state of a Matrix room.
Currently, only state resolution from room version 2 is supported,
see [the Matrix docs](https://spec.matrix.org/unstable/rooms/v2/).
The current implementation of the state resolution algorithm performs
rather badly.
Each time state is resolved, all events in the room are fetched from
the database and loaded into memory.
This is mostly so I didn't have to worry about fetching events from the
database when developing this initial implementation.
Then, the state is calculated using the new event's previous events and auth
events.
To prevent loading all events into memory, and calculating the whole state each
time, we should make snapshots of the state of a room at regular intervals.
It looks like Dendrite does this too.
"""
import Ecto.Query
alias MatrixServer.{Repo, Event, Room}
alias MatrixServer.StateResolution.Authorization
@type state_set :: map()
def resolve(event), do: resolve(event, true)
def resolve(%Event{room_id: room_id} = event, apply_state) do
room_events =
Event
|> where([e], e.room_id == ^room_id)
|> select([e], {e.event_id, e})
|> Repo.all()
|> Enum.into(%{})
resolve(event, room_events, apply_state)
end
def resolve(
%Event{type: type, state_key: state_key, prev_events: prev_event_ids} = event,
room_events,
apply_state
) do
state_sets =
prev_event_ids
|> Enum.map(&room_events[&1])
|> Enum.map(&resolve(&1, room_events, true))
resolved_state = do_resolve(state_sets, room_events)
if apply_state and Event.is_state_event(event) do
Map.put(resolved_state, {type, state_key}, event)
else
resolved_state
end
end
def resolve_forward_extremities(%Event{room_id: room_id}) do
room_events =
Event
|> where([e], e.room_id == ^room_id)
|> select([e], {e.event_id, e})
|> Repo.all()
|> Enum.into(%{})
Event
|> where([e], e.room_id == ^room_id)
|> join(:inner, [e], r in Room, on: e.room_id == r.id)
|> where([e, r], e.event_id == fragment("ANY(?)", r.forward_extremities))
|> Repo.all()
|> Enum.map(&resolve/1)
|> do_resolve(room_events)
end
def full_auth_chain(events, room_events) do
events
|> Enum.map(&auth_chain(&1, room_events))
|> Enum.reduce(MapSet.new(), &MapSet.union/2)
end
def update_state_set(
%Event{type: event_type, state_key: state_key} = event,
state_set
) do
Map.put(state_set, {event_type, state_key}, event)
end
defp do_resolve([], _), do: %{}
defp do_resolve(state_sets, room_events) do
{unconflicted_state_map, conflicted_state_set} = calculate_conflict(state_sets, room_events)
if MapSet.size(conflicted_state_set) == 0 do
unconflicted_state_map
else
do_resolve(state_sets, room_events, unconflicted_state_map, conflicted_state_set)
end
end
defp do_resolve(state_sets, room_events, unconflicted_state_map, conflicted_state_set) do
full_conflicted_set =
MapSet.union(conflicted_state_set, auth_difference(state_sets, room_events))
conflicted_control_event_ids =
full_conflicted_set
|> Enum.filter(&Event.is_control_event(room_events[&1]))
|> MapSet.new()
conflicted_control_events_with_auth_ids =
conflicted_control_event_ids
|> Enum.map(&room_events[&1])
|> full_auth_chain(room_events)
|> MapSet.intersection(full_conflicted_set)
|> MapSet.union(conflicted_control_event_ids)
sorted_control_events =
conflicted_control_events_with_auth_ids
|> Enum.map(&room_events[&1])
|> Enum.sort(rev_top_pow_order(room_events))
partial_resolved_state =
iterative_auth_checks(sorted_control_events, unconflicted_state_map, room_events)
resolved_power_levels = partial_resolved_state[{"m.room.power_levels", ""}]
conflicted_control_events_with_auth_ids
|> MapSet.difference(full_conflicted_set)
|> Enum.sort(mainline_order(resolved_power_levels, room_events))
|> Enum.map(&room_events[&1])
|> iterative_auth_checks(partial_resolved_state, room_events)
|> Map.merge(unconflicted_state_map)
end
defp calculate_conflict(state_sets, room_events) do
{unconflicted, conflicted} =
state_sets
|> Enum.flat_map(&Map.keys/1)
|> MapSet.new()
|> Enum.into(%{}, fn state_pair ->
events =
Enum.map(state_sets, fn
state_set when is_map_key(state_set, state_pair) -> state_set[state_pair].event_id
_ -> nil
end)
|> MapSet.new()
{state_pair, events}
end)
|> Enum.split_with(fn {_, event_ids} ->
MapSet.size(event_ids) == 1
end)
unconflicted_state_map =
Enum.into(unconflicted, %{}, fn {state_pair, event_ids} ->
event_id = MapSet.to_list(event_ids) |> hd()
{state_pair, room_events[event_id]}
end)
conflicted_state_set =
Enum.reduce(conflicted, MapSet.new(), fn {_, events}, acc ->
MapSet.union(acc, events)
end)
|> MapSet.delete(nil)
{unconflicted_state_map, conflicted_state_set}
end
defp auth_difference(state_sets, room_events) do
full_auth_chains =
Enum.map(state_sets, fn state_set ->
state_set
|> Map.values()
|> full_auth_chain(room_events)
end)
auth_chain_union = Enum.reduce(full_auth_chains, MapSet.new(), &MapSet.union/2)
auth_chain_intersection = Enum.reduce(full_auth_chains, MapSet.new(), &MapSet.intersection/2)
MapSet.difference(auth_chain_union, auth_chain_intersection)
end
defp auth_chain(%Event{auth_events: auth_events}, room_events) do
auth_events
|> Enum.map(&room_events[&1])
|> Enum.reduce(MapSet.new(), fn %Event{event_id: auth_event_id} = auth_event, acc ->
auth_event
|> auth_chain(room_events)
|> MapSet.union(acc)
|> MapSet.put(auth_event_id)
end)
end
defp rev_top_pow_order(room_events) do
fn %Event{origin_server_ts: timestamp1, event_id: event_id1} = event1,
%Event{origin_server_ts: timestamp2, event_id: event_id2} = event2 ->
power1 = get_power_level(event1, room_events)
power2 = get_power_level(event2, room_events)
if power1 == power2 do
if timestamp1 == timestamp2 do
event_id1 <= event_id2
else
timestamp1 < timestamp2
end
else
power1 < power2
end
end
end
defp get_power_level(%Event{sender: sender, auth_events: auth_event_ids}, room_events) do
pl_event_id =
Enum.find(auth_event_ids, fn id ->
room_events[id].type == "m.room.power_levels"
end)
# TODO: refactor
case room_events[pl_event_id] do
%Event{content: %{"users" => pl_users}} -> Map.get(pl_users, to_string(sender), 0)
nil -> 0
end
end
defp mainline_order(event, room_events) do
mainline_map =
event
|> mainline(room_events)
|> Enum.with_index()
|> Enum.into(%{})
fn event_id1, event_id2 ->
%Event{origin_server_ts: timestamp1} = event1 = room_events[event_id1]
%Event{origin_server_ts: timestamp2} = event2 = room_events[event_id2]
mainline_depth1 = get_mainline_depth(mainline_map, event1, room_events)
mainline_depth2 = get_mainline_depth(mainline_map, event2, room_events)
if mainline_depth1 == mainline_depth2 do
if timestamp1 == timestamp2 do
event_id1 <= event_id2
else
timestamp1 < timestamp2
end
else
mainline_depth1 < mainline_depth2
end
end
end
defp get_mainline_depth(mainline_map, event, room_events) do
mainline = mainline(event, room_events)
result =
Enum.find_value(mainline, fn mainline_event ->
if Map.has_key?(mainline_map, mainline_event) do
{:ok, mainline_map[mainline_event]}
else
nil
end
end)
case result do
{:ok, index} -> -index
nil -> nil
end
end
defp mainline(event, room_events) do
event
|> mainline([], room_events)
|> Enum.reverse()
end
defp mainline(%Event{auth_events: auth_event_ids} = event, acc, room_events) do
pl_event_id =
Enum.find(auth_event_ids, fn id ->
room_events[id].type == "m.room.power_levels"
end)
case room_events[pl_event_id] do
%Event{} = pl_event -> mainline(pl_event, [event | acc], room_events)
nil -> [event | acc]
end
end
defp iterative_auth_checks(events, state_set, room_events) do
Enum.reduce(events, state_set, fn event, acc ->
if authorized?(event, acc, room_events), do: update_state_set(event, acc), else: acc
end)
end
defp authorized?(%Event{auth_events: auth_event_ids} = event, state_set, room_events) do
state_set =
auth_event_ids
|> Enum.map(&room_events[&1])
|> Enum.reduce(state_set, &update_state_set/2)
Authorization.authorized?(event, state_set)
end
end