?
/**
@file udp_socket_manager.c
@brief Manager for UDP socket flow
@details Copyright (c) 2025 Acronis International GmbH
@author Denis Kopyrin (denis.kopyrin@acronis.com)
@since $Id: $
*/
#include "udp_socket_manager.h"
#include "hashtable_compat.h"
#include "hash_fast.h"
#include "memory.h"
#include "net_compat.h"
#include "net_events.h"
#include <linux/in.h>
#include <linux/in6.h>
#include <linux/jiffies.h>
#include <linux/list.h>
#include <net/ipv6.h>
#include <net/sock.h>
#define TABLE_SIZE_BITS 15
#define TABLE_SIZE (1 << (TABLE_SIZE_BITS - 1)) // 16384
#ifndef list_first_entry_or_null
#define list_first_entry_or_null(ptr, type, member) (list_empty(ptr) ? NULL : list_first_entry(ptr, type, member))
#endif
#ifdef KERNEL_MOCK
#include "mock/mock.h"
#endif
#ifndef SOCKFS_MAGIC
#define SOCKFS_MAGIC 0x534F434B
#endif
// 30 seconds
#define TTL msecs_to_jiffies(30000)
typedef struct
{
size_t data_sz;
// used as a key from socket
const struct inode *inode;
char data[];
} hashtable_udp_socket_key_t;
// I am being overly explicit here to avoid any potential issues with padding
typedef struct PACKED
{
uint32_t remote;
uint32_t local;
uint16_t remote_port;
uint16_t local_port;
} hashtable_udp_socket_key_data_ipv4_t;
typedef struct PACKED
{
uint32_t remote[4];
uint32_t local[4];
uint16_t remote_port;
uint16_t local_port;
} hashtable_udp_socket_key_data_ipv6_t;
typedef struct
{
hashtable_udp_socket_key_t header;
union
{
hashtable_udp_socket_key_data_ipv4_t ipv4;
hashtable_udp_socket_key_data_ipv6_t ipv6;
};
} hashtable_udp_socket_key_common_t;
typedef struct
{
atomic_t refcount;
// hashed by inode
struct hlist_node hash_inode_node;
// hashed by inode + key
struct hlist_node hash_addr_node;
bool lru_list_inserted;
unsigned long lru_deadline;
struct list_head lru_list_node;
struct rcu_head rcu;
hashtable_udp_socket_key_t key;
} hashtable_udp_socket_node_t;
typedef struct udp_socket_manager
{
struct mutex table_writer_lock;
bool active;
ssize_t flows_count;
// Hashtable for all the flows that were seen by UDP.
// Note that this is moreso a hack because normally UDP does not
// contain a notion of a "flow" like TCP does but to avoid
// spamming the client with the same messages, "flow" is introduced
DECLARE_HASHTABLE(seen_flows_addr_hashtable , TABLE_SIZE_BITS);
// For 'inode' mind that they mimic the same entries as 'addr' hashtable
// but there can be repeated inodes - sendmsg might be called with multiple address.
// For lookups normally 'seen_flows_addr_hashtable' should be used.
DECLARE_HASHTABLE(seen_flows_inode_hashtable, TABLE_SIZE_BITS);
struct list_head seen_flows_lru_list;
} udp_socket_manager_t;
static udp_socket_manager_t* global_udp_socket_manager;
// MARK: UDP socket node
static hashtable_udp_socket_node_t* node_alloc(const hashtable_udp_socket_key_t* key)
{
hashtable_udp_socket_node_t* node = mem_alloc(sizeof(hashtable_udp_socket_node_t) + key->data_sz);
if (!node)
return NULL;
atomic_set(&node->refcount, 1);
node->lru_list_inserted = false;
// Linux kernel will complain that I am overwriting the key insecurely so copy carefully
node->key.data_sz = key->data_sz;
node->key.inode = key->inode;
memcpy(node->key.data, key->data, key->data_sz);
return node;
}
static void node_rcu_free(struct rcu_head *rcu)
{
hashtable_udp_socket_node_t *node = container_of(rcu, hashtable_udp_socket_node_t, rcu);
mem_free(node);
}
static void node_put(hashtable_udp_socket_node_t* node) {
if (atomic_dec_and_test(&node->refcount))
call_rcu(&node->rcu, node_rcu_free);
}
static bool pack4(const struct sockaddr_in* sa, int sa_len, void* out_addr, void* out_port)
{
if (sa_len < (int) sizeof(struct sockaddr_in))
return false;
__builtin_memcpy(out_addr, &sa->sin_addr, sizeof(sa->sin_addr));
__builtin_memcpy(out_port, &sa->sin_port, sizeof(sa->sin_port));
return true;
}
static bool pack6(const struct sockaddr_in6* sa6, int sa_len, void* out_addr, void* out_port)
{
if (sa_len < SIN6_LEN_RFC2133)
return false;
__builtin_memcpy(out_addr, &sa6->sin6_addr, sizeof(sa6->sin6_addr));
__builtin_memcpy(out_port, &sa6->sin6_port, sizeof(sa6->sin6_port));
return true;
}
static bool pack_remote(hashtable_udp_socket_key_common_t *key, int family, void* sa, int sa_len)
{
switch (family)
{
case AF_INET:
return pack4((struct sockaddr_in*) sa , sa_len, &key->ipv4.remote, &key->ipv4.remote_port);
case AF_INET6:
return pack6((struct sockaddr_in6*) sa, sa_len, &key->ipv6.remote, &key->ipv6.remote_port);
default:
return false;
}
}
static bool pack_local(hashtable_udp_socket_key_common_t *key, int family, void* sa, int sa_len)
{
switch (family)
{
case AF_INET:
return pack4((struct sockaddr_in*) sa , sa_len, &key->ipv4.local, &key->ipv4.local_port);
case AF_INET6:
return pack6((struct sockaddr_in6*) sa, sa_len, &key->ipv6.local, &key->ipv6.local_port);
default:
return false;
}
}
// MARK: UDP socket key
static bool socket_make_key(hashtable_udp_socket_key_common_t *key, struct socket *sock, struct msghdr *msg)
{
int family = sock->sk->sk_family;
switch (family)
{
case AF_INET:
key->header.data_sz = sizeof(hashtable_udp_socket_key_data_ipv4_t);
break;
case AF_INET6:
key->header.data_sz = sizeof(hashtable_udp_socket_key_data_ipv6_t);
break;
default:
return false;
}
key->header.inode = SOCK_INODE(sock);
if (msg->msg_name) {
if (!pack_remote(key, family, msg->msg_name, msg->msg_namelen))
return false;
} else {
struct sockaddr_storage remote_storage_addr;
int remote_addr_len = sock_to_addr(sock, &remote_storage_addr, PEER_REMOTE_ALWAYS);
if (!pack_remote(key, family, &remote_storage_addr, remote_addr_len))
return false;
}
{
struct sockaddr_storage local_storage_addr;
int local_addr_len = sock_to_addr(sock, &local_storage_addr, PEER_LOCAL);
if (!pack_local(key, family, &local_storage_addr, local_addr_len))
return false;
}
return true;
}
static bool key_equal(const hashtable_udp_socket_key_t* k1, const hashtable_udp_socket_key_t* k2)
{
return k1->data_sz == k2->data_sz && k1->inode == k2->inode && 0 == memcmp(k1->data, k2->data, k1->data_sz);
}
static int key_hash(const hashtable_udp_socket_key_t* key)
{
return murmur_hash(&key->inode, key->data_sz + sizeof(key->inode)) >> (64 - TABLE_SIZE_BITS);
}
static int inode_hash(const struct inode* inode)
{
return moremur_hash((uint64_t) inode, TABLE_SIZE_BITS);
}
// MARK: UDP socket table
// Note that multiple inodes might be inserted into the same hash.
// We just need to check if any exists
static bool inode_is_inserted_rcu(int hash, const struct inode* inode) {
hashtable_udp_socket_node_t *search_node;
hlist_for_each_entry_rcu(search_node, &global_udp_socket_manager->seen_flows_inode_hashtable[hash], hash_inode_node) {
if (search_node->key.inode == inode) {
return true;
}
}
return false;
}
static hashtable_udp_socket_node_t* find_ref_rcu(int hash, const hashtable_udp_socket_key_t* key) {
hashtable_udp_socket_node_t *search_node;
hlist_for_each_entry_rcu(search_node, &global_udp_socket_manager->seen_flows_addr_hashtable[hash], hash_addr_node) {
if (!key_equal(&search_node->key, key))
continue;
if (atomic_inc_not_zero(&search_node->refcount))
return search_node;
else
return NULL;
}
return NULL;
}
static hashtable_udp_socket_node_t* find(int hash, const hashtable_udp_socket_key_t* key) {
hashtable_udp_socket_node_t *search_node;
hlist_for_each_entry(search_node, &global_udp_socket_manager->seen_flows_addr_hashtable[hash], hash_addr_node) {
if (key_equal(&search_node->key, key))
return search_node;
}
return NULL;
}
static void erase_impl(hashtable_udp_socket_node_t *node)
{
hash_del_rcu(&node->hash_inode_node);
hash_del_rcu(&node->hash_addr_node);
list_del(&node->lru_list_node);
node->lru_list_inserted = false;
global_udp_socket_manager->flows_count--;
node_put(node);
}
static void inode_erase_all_impl(int hash, const struct inode* inode) {
hashtable_udp_socket_node_t *search_node;
struct hlist_node* tmp;
hlist_for_each_entry_safe(search_node, tmp, &global_udp_socket_manager->seen_flows_inode_hashtable[hash], hash_inode_node) {
if (search_node->key.inode == inode)
erase_impl(search_node);
}
}
static void refresh_impl(hashtable_udp_socket_node_t* node)
{
if (node->lru_list_inserted) {
node->lru_deadline = jiffies + TTL;
list_del(&node->lru_list_node);
list_add_tail(&node->lru_list_node, &global_udp_socket_manager->seen_flows_lru_list);
}
}
// MARK: UDP socket manager
int udp_socket_manager_init(void)
{
global_udp_socket_manager = vmem_alloc(sizeof(udp_socket_manager_t));
if (!global_udp_socket_manager)
return -ENOMEM;
mutex_init(&global_udp_socket_manager->table_writer_lock);
global_udp_socket_manager->active = false;
global_udp_socket_manager->flows_count = 0;
hash_init(global_udp_socket_manager->seen_flows_inode_hashtable);
hash_init(global_udp_socket_manager->seen_flows_addr_hashtable);
INIT_LIST_HEAD(&global_udp_socket_manager->seen_flows_lru_list);
return 0;
}
void udp_socket_manager_deinit(void)
{
if (!global_udp_socket_manager)
return;
vmem_free(global_udp_socket_manager);
}
void udp_socket_manager_activate(void)
{
mutex_lock(&global_udp_socket_manager->table_writer_lock);
global_udp_socket_manager->active = true;
mutex_unlock(&global_udp_socket_manager->table_writer_lock);
}
void udp_socket_manager_deactivate(void)
{
mutex_lock(&global_udp_socket_manager->table_writer_lock);
if (global_udp_socket_manager->active) {
while (1)
{
hashtable_udp_socket_node_t *node = list_first_entry_or_null(&global_udp_socket_manager->seen_flows_lru_list, hashtable_udp_socket_node_t, lru_list_node);
if (!node)
break;
erase_impl(node);
}
global_udp_socket_manager->active = false;
}
mutex_unlock(&global_udp_socket_manager->table_writer_lock);
}
void udp_socket_manager_inode_free_security(const struct inode *inode)
{
int hash = moremur_hash((uint64_t) inode, TABLE_SIZE_BITS);
bool listed;
if (!inode->i_sb)
return;
if (inode->i_sb->s_magic != SOCKFS_MAGIC)
return;
rcu_read_lock();
listed = inode_is_inserted_rcu(hash, inode);
rcu_read_unlock();
if (!listed)
return;
mutex_lock(&global_udp_socket_manager->table_writer_lock);
inode_erase_all_impl(hash, inode);
mutex_unlock(&global_udp_socket_manager->table_writer_lock);
}
static void sweep_impl(void)
{
while (1)
{
hashtable_udp_socket_node_t *node = list_first_entry_or_null(&global_udp_socket_manager->seen_flows_lru_list, hashtable_udp_socket_node_t, lru_list_node);
if (!node)
break;
if (time_after(jiffies, node->lru_deadline))
erase_impl(node);
else
break;
}
while (global_udp_socket_manager->flows_count > TABLE_SIZE)
{
hashtable_udp_socket_node_t *node = list_first_entry_or_null(&global_udp_socket_manager->seen_flows_lru_list, hashtable_udp_socket_node_t, lru_list_node);
if (!node)
break;
erase_impl(node);
}
}
void udp_socket_manager_sendmsg(task_info_t* task_info, struct socket *sock, struct msghdr *msg, int size)
{
const struct inode *inode = SOCK_INODE(sock);
hashtable_udp_socket_node_t *node;
hashtable_udp_socket_key_common_t key;
int hash_addr;
int hash_inode;
bool send = false;
if (!socket_make_key(&key, sock, msg))
return;
hash_addr = key_hash(&key.header);
hash_inode = inode_hash(inode);
// Lookup for existing node and refresh
rcu_read_lock();
node = find_ref_rcu(hash_addr, &key.header);
rcu_read_unlock();
if (node) {
mutex_lock(&global_udp_socket_manager->table_writer_lock);
refresh_impl(node);
sweep_impl();
mutex_unlock(&global_udp_socket_manager->table_writer_lock);
node_put(node);
return;
}
// No node found, create a new one
node = node_alloc(&key.header);
if (!node)
return;
mutex_lock(&global_udp_socket_manager->table_writer_lock);
if (global_udp_socket_manager->active) {
hashtable_udp_socket_node_t *found_node;
sweep_impl();
found_node = find(hash_addr, &key.header);
if (found_node) {
refresh_impl(found_node);
node_put(node);
} else {
hlist_add_head_rcu(&node->hash_inode_node, &global_udp_socket_manager->seen_flows_inode_hashtable[hash_inode]);
hlist_add_head_rcu(&node->hash_addr_node , &global_udp_socket_manager->seen_flows_addr_hashtable[hash_addr]);
list_add_tail(&node->lru_list_node, &global_udp_socket_manager->seen_flows_lru_list);
node->lru_list_inserted = true;
node->lru_deadline = jiffies + TTL;
global_udp_socket_manager->flows_count++;
send = true;
}
} else {
node_put(node);
}
mutex_unlock(&global_udp_socket_manager->table_writer_lock);
if (send) {
// TODO: Analyze QUIC + DNS
(void) size;
net_event_sendmsg_udp(task_info, sock, msg);
}
}