package main import ( "bufio" "bytes" "encoding/binary" "fmt" "io" "net" "os" "os/exec" "strings" "time" "github.com/creack/pty" "github.com/rs/zerolog/log" "github.com/urfave/cli/v2" ) // 帧类型 1字节 // 帧长度 4字节 type Config struct { AddrServer string } const HEART_BEAT_INTERVAL = time.Second * 5 // 心跳超时时间 const ( MSG_TYPE_UNKOWNM = iota MSG_TYPE_HEARTBEAT MSG_TYPE_REGISTER MSG_TYPE_SESSION_CREATE MSG_TYPE_SESSION_DATA MSG_TYPE_SESSION_DESTORY MSG_TYPE_TUNNEL_CREATE MSG_TYPE_MAX ) type Device struct { id string category string desc string /* description of the device */ conn net.Conn create_time int64 /* connection time */ active time.Time registered bool closed uint32 send chan []byte // Buffered channel of outbound messages. } func main() { app := &cli.App{ Name: "XZRobot Ops Server", Usage: "The Server Side For xzrobot ops", Version: "1.0.0", Commands: []*cli.Command{ { Name: "run", Usage: "Run Server", Flags: []cli.Flag{ &cli.StringFlag{ Name: "log", Value: "log.txt", Usage: "log file path", }, &cli.StringFlag{ Name: "conf", Aliases: []string{"c"}, Value: "./rttys.conf", Usage: "config file to load", }, &cli.StringFlag{ Name: "addr-dev", Value: ":9011", Usage: "address to listen device", }, &cli.StringFlag{ Name: "addr-user", Value: ":9012", Usage: "address to listen user", }, &cli.StringFlag{ Name: "db", Value: "sqlite://database.db", Usage: "database source", }, }, Action: func(c *cli.Context) error { runClient(c) return nil }, }, }, Action: func(c *cli.Context) error { c.App.Command("run").Run(c) return nil }, } err := app.Run(os.Args) if err != nil { fmt.Println(err) os.Exit(1) } } func runClient(c *cli.Context) { cfg := &Config{ AddrServer: "localhost:9011", } tcpConn, err := createTcpConn(cfg.AddrServer) if err != nil { fmt.Println("TCP Connect Error! " + err.Error()) return } fmt.Println(tcpConn.LocalAddr().String() + " : Client Connected") dev := &Device{ id: "1234567890", category: "device", conn: tcpConn, create_time: time.Now().Unix(), active: time.Now(), registered: false, closed: 0, send: make(chan []byte, 100), } go dev.readLoop() //registe s := make([][]byte, 3) s[0] = []byte("deviceid") //id s[1] = []byte(dev.id) //desc s[2] = []byte(dev.id) //token dev.WriteMsg(MSG_TYPE_REGISTER, bytes.Join(s, []byte{0})) go dev.writeLoop() select {} } func (dev *Device) WriteMsg(typ int, data []byte) { b := []byte{byte(typ), 0, 0, 0, 0} binary.BigEndian.PutUint32(b[1:], uint32(len(data))) dev.send <- append(b, data...) } func (dev *Device) readLoop() { defer dev.conn.Close() reader := bufio.NewReader(dev.conn) for { b, err := reader.Peek(5) if err != nil { if err != io.EOF && !strings.Contains(err.Error(), "use of closed network connection") { log.Error().Msg(err.Error()) } return } reader.Discard(5) msg_type := b[0] if msg_type >= MSG_TYPE_MAX { log.Error().Msgf("invalid msg type: %d", msg_type) return } msg_length := binary.BigEndian.Uint32(b[1:]) data := make([]byte, msg_length) _, err = io.ReadFull(reader, data) if err != nil { log.Error().Msg(err.Error()) return } dev.active = time.Now() switch msg_type { case MSG_TYPE_HEARTBEAT: // log.Info().Msgf("Receive Heartbeat Time: %d", time.Now().Unix()) case MSG_TYPE_REGISTER: dev.registered = true log.Info().Msgf("Device Registry Success") case MSG_TYPE_TUNNEL_CREATE: log.Info().Msgf("Receive Tunnel Create") cmd := exec.Command("bash") ff, err := pty.Start(cmd) if err != nil { fmt.Println("Create Pty Error! " + err.Error()) return } remoteConn, err := createTcpConn("localhost:10001") if err != nil { fmt.Println("Create remoteAddr Error! " + err.Error()) return } go func() { defer ff.Close() defer remoteConn.Close() _, err := io.Copy(ff, remoteConn) if err != nil { fmt.Println("Copy error! " + err.Error()) return } }() go func() { defer ff.Close() defer remoteConn.Close() _, err := io.Copy(remoteConn, ff) if err != nil { fmt.Println("Copy error! " + err.Error()) return } }() // createTunnel("localhost:10001", "localhost:20001") default: log.Error().Msgf("invalid msg type: %d", msg_type) } } } func (dev *Device) writeLoop() { ticker := time.NewTicker(time.Second) defer dev.conn.Close() ninactive := 0 lastHeartbeat := time.Now() for { select { case msg, ok := <-dev.send: if !ok { return } _, err := dev.conn.Write(msg) if err != nil { log.Error().Msg(err.Error()) return } case <-ticker.C: now := time.Now() if now.Sub(dev.active) > HEART_BEAT_INTERVAL*3/2 { if dev.id == "" { return } log.Error().Msgf("Inactive device in long time: %s", dev.id) if ninactive > 3 { log.Error().Msgf("Inactive 3 times, now kill it: %s", dev.id) return } ninactive = ninactive + 1 } if now.Sub(lastHeartbeat) > HEART_BEAT_INTERVAL-1 { lastHeartbeat = now if len(dev.send) < 1 { // log.Info().Msgf("Send Heartbeat Time: %d", time.Now().Unix()) dev.WriteMsg(MSG_TYPE_HEARTBEAT, []byte{}) } } } } } func createTcpConn(addr string) (*net.TCPConn, error) { tcpConn, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err } conn, err := net.DialTCP("tcp", nil, tcpConn) if err != nil { return nil, err } return conn, nil } func createTunnel(localAddr string, remoteAddr string) { localConn, err := createTcpConn(localAddr) if err != nil { fmt.Println("Create LocalConn Error! " + err.Error()) return } remoteConn, err := createTcpConn(remoteAddr) if err != nil { fmt.Println("Create remoteAddr Error! " + err.Error()) return } go func() { defer localConn.Close() defer remoteConn.Close() _, err := io.Copy(localConn, remoteConn) if err != nil { fmt.Println("Copy error! " + err.Error()) return } }() go func() { defer localConn.Close() defer remoteConn.Close() _, err := io.Copy(remoteConn, localConn) if err != nil { fmt.Println("Copy error! " + err.Error()) return } }() }