123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496 |
- #include "tunnel.h"
- #include "server.h"
- #include "buffer.h"
- #include "socket_comm.h"
- #include <assert.h>
- #include <unistd.h>
- struct server_info
- {
- int listen_port[2];
- int listen_fd[2];
- struct client_info client[TOTAL_CONNECTION];
- int client_id[TOTAL_CONNECTION];
- int client_cnt;
- int max_fd;
- int id_idx;
- struct ring_buffer* wait_closed;
- int listen_id;
- fd_set fd_rset;
- fd_set fd_wset;
- };
- static int
- server_init(const char* host, int port) {
- int fd;
- int reuse = 1;
- struct addrinfo ai_hints;
- struct addrinfo *ai_list = NULL;
- char portstr[16];
- if (host == NULL || host[0] == 0){
- host = "0.0.0.0";
- }
- sprintf(portstr, "%d", port);
- memset(&ai_hints, 0, sizeof(ai_hints));
- ai_hints.ai_protocol = IPPROTO_TCP;
- ai_hints.ai_socktype = SOCK_STREAM;
- ai_hints.ai_family = AF_UNSPEC;
- int status = getaddrinfo(host, portstr, &ai_hints, &ai_list);
- if (status != 0) {
- return -1;
- }
- fd = socket(ai_list->ai_family, ai_list->ai_socktype, 0);
- if (fd < 0) {
- goto _failed_fd;
- }
- if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (void *)&reuse, sizeof(int)) == -1) {
- goto _failed;
- }
- /*bind*/
- status = bind(fd, (struct sockaddr *)ai_list->ai_addr, ai_list->ai_addrlen);
- if (status != 0)
- goto _failed;
- //listen
- if (listen(fd, 32) == -1) {
- close(fd);
- fprintf(stderr, "%s listen port %d failed.\n", get_time(), port);
- return -1;
- }
- freeaddrinfo(ai_list);
- sp_nonblocking(fd);
- return fd;
- _failed:
- close(fd);
- _failed_fd:
- freeaddrinfo(ai_list);
- return -1;
- }
- static int
- get_id(struct server_info* s) {
- int i, ret = -1;
- for (i = s->id_idx; i < s->id_idx + TOTAL_CONNECTION; ++i) {
- int idx = i % TOTAL_CONNECTION;
- ++s->id_idx;
- if (s->client[idx].fd == 0) {
- ret = i;
- break;
- }
- }
- return ret;
- }
- static void
- do_accept(struct server_info* s, int listen_fd) {
- union sockaddr_all addr;
- socklen_t len = sizeof(addr);
- int fd = accept(listen_fd, &addr.s, &len);
- if (fd == -1) {
- int err = errno;
- if (err != EAGAIN) {
- fprintf(stderr, "%s accept error: %s.\n", get_time(), strerror(err));
- }
- return;
- }
- if (s->client_cnt >= TOTAL_CONNECTION) {
- close(fd);
- fprintf(stderr, "%s accept error, connection max.............\n", get_time());
- return;
- }
- if (listen_fd == s->listen_fd[0]) {
- if (s->listen_id < 0) {
- //TODO:可以加个连接缓存,等待下个tunnel上来后,直接连上
- close(fd);
- fprintf(stderr, "%s accept error, no available connection for now.............\n", get_time());
- return;
- }
- }else {
- if (s->listen_id >= 0) {
- close(fd);
- fprintf(stderr, "%s accept error, tunnel only need one available.............\n", get_time());
- return;
- }
- }
- int id = get_id(s);
- assert(id != -1);
- struct ring_buffer* rb = alloc_ring_buffer(MAX_CLIENT_BUFFER);
- struct client_info* nc = s->client + id % TOTAL_CONNECTION;
- assert(nc->fd == 0);
- nc->id = id;
- nc->fd = fd;
- nc->buffer = rb;
- nc->to_id = -1;
- void* sin_addr = (addr.s.sa_family == AF_INET) ? (void*)&addr.v4.sin_addr : (void *)&addr.v6.sin6_addr;
- int sin_port = ntohs((addr.s.sa_family == AF_INET) ? addr.v4.sin_port : addr.v6.sin6_port);
- static char tmp[128];
- if (inet_ntop(addr.s.sa_family, sin_addr, tmp, sizeof(tmp))) {
- snprintf(nc->client_ip, sizeof(nc->client_ip), "%s:%d", tmp, sin_port);
- }
- fprintf(stderr, "%s client %s connected.\n", get_time(), nc->client_ip);
- s->client_id[s->client_cnt++] = id;
- FD_SET(fd, &s->fd_rset);
- if (s->max_fd < fd + 1) {
- s->max_fd = fd + 1;//TODO:最大的fd被close后是否要处理下
- }
- set_keep_alive(fd);
- sp_nonblocking(fd);
- if (listen_fd == s->listen_fd[0]) {
- nc->to_id = s->listen_id;
- struct client_info* listen_nc = s->client + nc->to_id % TOTAL_CONNECTION;
- assert(listen_nc->fd >= 0 && listen_nc->id == nc->to_id);
- listen_nc->to_id = nc->id;
- s->listen_id = -1;
- }else {
- s->listen_id = id;
- }
- }
- static void
- do_close(struct server_info* s, struct client_info* c) {
- int i;
- for (i = 0; i < s->client_cnt; ++i) {
- if (s->client_id[i] == c->id) {
- memcpy(s->client_id + i, s->client_id + i + 1, (s->client_cnt - i - 1) * sizeof(int));
- --s->client_cnt;
- break;
- }
- }
- FD_CLR(c->fd, &s->fd_rset);
- FD_CLR(c->fd, &s->fd_wset);
- close(c->fd);
- int idx = c->id % TOTAL_CONNECTION;
- assert(&s->client[idx] == c);
- if (c->id == s->listen_id) {
- s->listen_id = -1;
- }
- if (c->to_id >= 0 && s->client[c->to_id % TOTAL_CONNECTION].id == c->to_id) {
- int len;
- char* id_buffer = get_ring_buffer_write_ptr(s->wait_closed, &len);
- int id_len = sizeof(int);
- assert(id_buffer && len >= id_len);
- memcpy(id_buffer, &c->to_id, id_len);
- move_ring_buffer_write_pos(s->wait_closed, id_len);
- }
- c->to_id = -1;
- c->id = -1;
- c->fd = 0;
- free_ring_buffer(c->buffer);
- c->buffer = NULL;
- fprintf(stderr, "%s client %s disconnect.\n", get_time(), c->client_ip);
- memset(c->client_ip, 0, sizeof(c->client_ip));
- }
- /*
- static int
- try_write(struct server_info* s, struct client_info* c) {
- int len;
- char* buffer = get_ring_buffer_read_ptr(c->buffer, &len);
- if (!buffer) {
- return 0; //empty
- }
- int n = write(c->fd, buffer, len);
- if (n < 0) {
- switch (errno) {
- case EINTR:
- case EAGAIN:
- break;
- default:
- fprintf(stderr, "server: write to (id=%d) error :%s.\n", c->id, strerror(errno));
- do_close(s, c);
- return -1;
- }
- }else {
- move_ring_buffer_read_pos(c->buffer, n);
- }
- if (!is_ring_buffer_empty(c->buffer)) {
- FD_SET(c->fd, &s->fd_wset);
- }
- return 1;
- }
- */
- static int
- do_read(struct server_info* s, struct client_info* c) {
- int id = c->to_id;
- if (id < 0) {
- do_close(s, c); //only when client disconnect
- return -1;
- }
- struct client_info* to_c = s->client + id % TOTAL_CONNECTION;
- if (to_c->id != id) {
- do_close(s, c);
- return -1;
- }
- struct ring_buffer* rb = to_c->buffer;
- int len;
- char* start_buffer = get_ring_buffer_write_ptr(rb, &len);
- if (!start_buffer) {
- return 0; //buff fulled
- }
- int n = (int)read(c->fd, start_buffer, len);
- if (n == -1) {
- switch (errno) {
- case EAGAIN:
- fprintf(stderr, "%s read fd error:EAGAIN.\n", get_time());
- break;
- case EINTR:
- break;
- default:
- fprintf(stderr, "%s server: read (id=%d) error :%s.\n", get_time(), c->id, strerror(errno));
- do_close(s, c);
- return -1;
- }
- return 1;
- }
- if (n == 0) {
- do_close(s, c); //normal close
- return -1;
- }
- move_ring_buffer_write_pos(rb, n);
- FD_SET(to_c->fd, &s->fd_wset);
- if (n == len && !is_ring_buffer_empty(rb)) {
- fprintf(stderr, "%s server: read again.\n", get_time());
- return do_read(s, c);
- }
- return 1;
- }
- static int
- do_write(struct server_info* s, struct client_info* c) {
- int len;
- char* buffer = get_ring_buffer_read_ptr(c->buffer, &len);
- if (!buffer) {
- return 0;
- }
- int writed_len = 0;
- char need_break = 0;
- while (!need_break && writed_len < len) {
- int n = write(c->fd, buffer, len - writed_len);
- if (n < 0) {
- switch (errno) {
- case EINTR:
- n = 0;
- break;
- case EAGAIN:
- n = 0;
- need_break = 1;
- break;
- default:
- need_break = 1;
- fprintf(stderr, "%s socket-server: write to (id=%d) error :%s.\n", get_time(), c->id, strerror(errno));
- do_close(s, c);
- return -1;
- }
- } else {
- writed_len += n;
- buffer += n;
- }
- }
- move_ring_buffer_read_pos(c->buffer, writed_len);
- if (is_ring_buffer_empty(c->buffer)) {
- FD_CLR(c->fd, &s->fd_wset);
- } else if (writed_len == len) {
- fprintf(stderr, "%s server: write again.\n", get_time());
- return do_write(s, c);
- }
- return 1;
- }
- static void
- pre_check_close(struct server_info* s) {
- int len;
- char* id_buffer = get_ring_buffer_read_ptr(s->wait_closed, &len);
- if (!id_buffer) return;
- int id_len = sizeof(int);
- assert(len % id_len == 0);
- int tmp = len;
-
- while (len > 0) {
- int* id = (int*)id_buffer;
- int idx = *id % TOTAL_CONNECTION;
- id_buffer += id_len;
- len -= len;
- struct client_info* c = s->client + idx;
- if (c->fd > 0) {
- if (do_write(s, c) != -1) {
- do_close(s, c);
- }
- }
- }
- move_ring_buffer_read_pos(s->wait_closed, tmp);
- }
- static void*
- server_thread(void* param) {
- struct server_param *tp = (struct server_param*)param;
- int fd1 = server_init(NULL, tp->listen_port[0]);
- if (fd1 == -1) {
- return NULL;
- }
-
- int fd2 = server_init(NULL, tp->listen_port[1]);
- if (fd2 == -1) {
- close(fd1);
- return NULL;
- }
- struct server_info s;
- memset(&s, 0, sizeof(s));
- s.listen_fd[0] = fd1;
- s.listen_fd[1] = fd2;
- s.listen_port[0] = tp->listen_port[0];
- s.listen_port[1] = tp->listen_port[1];
- int tmp_fd = fd1 > fd2 ? fd1: fd2;
- tmp_fd = tp->pid > tmp_fd ? tp->pid : tmp_fd;
- s.max_fd = tmp_fd + 1;
- s.listen_id = -1;
- s.wait_closed = alloc_ring_buffer(TOTAL_CONNECTION * sizeof(int));
- FD_ZERO(&s.fd_wset);
- FD_ZERO(&s.fd_rset);
- FD_SET(fd1, &s.fd_rset);
- FD_SET(fd2, &s.fd_rset);
- FD_SET(tp->pid, &s.fd_rset);
- while (1) {
- pre_check_close(&s);
- fd_set r_set = s.fd_rset;
- fd_set w_set = s.fd_wset;
- int cnt = select(s.max_fd, &r_set, &w_set, NULL, NULL);
- if (cnt == -1) {
- fprintf(stderr, "%s select error %s.\n", get_time(), strerror(errno));
- continue;
- }
- if (FD_ISSET(s.listen_fd[1], &r_set)) {
- //accept
- --cnt;
- do_accept(&s, s.listen_fd[1]);
- }
- if (FD_ISSET(s.listen_fd[0], &r_set)) {
- //accept
- --cnt;
- do_accept(&s, s.listen_fd[0]);
- }
- int i;
- for (i = s.client_cnt - 1; i >= 0 && cnt > 0; --i) {
- int id = s.client_id[i] % TOTAL_CONNECTION;
- struct client_info* c = &s.client[id];
- int fd = c->fd;
- assert(fd > 0);
- if (FD_ISSET(fd, &r_set)) {
- //read
- --cnt;
- if (do_read(&s, c) == -1) continue;
- }
- if (FD_ISSET(fd, &w_set)) {
- //write
- --cnt;
- if (do_write(&s, c) == -1) continue;
- }
- }
- if (FD_ISSET(tp->pid, &r_set)) {
- //exit
- break;
- }
- }
- close(s.listen_fd[0]);
- close(s.listen_fd[1]);
- //try send the last buffer
- fprintf(stderr, "%s ====================SERVER: SEND LAST DATA BEGIN===================.\n", get_time());
- int i;
- for (i = s.client_cnt - 1; i >= 0; --i) {
- int id = s.client_id[i] % TOTAL_CONNECTION;
- struct client_info* c = &s.client[id];
- int fd = c->fd;
- assert(fd > 0);
- if (do_write(&s, c) != -1) {
- do_close(&s, c);
- }
- }
- fprintf(stderr, "%s ====================SERVER SEND LAST DATA END=====================.\n", get_time());
- free_ring_buffer(s.wait_closed);
- assert(s.client_cnt == 0);
- return NULL;
- }
- pthread_t
- start_server(struct server_param* tp) {
- pthread_t pid;
- if (pthread_create(&pid, NULL, server_thread, tp)) {
- fprintf(stderr, "%s Create server thread failed.\n", get_time());
- exit(1);
- return 0;
- }
- return pid;
- }
|