client.c 10.0 KB


  1. #include "tunnel.h"
  2. #include "client.h"
  3. #include "socket_comm.h"
  4. #include "sys/time.h"
  5. #include <assert.h>
  6. struct client {
  7. char remote_ip[128];
  8. int remote_port;
  9. int ssh_port;
  10. int free_connection;
  11. int id_idx;
  12. struct ring_buffer* wait_closed;
  13. int max_fd;
  14. fd_set fd_rset;
  15. fd_set fd_wset;
  16. int time;
  17. int cnt;
  18. struct client_info all_fds[TOTAL_CONNECTION];
  19. int all_ids[TOTAL_CONNECTION];
  20. };
  21. static int
  22. get_id(struct client* c) {
  23. int i, ret = -1;
  24. for (i = c->id_idx; i < c->id_idx + TOTAL_CONNECTION; ++i) {
  25. int idx = i % TOTAL_CONNECTION;
  26. ++c->id_idx;
  27. if (c->all_fds[idx].fd == 0) {
  28. ret = i;
  29. break;
  30. }
  31. }
  32. return ret;
  33. }
  34. static int
  35. connect_to(struct client* c, int ssh_id){
  36. if (c->cnt >= TOTAL_CONNECTION) {
  37. fprintf(stderr, "%s client max connection.....\n", get_time());
  38. return -1;
  39. }
  40. const char* ip;
  41. int port;
  42. if (ssh_id < 0) {
  43. if (c->free_connection > 0) return -1;
  44. struct timeval tv;
  45. gettimeofday(&tv, NULL);
  46. if (tv.tv_sec - c->time < FREE_CONNECT_TIME) {
  47. return -1;
  48. }else {
  49. c->time = tv.tv_sec;
  50. }
  51. ip = c->remote_ip;
  52. port = c->remote_port;
  53. } else {
  54. ip = "0.0.0.0";
  55. port = c->ssh_port;
  56. }
  57. int id;
  58. int idx;
  59. struct client_info* info;
  60. struct ring_buffer* rb;
  61. struct addrinfo hints;
  62. struct addrinfo* res = NULL;
  63. struct addrinfo* ai_ptr = NULL;
  64. memset(&hints, 0, sizeof(hints));
  65. hints.ai_family = AF_UNSPEC;
  66. hints.ai_socktype = SOCK_STREAM;
  67. hints.ai_protocol = IPPROTO_TCP;
  68. char portstr[16];
  69. sprintf(portstr, "%d", port);
  70. int status = getaddrinfo(ip, portstr, &hints, &res);
  71. if (status != 0) {
  72. return -1;
  73. }
  74. int sock = -1;
  75. for (ai_ptr = res; ai_ptr != NULL; ai_ptr = ai_ptr->ai_next) {
  76. sock = socket(ai_ptr->ai_family, ai_ptr->ai_socktype, ai_ptr->ai_protocol);
  77. if (sock < 0) {
  78. continue;
  79. }
  80. set_keep_alive(sock);
  81. sp_nonblocking(sock);
  82. status = connect(sock, ai_ptr->ai_addr, ai_ptr->ai_addrlen);
  83. if (status != 0 && errno != EINPROGRESS) {
  84. close(sock);
  85. sock = -1;
  86. continue;
  87. }
  88. break;
  89. }
  90. if (sock < 0) {
  91. goto _failed;
  92. }
  93. id = get_id(c);
  94. assert(id != -1);
  95. idx = id % TOTAL_CONNECTION;
  96. info = &c->all_fds[idx];
  97. info->fd = sock;
  98. info->id = id;
  99. info->to_id = ssh_id;
  100. snprintf(info->client_ip, sizeof(info->client_ip), "%s:%d", ip, port);
  101. rb = alloc_ring_buffer(MAX_CLIENT_BUFFER);
  102. info->buffer = rb;
  103. c->all_ids[c->cnt++] = id;
  104. if (ssh_id < 0) {
  105. c->free_connection += 1;
  106. }
  107. if (status != 0) {
  108. //connect no block, need check after
  109. FD_SET(sock, &c->fd_wset);
  110. info->connect_type = SOCKET_CONNECTING;
  111. }else {
  112. //success
  113. FD_SET(sock, &c->fd_rset);
  114. info->connect_type = SOCKET_CONNECTED;
  115. struct sockaddr* addr = ai_ptr->ai_addr;
  116. void* sin_addr = (ai_ptr->ai_family == AF_INET) ? (void*)&((struct sockaddr_in*)addr)->sin_addr : (void*)&((struct sockaddr_in6*)addr)->sin6_addr;
  117. inet_ntop(ai_ptr->ai_family, sin_addr, info->client_ip, sizeof(info->client_ip));
  118. fprintf(stderr, "%s connected to %s. \n", get_time(), info->client_ip);
  119. }
  120. if (c->max_fd < sock + 1) {
  121. c->max_fd = sock + 1;
  122. }
  123. return id;
  124. _failed:
  125. freeaddrinfo(res);
  126. return -1;
  127. }
  128. static void
  129. do_close(struct client* c, struct client_info* info) {
  130. int i;
  131. for (i = 0; i < c->cnt; ++i) {
  132. if (c->all_ids[i] == info->id) {
  133. memcpy(c->all_ids + i, c->all_ids + i + 1, (c->cnt - i - 1) * sizeof(int));
  134. --c->cnt;
  135. break;
  136. }
  137. }
  138. FD_CLR(info->fd, &c->fd_rset);
  139. FD_CLR(info->fd, &c->fd_wset);
  140. close(info->fd);
  141. if (info->to_id == -1) {
  142. assert(c->free_connection == 1);
  143. c->free_connection = 0;
  144. }
  145. if (info->to_id >= 0 && c->all_fds[info->to_id % TOTAL_CONNECTION].id == info->to_id ) {
  146. int len;
  147. char* id_buffer = get_ring_buffer_write_ptr(c->wait_closed, &len);
  148. int id_len = sizeof(int);
  149. assert(id_buffer && len >= id_len);
  150. memcpy(id_buffer, &info->to_id, id_len);
  151. move_ring_buffer_write_pos(c->wait_closed, id_len);
  152. }
  153. if (info->connect_type == SOCKET_CONNECTED)
  154. {
  155. fprintf(stderr, "%s client disconnect from %s.\n", get_time(), info->client_ip);
  156. }
  157. info->to_id = -1;
  158. info->id = -1;
  159. info->fd = 0;
  160. info->connect_type = 0;
  161. free_ring_buffer(info->buffer);
  162. info->buffer = NULL;
  163. memset(info->client_ip, 0, sizeof(info->client_ip));
  164. }
  165. static int
  166. report_connect(struct client* c, struct client_info* info) {
  167. int error;
  168. socklen_t len = sizeof(error);
  169. int code = getsockopt(info->fd, SOL_SOCKET, SO_ERROR, &error, &len);
  170. if (code != 0 || error != 0) {
  171. //connect fail, close it
  172. fprintf(stderr, "%s client: connect to %s error :%s. \n", get_time(), info->client_ip, strerror(error));
  173. do_close(c, info);
  174. return -1;
  175. }
  176. info->connect_type = SOCKET_CONNECTED;
  177. FD_SET(info->fd, &c->fd_rset);
  178. union sockaddr_all u;
  179. socklen_t slen = sizeof(u);
  180. if (getpeername(info->fd, &u.s, &slen) == 0){
  181. void* sin_addr = (u.s.sa_family == AF_INET) ? (void*)&u.v4.sin_addr : (void*)&u.v6.sin6_addr;
  182. inet_ntop(u.s.sa_family, sin_addr, info->client_ip, sizeof(info->client_ip));
  183. }
  184. fprintf(stderr, "%s connected to %s. \n", get_time(), info->client_ip);
  185. return 0;
  186. }
  187. static int
  188. do_read(struct client* c, struct client_info* info) {
  189. assert(info->connect_type == SOCKET_CONNECTED);
  190. int to_id = info->to_id;
  191. if (to_id < 0) {
  192. to_id = connect_to(c, info->id);
  193. if (to_id == -1) {
  194. do_close(c, info);
  195. return -1;
  196. }
  197. info->to_id = to_id;
  198. c->free_connection -= 1;
  199. c->time = 0;
  200. }
  201. int len;
  202. struct client_info* to_info = &c->all_fds[to_id % TOTAL_CONNECTION];
  203. if (to_info->id != to_id) {
  204. do_close(c, info);
  205. return -1;
  206. }
  207. char* buffer = get_ring_buffer_write_ptr(to_info->buffer, &len);
  208. if (!buffer) {
  209. return 0; //buff fulled
  210. }
  211. int n = (int)read(info->fd, buffer, len);
  212. if (n == -1) {
  213. switch (errno) {
  214. case EAGAIN:
  215. fprintf(stderr, "%s read fd error:EAGAIN.\n", get_time());
  216. break;
  217. case EINTR:
  218. break;
  219. default:
  220. fprintf(stderr, "%s client: read (id=%d) error :%s. \n", get_time(), info->id, strerror(errno));
  221. do_close(c, info);
  222. return -1;
  223. }
  224. return 1;
  225. }
  226. if (n == 0) {
  227. do_close(c, info); //normal close
  228. return -1;
  229. }
  230. move_ring_buffer_write_pos(to_info->buffer, n);
  231. FD_SET(to_info->fd, &c->fd_wset);
  232. if (n == len && !is_ring_buffer_empty(to_info->buffer)) {
  233. fprintf(stderr, "%s client: read again.\n", get_time());
  234. return do_read(c, info);
  235. }
  236. return 1;
  237. }
  238. static int
  239. do_write(struct client* c, struct client_info* info, int wait_closed) {
  240. if (info->connect_type == SOCKET_CONNECTING) {
  241. if (wait_closed) return 0;
  242. if (report_connect(c, info) == -1) {
  243. return -1;
  244. }else if (is_ring_buffer_empty(info->buffer)) {
  245. FD_CLR(info->fd, &c->fd_wset);
  246. }
  247. return 0;
  248. }
  249. int len;
  250. char* buffer = get_ring_buffer_read_ptr(info->buffer, &len);
  251. if (!buffer) {
  252. return 0;
  253. }
  254. int writed_len = 0;
  255. char need_break = 0;
  256. while (!need_break && writed_len < len) {
  257. int n = write(info->fd, buffer, len - writed_len);
  258. if (n < 0) {
  259. switch (errno) {
  260. case EINTR:
  261. n = 0;
  262. break;
  263. case EAGAIN:
  264. n = 0;
  265. need_break = 1;
  266. break;
  267. default:
  268. need_break = 1;
  269. fprintf(stderr, "%s socket-client: write to (id=%d) error :%s.\n", get_time(), info->id, strerror(errno));
  270. do_close(c, info);
  271. return -1;
  272. }
  273. }
  274. else {
  275. writed_len += n;
  276. buffer += n;
  277. }
  278. }
  279. move_ring_buffer_read_pos(info->buffer, writed_len);
  280. if (is_ring_buffer_empty(info->buffer)) {
  281. FD_CLR(info->fd, &c->fd_wset);
  282. } else if (writed_len == len) {
  283. fprintf(stderr, "%s client: write again.\n", get_time());
  284. return do_write(c, info, wait_closed);
  285. }
  286. return 1;
  287. }
  288. static void
  289. pre_check_close(struct client* c) {
  290. int len;
  291. char* id_buffer = get_ring_buffer_read_ptr(c->wait_closed, &len);
  292. if (!id_buffer) return;
  293. int id_len = sizeof(int);
  294. assert(len % id_len == 0);
  295. int tmp = len;
  296. while (len > 0) {
  297. int* id = (int*)id_buffer;
  298. int idx = *id % TOTAL_CONNECTION;
  299. id_buffer += id_len;
  300. len -= len;
  301. struct client_info* info = c->all_fds + idx;
  302. if (info->fd > 0 && info->id == *id) {
  303. if (do_write(c, info, 1) != -1) {
  304. do_close(c, info);
  305. }
  306. }
  307. }
  308. move_ring_buffer_read_pos(c->wait_closed, tmp);
  309. }
  310. static void*
  311. client_thread(void* param) {
  312. struct client_param* cp = (struct client_param*)param;
  313. struct client c;
  314. memset(&c, 0, sizeof(c));
  315. sprintf(c.remote_ip, "%s", cp->remote_ip);
  316. c.remote_port = cp->p1;
  317. c.ssh_port = cp->p2;
  318. c.wait_closed = alloc_ring_buffer(sizeof(int) * TOTAL_CONNECTION);
  319. FD_ZERO(&c.fd_rset);
  320. FD_ZERO(&c.fd_wset);
  321. FD_SET(cp->pid, &c.fd_rset);
  322. c.max_fd = cp->pid + 1;
  323. sp_nonblocking(cp->pid);
  324. while (1) {
  325. pre_check_close(&c);
  326. if (connect_to(&c, -1) == -1 && c.cnt == 0) {
  327. c.max_fd = cp->pid + 1;
  328. int buff = 0;
  329. int n = (int)read(cp->pid, &buff, sizeof(int));
  330. if (n > 0) {
  331. break;
  332. }
  333. sleep(1);
  334. continue;
  335. }
  336. fd_set r_set = c.fd_rset;
  337. fd_set w_set = c.fd_wset;
  338. int cnt = select(c.max_fd, &r_set, &w_set, NULL, NULL);
  339. if (cnt == -1) {
  340. fprintf(stderr, "%s select error: %s.\n", get_time(), strerror(errno));
  341. continue;
  342. }
  343. int i;
  344. for (i = c.cnt - 1; i >= 0 && cnt > 0; --i) {
  345. int id = c.all_ids[i] % TOTAL_CONNECTION;
  346. struct client_info* info = &c.all_fds[id];
  347. assert(c.all_ids[i] == info->id);
  348. int fd = info->fd;
  349. assert(fd > 0);
  350. if (FD_ISSET(fd, &r_set)) {
  351. // read
  352. --cnt;
  353. if (do_read(&c, info) == -1) continue;
  354. }
  355. if (FD_ISSET(fd, &w_set)) {
  356. //write
  357. --cnt;
  358. if (do_write(&c, info, 0) == -1) continue;
  359. }
  360. }
  361. if (FD_ISSET(cp->pid, &r_set)) {
  362. //exit
  363. break;
  364. }
  365. }
  366. fprintf(stderr, "%s ====================CLIENT: SEND LAST DATA BEGIN===================.\n", get_time());
  367. int i;
  368. for (i = c.cnt - 1; i >= 0; --i) {
  369. int id = c.all_ids[i] % TOTAL_CONNECTION;
  370. struct client_info* info = &c.all_fds[id];
  371. assert(c.all_ids[i] == info->id);
  372. if (do_write(&c, info, 1) != -1) {
  373. do_close(&c, info);
  374. }
  375. }
  376. fprintf(stderr, "%s ====================CLIENT: SEND LAST DATA END=====================.\n", get_time());
  377. free_ring_buffer(c.wait_closed);
  378. assert(c.cnt == 0);
  379. return NULL;
  380. }
  381. pthread_t
  382. start_client(struct client_param* cp) {
  383. pthread_t pid;
  384. if (pthread_create(&pid, NULL, client_thread, cp)) {
  385. fprintf(stderr, "%s Create client thread failed.\n", get_time());
  386. exit(1);
  387. return 0;
  388. }
  389. return pid;
  390. }