123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481 |
- #include "tunnel.h"
- #include "client.h"
- #include "socket_comm.h"
- #include "sys/time.h"
- #include <assert.h>
- struct client {
- char remote_ip[128];
- int remote_port;
- int ssh_port;
- int free_connection;
- int id_idx;
- struct ring_buffer* wait_closed;
- int max_fd;
- fd_set fd_rset;
- fd_set fd_wset;
- int time;
- int cnt;
- struct client_info all_fds[TOTAL_CONNECTION];
- int all_ids[TOTAL_CONNECTION];
- };
- static int
- get_id(struct client* c) {
- int i, ret = -1;
- for (i = c->id_idx; i < c->id_idx + TOTAL_CONNECTION; ++i) {
- int idx = i % TOTAL_CONNECTION;
- ++c->id_idx;
- if (c->all_fds[idx].fd == 0) {
- ret = i;
- break;
- }
- }
- return ret;
- }
- static int
- connect_to(struct client* c, int ssh_id){
- if (c->cnt >= TOTAL_CONNECTION) {
- fprintf(stderr, "%s client max connection.....\n", get_time());
- return -1;
- }
-
- const char* ip;
- int port;
- if (ssh_id < 0) {
- if (c->free_connection > 0) return -1;
- struct timeval tv;
- gettimeofday(&tv, NULL);
- if (tv.tv_sec - c->time < FREE_CONNECT_TIME) {
- return -1;
- }else {
- c->time = tv.tv_sec;
- }
- ip = c->remote_ip;
- port = c->remote_port;
- } else {
- ip = "0.0.0.0";
- port = c->ssh_port;
- }
- int id;
- int idx;
- struct client_info* info;
- struct ring_buffer* rb;
- struct addrinfo hints;
- struct addrinfo* res = NULL;
- struct addrinfo* ai_ptr = NULL;
- memset(&hints, 0, sizeof(hints));
- hints.ai_family = AF_UNSPEC;
- hints.ai_socktype = SOCK_STREAM;
- hints.ai_protocol = IPPROTO_TCP;
- char portstr[16];
- sprintf(portstr, "%d", port);
- int status = getaddrinfo(ip, portstr, &hints, &res);
- if (status != 0) {
- return -1;
- }
- int sock = -1;
- for (ai_ptr = res; ai_ptr != NULL; ai_ptr = ai_ptr->ai_next) {
- sock = socket(ai_ptr->ai_family, ai_ptr->ai_socktype, ai_ptr->ai_protocol);
- if (sock < 0) {
- continue;
- }
- set_keep_alive(sock);
- sp_nonblocking(sock);
- status = connect(sock, ai_ptr->ai_addr, ai_ptr->ai_addrlen);
- if (status != 0 && errno != EINPROGRESS) {
- close(sock);
- sock = -1;
- continue;
- }
- break;
- }
- if (sock < 0) {
- goto _failed;
- }
- id = get_id(c);
- assert(id != -1);
- idx = id % TOTAL_CONNECTION;
- info = &c->all_fds[idx];
- info->fd = sock;
- info->id = id;
- info->to_id = ssh_id;
- snprintf(info->client_ip, sizeof(info->client_ip), "%s:%d", ip, port);
- rb = alloc_ring_buffer(MAX_CLIENT_BUFFER);
- info->buffer = rb;
- c->all_ids[c->cnt++] = id;
- if (ssh_id < 0) {
- c->free_connection += 1;
- }
- if (status != 0) {
- //connect no block, need check after
- FD_SET(sock, &c->fd_wset);
- info->connect_type = SOCKET_CONNECTING;
- }else {
- //success
- FD_SET(sock, &c->fd_rset);
- info->connect_type = SOCKET_CONNECTED;
- struct sockaddr* addr = ai_ptr->ai_addr;
- void* sin_addr = (ai_ptr->ai_family == AF_INET) ? (void*)&((struct sockaddr_in*)addr)->sin_addr : (void*)&((struct sockaddr_in6*)addr)->sin6_addr;
- inet_ntop(ai_ptr->ai_family, sin_addr, info->client_ip, sizeof(info->client_ip));
- fprintf(stderr, "%s connected to %s. \n", get_time(), info->client_ip);
- }
- if (c->max_fd < sock + 1) {
- c->max_fd = sock + 1;
- }
- return id;
- _failed:
- freeaddrinfo(res);
- return -1;
- }
- static void
- do_close(struct client* c, struct client_info* info) {
- int i;
- for (i = 0; i < c->cnt; ++i) {
- if (c->all_ids[i] == info->id) {
- memcpy(c->all_ids + i, c->all_ids + i + 1, (c->cnt - i - 1) * sizeof(int));
- --c->cnt;
- break;
- }
- }
- FD_CLR(info->fd, &c->fd_rset);
- FD_CLR(info->fd, &c->fd_wset);
- close(info->fd);
- if (info->to_id == -1) {
- assert(c->free_connection == 1);
- c->free_connection = 0;
- }
- if (info->to_id >= 0 && c->all_fds[info->to_id % TOTAL_CONNECTION].id == info->to_id ) {
- int len;
- char* id_buffer = get_ring_buffer_write_ptr(c->wait_closed, &len);
- int id_len = sizeof(int);
- assert(id_buffer && len >= id_len);
- memcpy(id_buffer, &info->to_id, id_len);
- move_ring_buffer_write_pos(c->wait_closed, id_len);
- }
- if (info->connect_type == SOCKET_CONNECTED)
- {
- fprintf(stderr, "%s client disconnect from %s.\n", get_time(), info->client_ip);
- }
- info->to_id = -1;
- info->id = -1;
- info->fd = 0;
- info->connect_type = 0;
- free_ring_buffer(info->buffer);
- info->buffer = NULL;
- memset(info->client_ip, 0, sizeof(info->client_ip));
- }
- static int
- report_connect(struct client* c, struct client_info* info) {
- int error;
- socklen_t len = sizeof(error);
- int code = getsockopt(info->fd, SOL_SOCKET, SO_ERROR, &error, &len);
- if (code != 0 || error != 0) {
- //connect fail, close it
- fprintf(stderr, "%s client: connect to %s error :%s. \n", get_time(), info->client_ip, strerror(error));
- do_close(c, info);
- return -1;
- }
- info->connect_type = SOCKET_CONNECTED;
- FD_SET(info->fd, &c->fd_rset);
- union sockaddr_all u;
- socklen_t slen = sizeof(u);
- if (getpeername(info->fd, &u.s, &slen) == 0){
- void* sin_addr = (u.s.sa_family == AF_INET) ? (void*)&u.v4.sin_addr : (void*)&u.v6.sin6_addr;
- inet_ntop(u.s.sa_family, sin_addr, info->client_ip, sizeof(info->client_ip));
- }
- fprintf(stderr, "%s connected to %s. \n", get_time(), info->client_ip);
- return 0;
- }
- static int
- do_read(struct client* c, struct client_info* info) {
- assert(info->connect_type == SOCKET_CONNECTED);
- int to_id = info->to_id;
- if (to_id < 0) {
- to_id = connect_to(c, info->id);
- if (to_id == -1) {
- do_close(c, info);
- return -1;
- }
- info->to_id = to_id;
- c->free_connection -= 1;
- c->time = 0;
- }
- int len;
- struct client_info* to_info = &c->all_fds[to_id % TOTAL_CONNECTION];
- if (to_info->id != to_id) {
- do_close(c, info);
- return -1;
- }
- char* buffer = get_ring_buffer_write_ptr(to_info->buffer, &len);
- if (!buffer) {
- return 0; //buff fulled
- }
- int n = (int)read(info->fd, buffer, len);
- if (n == -1) {
- switch (errno) {
- case EAGAIN:
- fprintf(stderr, "%s read fd error:EAGAIN.\n", get_time());
- break;
- case EINTR:
- break;
- default:
- fprintf(stderr, "%s client: read (id=%d) error :%s. \n", get_time(), info->id, strerror(errno));
- do_close(c, info);
- return -1;
- }
- return 1;
- }
- if (n == 0) {
- do_close(c, info); //normal close
- return -1;
- }
- move_ring_buffer_write_pos(to_info->buffer, n);
- FD_SET(to_info->fd, &c->fd_wset);
- if (n == len && !is_ring_buffer_empty(to_info->buffer)) {
- fprintf(stderr, "%s client: read again.\n", get_time());
- return do_read(c, info);
- }
- return 1;
- }
- static int
- do_write(struct client* c, struct client_info* info, int wait_closed) {
- if (info->connect_type == SOCKET_CONNECTING) {
- if (wait_closed) return 0;
- if (report_connect(c, info) == -1) {
- return -1;
- }else if (is_ring_buffer_empty(info->buffer)) {
- FD_CLR(info->fd, &c->fd_wset);
- }
- return 0;
- }
- int len;
- char* buffer = get_ring_buffer_read_ptr(info->buffer, &len);
- if (!buffer) {
- return 0;
- }
- int writed_len = 0;
- char need_break = 0;
- while (!need_break && writed_len < len) {
- int n = write(info->fd, buffer, len - writed_len);
- if (n < 0) {
- switch (errno) {
- case EINTR:
- n = 0;
- break;
- case EAGAIN:
- n = 0;
- need_break = 1;
- break;
- default:
- need_break = 1;
- fprintf(stderr, "%s socket-client: write to (id=%d) error :%s.\n", get_time(), info->id, strerror(errno));
- do_close(c, info);
- return -1;
- }
- }
- else {
- writed_len += n;
- buffer += n;
- }
- }
- move_ring_buffer_read_pos(info->buffer, writed_len);
- if (is_ring_buffer_empty(info->buffer)) {
- FD_CLR(info->fd, &c->fd_wset);
- } else if (writed_len == len) {
- fprintf(stderr, "%s client: write again.\n", get_time());
- return do_write(c, info, wait_closed);
- }
- return 1;
- }
- static void
- pre_check_close(struct client* c) {
- int len;
- char* id_buffer = get_ring_buffer_read_ptr(c->wait_closed, &len);
- if (!id_buffer) return;
- int id_len = sizeof(int);
- assert(len % id_len == 0);
- int tmp = len;
- while (len > 0) {
- int* id = (int*)id_buffer;
- int idx = *id % TOTAL_CONNECTION;
- id_buffer += id_len;
- len -= len;
- struct client_info* info = c->all_fds + idx;
- if (info->fd > 0 && info->id == *id) {
- if (do_write(c, info, 1) != -1) {
- do_close(c, info);
- }
- }
- }
- move_ring_buffer_read_pos(c->wait_closed, tmp);
- }
- static void*
- client_thread(void* param) {
- struct client_param* cp = (struct client_param*)param;
- struct client c;
- memset(&c, 0, sizeof(c));
- sprintf(c.remote_ip, "%s", cp->remote_ip);
- c.remote_port = cp->p1;
- c.ssh_port = cp->p2;
- c.wait_closed = alloc_ring_buffer(sizeof(int) * TOTAL_CONNECTION);
- FD_ZERO(&c.fd_rset);
- FD_ZERO(&c.fd_wset);
- FD_SET(cp->pid, &c.fd_rset);
- c.max_fd = cp->pid + 1;
- sp_nonblocking(cp->pid);
- while (1) {
- pre_check_close(&c);
-
- if (connect_to(&c, -1) == -1 && c.cnt == 0) {
- c.max_fd = cp->pid + 1;
- int buff = 0;
- int n = (int)read(cp->pid, &buff, sizeof(int));
- if (n > 0) {
- break;
- }
- sleep(1);
- continue;
- }
- fd_set r_set = c.fd_rset;
- fd_set w_set = c.fd_wset;
- int cnt = select(c.max_fd, &r_set, &w_set, NULL, NULL);
- if (cnt == -1) {
- fprintf(stderr, "%s select error: %s.\n", get_time(), strerror(errno));
- continue;
- }
- int i;
- for (i = c.cnt - 1; i >= 0 && cnt > 0; --i) {
- int id = c.all_ids[i] % TOTAL_CONNECTION;
- struct client_info* info = &c.all_fds[id];
- assert(c.all_ids[i] == info->id);
- int fd = info->fd;
- assert(fd > 0);
- if (FD_ISSET(fd, &r_set)) {
- // read
- --cnt;
- if (do_read(&c, info) == -1) continue;
- }
- if (FD_ISSET(fd, &w_set)) {
- //write
- --cnt;
- if (do_write(&c, info, 0) == -1) continue;
- }
- }
- if (FD_ISSET(cp->pid, &r_set)) {
- //exit
- break;
- }
- }
- fprintf(stderr, "%s ====================CLIENT: SEND LAST DATA BEGIN===================.\n", get_time());
- int i;
- for (i = c.cnt - 1; i >= 0; --i) {
- int id = c.all_ids[i] % TOTAL_CONNECTION;
- struct client_info* info = &c.all_fds[id];
- assert(c.all_ids[i] == info->id);
- if (do_write(&c, info, 1) != -1) {
- do_close(&c, info);
- }
- }
- fprintf(stderr, "%s ====================CLIENT: SEND LAST DATA END=====================.\n", get_time());
- free_ring_buffer(c.wait_closed);
- assert(c.cnt == 0);
- return NULL;
- }
- pthread_t
- start_client(struct client_param* cp) {
- pthread_t pid;
- if (pthread_create(&pid, NULL, client_thread, cp)) {
- fprintf(stderr, "%s Create client thread failed.\n", get_time());
- exit(1);
- return 0;
- }
- return pid;
- }
|