server.c 10 KB


  1. #include "tunnel.h"
  2. #include "server.h"
  3. #include "buffer.h"
  4. #include "socket_comm.h"
  5. #include <assert.h>
  6. #include <unistd.h>
  7. struct server_info
  8. {
  9. int listen_port[2];
  10. int listen_fd[2];
  11. struct client_info client[TOTAL_CONNECTION];
  12. int client_id[TOTAL_CONNECTION];
  13. int client_cnt;
  14. int max_fd;
  15. int id_idx;
  16. struct ring_buffer* wait_closed;
  17. int listen_id;
  18. fd_set fd_rset;
  19. fd_set fd_wset;
  20. };
  21. static int
  22. server_init(const char* host, int port) {
  23. int fd;
  24. int reuse = 1;
  25. struct addrinfo ai_hints;
  26. struct addrinfo *ai_list = NULL;
  27. char portstr[16];
  28. if (host == NULL || host[0] == 0){
  29. host = "0.0.0.0";
  30. }
  31. sprintf(portstr, "%d", port);
  32. memset(&ai_hints, 0, sizeof(ai_hints));
  33. ai_hints.ai_protocol = IPPROTO_TCP;
  34. ai_hints.ai_socktype = SOCK_STREAM;
  35. ai_hints.ai_family = AF_UNSPEC;
  36. int status = getaddrinfo(host, portstr, &ai_hints, &ai_list);
  37. if (status != 0) {
  38. return -1;
  39. }
  40. fd = socket(ai_list->ai_family, ai_list->ai_socktype, 0);
  41. if (fd < 0) {
  42. goto _failed_fd;
  43. }
  44. if (setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, (void *)&reuse, sizeof(int)) == -1) {
  45. goto _failed;
  46. }
  47. /*bind*/
  48. status = bind(fd, (struct sockaddr *)ai_list->ai_addr, ai_list->ai_addrlen);
  49. if (status != 0)
  50. goto _failed;
  51. //listen
  52. if (listen(fd, 32) == -1) {
  53. close(fd);
  54. fprintf(stderr, "%s listen port %d failed.\n", get_time(), port);
  55. return -1;
  56. }
  57. freeaddrinfo(ai_list);
  58. sp_nonblocking(fd);
  59. return fd;
  60. _failed:
  61. close(fd);
  62. _failed_fd:
  63. freeaddrinfo(ai_list);
  64. return -1;
  65. }
  66. static int
  67. get_id(struct server_info* s) {
  68. int i, ret = -1;
  69. for (i = s->id_idx; i < s->id_idx + TOTAL_CONNECTION; ++i) {
  70. int idx = i % TOTAL_CONNECTION;
  71. ++s->id_idx;
  72. if (s->client[idx].fd == 0) {
  73. ret = i;
  74. break;
  75. }
  76. }
  77. return ret;
  78. }
  79. static void
  80. do_accept(struct server_info* s, int listen_fd) {
  81. union sockaddr_all addr;
  82. socklen_t len = sizeof(addr);
  83. int fd = accept(listen_fd, &addr.s, &len);
  84. if (fd == -1) {
  85. int err = errno;
  86. if (err != EAGAIN) {
  87. fprintf(stderr, "%s accept error: %s.\n", get_time(), strerror(err));
  88. }
  89. return;
  90. }
  91. if (s->client_cnt >= TOTAL_CONNECTION) {
  92. close(fd);
  93. fprintf(stderr, "%s accept error, connection max.............\n", get_time());
  94. return;
  95. }
  96. if (listen_fd == s->listen_fd[0]) {
  97. if (s->listen_id < 0) {
  98. //TODO:可以加个连接缓存,等待下个tunnel上来后,直接连上
  99. close(fd);
  100. fprintf(stderr, "%s accept error, no available connection for now.............\n", get_time());
  101. return;
  102. }
  103. }else {
  104. if (s->listen_id >= 0) {
  105. close(fd);
  106. fprintf(stderr, "%s accept error, tunnel only need one available.............\n", get_time());
  107. return;
  108. }
  109. }
  110. int id = get_id(s);
  111. assert(id != -1);
  112. struct ring_buffer* rb = alloc_ring_buffer(MAX_CLIENT_BUFFER);
  113. struct client_info* nc = s->client + id % TOTAL_CONNECTION;
  114. assert(nc->fd == 0);
  115. nc->id = id;
  116. nc->fd = fd;
  117. nc->buffer = rb;
  118. nc->to_id = -1;
  119. void* sin_addr = (addr.s.sa_family == AF_INET) ? (void*)&addr.v4.sin_addr : (void *)&addr.v6.sin6_addr;
  120. int sin_port = ntohs((addr.s.sa_family == AF_INET) ? addr.v4.sin_port : addr.v6.sin6_port);
  121. static char tmp[128];
  122. if (inet_ntop(addr.s.sa_family, sin_addr, tmp, sizeof(tmp))) {
  123. snprintf(nc->client_ip, sizeof(nc->client_ip), "%s:%d", tmp, sin_port);
  124. }
  125. fprintf(stderr, "%s client %s connected.\n", get_time(), nc->client_ip);
  126. s->client_id[s->client_cnt++] = id;
  127. FD_SET(fd, &s->fd_rset);
  128. if (s->max_fd < fd + 1) {
  129. s->max_fd = fd + 1;//TODO:最大的fd被close后是否要处理下
  130. }
  131. set_keep_alive(fd);
  132. sp_nonblocking(fd);
  133. if (listen_fd == s->listen_fd[0]) {
  134. nc->to_id = s->listen_id;
  135. struct client_info* listen_nc = s->client + nc->to_id % TOTAL_CONNECTION;
  136. assert(listen_nc->fd >= 0 && listen_nc->id == nc->to_id);
  137. listen_nc->to_id = nc->id;
  138. s->listen_id = -1;
  139. }else {
  140. s->listen_id = id;
  141. }
  142. }
  143. static void
  144. do_close(struct server_info* s, struct client_info* c) {
  145. int i;
  146. for (i = 0; i < s->client_cnt; ++i) {
  147. if (s->client_id[i] == c->id) {
  148. memcpy(s->client_id + i, s->client_id + i + 1, (s->client_cnt - i - 1) * sizeof(int));
  149. --s->client_cnt;
  150. break;
  151. }
  152. }
  153. FD_CLR(c->fd, &s->fd_rset);
  154. FD_CLR(c->fd, &s->fd_wset);
  155. close(c->fd);
  156. int idx = c->id % TOTAL_CONNECTION;
  157. assert(&s->client[idx] == c);
  158. if (c->id == s->listen_id) {
  159. s->listen_id = -1;
  160. }
  161. if (c->to_id >= 0 && s->client[c->to_id % TOTAL_CONNECTION].id == c->to_id) {
  162. int len;
  163. char* id_buffer = get_ring_buffer_write_ptr(s->wait_closed, &len);
  164. int id_len = sizeof(int);
  165. assert(id_buffer && len >= id_len);
  166. memcpy(id_buffer, &c->to_id, id_len);
  167. move_ring_buffer_write_pos(s->wait_closed, id_len);
  168. }
  169. c->to_id = -1;
  170. c->id = -1;
  171. c->fd = 0;
  172. free_ring_buffer(c->buffer);
  173. c->buffer = NULL;
  174. fprintf(stderr, "%s client %s disconnect.\n", get_time(), c->client_ip);
  175. memset(c->client_ip, 0, sizeof(c->client_ip));
  176. }
  177. /*
  178. static int
  179. try_write(struct server_info* s, struct client_info* c) {
  180. int len;
  181. char* buffer = get_ring_buffer_read_ptr(c->buffer, &len);
  182. if (!buffer) {
  183. return 0; //empty
  184. }
  185. int n = write(c->fd, buffer, len);
  186. if (n < 0) {
  187. switch (errno) {
  188. case EINTR:
  189. case EAGAIN:
  190. break;
  191. default:
  192. fprintf(stderr, "server: write to (id=%d) error :%s.\n", c->id, strerror(errno));
  193. do_close(s, c);
  194. return -1;
  195. }
  196. }else {
  197. move_ring_buffer_read_pos(c->buffer, n);
  198. }
  199. if (!is_ring_buffer_empty(c->buffer)) {
  200. FD_SET(c->fd, &s->fd_wset);
  201. }
  202. return 1;
  203. }
  204. */
  205. static int
  206. do_read(struct server_info* s, struct client_info* c) {
  207. int id = c->to_id;
  208. if (id < 0) {
  209. do_close(s, c); //only when client disconnect
  210. return -1;
  211. }
  212. struct client_info* to_c = s->client + id % TOTAL_CONNECTION;
  213. if (to_c->id != id) {
  214. do_close(s, c);
  215. return -1;
  216. }
  217. struct ring_buffer* rb = to_c->buffer;
  218. int len;
  219. char* start_buffer = get_ring_buffer_write_ptr(rb, &len);
  220. if (!start_buffer) {
  221. return 0; //buff fulled
  222. }
  223. int n = (int)read(c->fd, start_buffer, len);
  224. if (n == -1) {
  225. switch (errno) {
  226. case EAGAIN:
  227. fprintf(stderr, "%s read fd error:EAGAIN.\n", get_time());
  228. break;
  229. case EINTR:
  230. break;
  231. default:
  232. fprintf(stderr, "%s server: read (id=%d) error :%s.\n", get_time(), c->id, strerror(errno));
  233. do_close(s, c);
  234. return -1;
  235. }
  236. return 1;
  237. }
  238. if (n == 0) {
  239. do_close(s, c); //normal close
  240. return -1;
  241. }
  242. move_ring_buffer_write_pos(rb, n);
  243. FD_SET(to_c->fd, &s->fd_wset);
  244. if (n == len && !is_ring_buffer_empty(rb)) {
  245. fprintf(stderr, "%s server: read again.\n", get_time());
  246. return do_read(s, c);
  247. }
  248. return 1;
  249. }
  250. static int
  251. do_write(struct server_info* s, struct client_info* c) {
  252. int len;
  253. char* buffer = get_ring_buffer_read_ptr(c->buffer, &len);
  254. if (!buffer) {
  255. return 0;
  256. }
  257. int writed_len = 0;
  258. char need_break = 0;
  259. while (!need_break && writed_len < len) {
  260. int n = write(c->fd, buffer, len - writed_len);
  261. if (n < 0) {
  262. switch (errno) {
  263. case EINTR:
  264. n = 0;
  265. break;
  266. case EAGAIN:
  267. n = 0;
  268. need_break = 1;
  269. break;
  270. default:
  271. need_break = 1;
  272. fprintf(stderr, "%s socket-server: write to (id=%d) error :%s.\n", get_time(), c->id, strerror(errno));
  273. do_close(s, c);
  274. return -1;
  275. }
  276. } else {
  277. writed_len += n;
  278. buffer += n;
  279. }
  280. }
  281. move_ring_buffer_read_pos(c->buffer, writed_len);
  282. if (is_ring_buffer_empty(c->buffer)) {
  283. FD_CLR(c->fd, &s->fd_wset);
  284. } else if (writed_len == len) {
  285. fprintf(stderr, "%s server: write again.\n", get_time());
  286. return do_write(s, c);
  287. }
  288. return 1;
  289. }
  290. static void
  291. pre_check_close(struct server_info* s) {
  292. int len;
  293. char* id_buffer = get_ring_buffer_read_ptr(s->wait_closed, &len);
  294. if (!id_buffer) return;
  295. int id_len = sizeof(int);
  296. assert(len % id_len == 0);
  297. int tmp = len;
  298. while (len > 0) {
  299. int* id = (int*)id_buffer;
  300. int idx = *id % TOTAL_CONNECTION;
  301. id_buffer += id_len;
  302. len -= len;
  303. struct client_info* c = s->client + idx;
  304. if (c->fd > 0) {
  305. if (do_write(s, c) != -1) {
  306. do_close(s, c);
  307. }
  308. }
  309. }
  310. move_ring_buffer_read_pos(s->wait_closed, tmp);
  311. }
  312. static void*
  313. server_thread(void* param) {
  314. struct server_param *tp = (struct server_param*)param;
  315. int fd1 = server_init(NULL, tp->listen_port[0]);
  316. if (fd1 == -1) {
  317. return NULL;
  318. }
  319. int fd2 = server_init(NULL, tp->listen_port[1]);
  320. if (fd2 == -1) {
  321. close(fd1);
  322. return NULL;
  323. }
  324. struct server_info s;
  325. memset(&s, 0, sizeof(s));
  326. s.listen_fd[0] = fd1;
  327. s.listen_fd[1] = fd2;
  328. s.listen_port[0] = tp->listen_port[0];
  329. s.listen_port[1] = tp->listen_port[1];
  330. int tmp_fd = fd1 > fd2 ? fd1: fd2;
  331. tmp_fd = tp->pid > tmp_fd ? tp->pid : tmp_fd;
  332. s.max_fd = tmp_fd + 1;
  333. s.listen_id = -1;
  334. s.wait_closed = alloc_ring_buffer(TOTAL_CONNECTION * sizeof(int));
  335. FD_ZERO(&s.fd_wset);
  336. FD_ZERO(&s.fd_rset);
  337. FD_SET(fd1, &s.fd_rset);
  338. FD_SET(fd2, &s.fd_rset);
  339. FD_SET(tp->pid, &s.fd_rset);
  340. while (1) {
  341. pre_check_close(&s);
  342. fd_set r_set = s.fd_rset;
  343. fd_set w_set = s.fd_wset;
  344. int cnt = select(s.max_fd, &r_set, &w_set, NULL, NULL);
  345. if (cnt == -1) {
  346. fprintf(stderr, "%s select error %s.\n", get_time(), strerror(errno));
  347. continue;
  348. }
  349. if (FD_ISSET(s.listen_fd[1], &r_set)) {
  350. //accept
  351. --cnt;
  352. do_accept(&s, s.listen_fd[1]);
  353. }
  354. if (FD_ISSET(s.listen_fd[0], &r_set)) {
  355. //accept
  356. --cnt;
  357. do_accept(&s, s.listen_fd[0]);
  358. }
  359. int i;
  360. for (i = s.client_cnt - 1; i >= 0 && cnt > 0; --i) {
  361. int id = s.client_id[i] % TOTAL_CONNECTION;
  362. struct client_info* c = &s.client[id];
  363. int fd = c->fd;
  364. assert(fd > 0);
  365. if (FD_ISSET(fd, &r_set)) {
  366. //read
  367. --cnt;
  368. if (do_read(&s, c) == -1) continue;
  369. }
  370. if (FD_ISSET(fd, &w_set)) {
  371. //write
  372. --cnt;
  373. if (do_write(&s, c) == -1) continue;
  374. }
  375. }
  376. if (FD_ISSET(tp->pid, &r_set)) {
  377. //exit
  378. break;
  379. }
  380. }
  381. close(s.listen_fd[0]);
  382. close(s.listen_fd[1]);
  383. //try send the last buffer
  384. fprintf(stderr, "%s ====================SERVER: SEND LAST DATA BEGIN===================.\n", get_time());
  385. int i;
  386. for (i = s.client_cnt - 1; i >= 0; --i) {
  387. int id = s.client_id[i] % TOTAL_CONNECTION;
  388. struct client_info* c = &s.client[id];
  389. int fd = c->fd;
  390. assert(fd > 0);
  391. if (do_write(&s, c) != -1) {
  392. do_close(&s, c);
  393. }
  394. }
  395. fprintf(stderr, "%s ====================SERVER SEND LAST DATA END=====================.\n", get_time());
  396. free_ring_buffer(s.wait_closed);
  397. assert(s.client_cnt == 0);
  398. return NULL;
  399. }
  400. pthread_t
  401. start_server(struct server_param* tp) {
  402. pthread_t pid;
  403. if (pthread_create(&pid, NULL, server_thread, tp)) {
  404. fprintf(stderr, "%s Create server thread failed.\n", get_time());
  405. exit(1);
  406. return 0;
  407. }
  408. return pid;
  409. }