diff --git a/client/client.go b/client/client.go index aac7ea7..f331317 100644 --- a/client/client.go +++ b/client/client.go @@ -2,9 +2,10 @@ package main import ( "fmt" + "os" + "github.com/xiaoqidun/qqwry" "github.com/xiaoqidun/qqwry/assets" - "os" ) func init() { diff --git a/go.mod b/go.mod index 55cc094..a381bc6 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.20 require ( github.com/ipipdotnet/ipdb-go v1.3.3 - golang.org/x/text v0.31.0 + golang.org/x/text v0.32.0 ) diff --git a/go.sum b/go.sum index c236e21..ab1a303 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ github.com/ipipdotnet/ipdb-go v1.3.3 h1:GLSAW9ypLUd6EF9QNK2Uhxew9Jzs4XMJ9gOZEFnJm7U= github.com/ipipdotnet/ipdb-go v1.3.3/go.mod h1:yZ+8puwe3R37a/3qRftXo40nZVQbxYDLqls9o5foexs= -golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM= -golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= diff --git a/qqwry.go b/qqwry.go index fff5026..a8421ad 100644 --- a/qqwry.go +++ b/qqwry.go @@ -1,235 +1,75 @@ -package qqwry - -import ( - "bytes" - "encoding/binary" - "errors" - "github.com/ipipdotnet/ipdb-go" - "golang.org/x/text/encoding/simplifiedchinese" - "golang.org/x/text/transform" - "io" - "net" - "os" - "strings" - "sync" -) - -var ( - data []byte - dataLen uint32 - ipdbCity *ipdb.City - dataType = dataTypeDat - locationCache = &sync.Map{} -) - -const ( - dataTypeDat = 0 - dataTypeIpdb = 1 -) - -const ( - indexLen = 7 - redirectMode1 = 0x01 - redirectMode2 = 0x02 -) - -type Location struct { - Country string // 国家 - Province string // 省份 - City string // 城市 - District string // 区县 - ISP string // 运营商 - IP string // IP地址 -} - -func byte3ToUInt32(data []byte) uint32 { - i := uint32(data[0]) & 0xff - i |= (uint32(data[1]) << 8) & 0xff00 - i |= (uint32(data[2]) << 16) & 0xff0000 - return i -} - -func gb18030Decode(src []byte) string { - in := bytes.NewReader(src) - out := transform.NewReader(in, simplifiedchinese.GB18030.NewDecoder()) - d, _ := io.ReadAll(out) - return string(d) -} - -// QueryIP 从内存或缓存查询IP -func QueryIP(ip string) (location *Location, err error) { - if v, ok := locationCache.Load(ip); ok { - return v.(*Location), nil - } - switch dataType { - case dataTypeDat: - return QueryIPDat(ip) - case dataTypeIpdb: - return QueryIPIpdb(ip) - default: - return nil, errors.New("data type not support") - } -} - -// QueryIPDat 从dat查询IP,仅加载dat格式数据库时使用 -func QueryIPDat(ipv4 string) (location *Location, err error) { - ip := net.ParseIP(ipv4).To4() - if ip == nil { - return nil, errors.New("ip is not ipv4") - } - ip32 := binary.BigEndian.Uint32(ip) - posA := binary.LittleEndian.Uint32(data[:4]) - posZ := binary.LittleEndian.Uint32(data[4:8]) - var offset uint32 = 0 - for { - mid := posA + (((posZ-posA)/indexLen)>>1)*indexLen - buf := data[mid : mid+indexLen] - _ip := binary.LittleEndian.Uint32(buf[:4]) - if posZ-posA == indexLen { - offset = byte3ToUInt32(buf[4:]) - buf = data[mid+indexLen : mid+indexLen+indexLen] - if ip32 < binary.LittleEndian.Uint32(buf[:4]) { - break - } else { - offset = 0 - break - } - } - if _ip > ip32 { - posZ = mid - } else if _ip < ip32 { - posA = mid - } else if _ip == ip32 { - offset = byte3ToUInt32(buf[4:]) - break - } - } - if offset <= 0 { - return nil, errors.New("ip not found") - } - posM := offset + 4 - mode := data[posM] - var ispPos uint32 - var addr, isp string - switch mode { - case redirectMode1: - posC := byte3ToUInt32(data[posM+1 : posM+4]) - mode = data[posC] - posCA := posC - if mode == redirectMode2 { - posCA = byte3ToUInt32(data[posC+1 : posC+4]) - posC += 4 - } - for i := posCA; i < dataLen; i++ { - if data[i] == 0 { - addr = string(data[posCA:i]) - break - } - } - if mode != redirectMode2 { - posC += uint32(len(addr) + 1) - } - ispPos = posC - case redirectMode2: - posCA := byte3ToUInt32(data[posM+1 : posM+4]) - for i := posCA; i < dataLen; i++ { - if data[i] == 0 { - addr = string(data[posCA:i]) - break - } - } - ispPos = offset + 8 - default: - posCA := offset + 4 - for i := posCA; i < dataLen; i++ { - if data[i] == 0 { - addr = string(data[posCA:i]) - break - } - } - ispPos = offset + uint32(5+len(addr)) - } - if addr != "" { - addr = strings.TrimSpace(gb18030Decode([]byte(addr))) - } - ispMode := data[ispPos] - if ispMode == redirectMode1 || ispMode == redirectMode2 { - ispPos = byte3ToUInt32(data[ispPos+1 : ispPos+4]) - } - if ispPos > 0 { - for i := ispPos; i < dataLen; i++ { - if data[i] == 0 { - isp = string(data[ispPos:i]) - if isp != "" { - if strings.Contains(isp, "CZ88.NET") { - isp = "" - } else { - isp = strings.TrimSpace(gb18030Decode([]byte(isp))) - } - } - break - } - } - } - location = SplitResult(addr, isp, ipv4) - locationCache.Store(ipv4, location) - return location, nil -} - -// QueryIPIpdb 从ipdb查询IP,仅加载ipdb格式数据库时使用 -func QueryIPIpdb(ip string) (location *Location, err error) { - ret, err := ipdbCity.Find(ip, "CN") - if err != nil { - return - } - location = SplitResult(ret[0], ret[1], ip) - locationCache.Store(ip, location) - return location, nil -} - -// LoadData 从内存加载IP数据库 -func LoadData(database []byte) { - if string(database[6:11]) == "build" { - dataType = dataTypeIpdb - loadCity, err := ipdb.NewCityFromBytes(database) - if err != nil { - panic(err) - } - ipdbCity = loadCity - return - } - data = database - dataLen = uint32(len(data)) -} - -// LoadFile 从文件加载IP数据库 -func LoadFile(filepath string) (err error) { - body, err := os.ReadFile(filepath) - if err != nil { - return - } - LoadData(body) - return -} - -// SplitResult 按照调整后的纯真社区版IP库地理位置格式返回结果 -func SplitResult(addr string, isp string, ipv4 string) (location *Location) { - location = &Location{ISP: isp, IP: ipv4} - splitList := strings.Split(addr, "–") - for i := 0; i < len(splitList); i++ { - switch i { - case 0: - location.Country = splitList[i] - case 1: - location.Province = splitList[i] - case 2: - location.City = splitList[i] - case 3: - location.District = splitList[i] - } - } - if location.Country == "局域网" { - location.ISP = location.Country - } - return -} +package qqwry + +import ( + "errors" + "os" + + "github.com/ipipdotnet/ipdb-go" +) + +const ( + dataTypeDat = 0 + dataTypeIpdb = 1 +) + +// Client IP查询客户端 +// 字段: data DAT数据库, dataLen DAT数据库长度, ipdbCity IPDB数据库, dataType 数据类型, cache 结果缓存 +type Client struct { + data []byte + dataLen uint32 + ipdbCity *ipdb.City + dataType int + cache *Cache +} + +// NewClient 创建新的IP查询客户端 +// 入参: filePath 文件路径 +// 返回: c 客户端实例, err 错误信息 +func NewClient(filePath string) (c *Client, err error) { + body, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } + return NewClientFromData(body) +} + +// NewClientFromData 从数据创建新的IP查询客户端 +// 入参: body DAT数据库或IPDB数据库 +// 返回: c 客户端实例, err 错误信息 +func NewClientFromData(body []byte) (c *Client, err error) { + c = &Client{} + c.cache = NewCache(10000) + if len(body) > 11 && string(body[6:11]) == "build" { + c.dataType = dataTypeIpdb + c.ipdbCity, err = ipdb.NewCityFromBytes(body) + if err != nil { + return nil, err + } + } else { + c.dataType = dataTypeDat + c.data = body + c.dataLen = uint32(len(c.data)) + } + return c, nil +} + +// QueryIP 查询IP +// 入参: ip IP地址 +// 返回: location 位置信息, err 错误信息 +func (c *Client) QueryIP(ip string) (location *Location, err error) { + if v, ok := c.cache.Get(ip); ok { + return v.Clone(), nil + } + switch c.dataType { + case dataTypeDat: + location, err = c.queryIPDat(ip) + case dataTypeIpdb: + location, err = c.queryIPIpdb(ip) + default: + return nil, errors.New("data type not support") + } + if err == nil && location != nil { + c.cache.Add(ip, location) + } + return location, err +} diff --git a/qqwry_cache.go b/qqwry_cache.go new file mode 100644 index 0000000..0925ac6 --- /dev/null +++ b/qqwry_cache.go @@ -0,0 +1,66 @@ +package qqwry + +import ( + "container/list" + "sync" +) + +// cacheEntry 缓存条目 +type cacheEntry struct { + key string + value *Location +} + +// Cache 简单的LRU缓存实现 +// 字段: capacity 容量, list 双向链表, items哈希表, lock 互斥锁 +type Cache struct { + capacity int + list *list.List + items map[string]*list.Element + lock sync.Mutex +} + +// NewCache 创建新的LRU缓存 +// 入参: capacity 缓存容量 +// 返回: *Cache 缓存实例 +func NewCache(capacity int) *Cache { + return &Cache{ + capacity: capacity, + list: list.New(), + items: make(map[string]*list.Element), + } +} + +// Get 获取缓存 +// 入参: key 键 +// 返回: value 值, ok 是否存在 +func (c *Cache) Get(key string) (value *Location, ok bool) { + c.lock.Lock() + defer c.lock.Unlock() + if ent, ok := c.items[key]; ok { + c.list.MoveToFront(ent) + return ent.Value.(*cacheEntry).value, true + } + return nil, false +} + +// Add 添加缓存 +// 入参: key 键, value 值 +func (c *Cache) Add(key string, value *Location) { + c.lock.Lock() + defer c.lock.Unlock() + if ent, ok := c.items[key]; ok { + c.list.MoveToFront(ent) + ent.Value.(*cacheEntry).value = value + return + } + ent := c.list.PushFront(&cacheEntry{key: key, value: value}) + c.items[key] = ent + if c.list.Len() > c.capacity { + back := c.list.Back() + if back != nil { + c.list.Remove(back) + delete(c.items, back.Value.(*cacheEntry).key) + } + } +} diff --git a/qqwry_dat.go b/qqwry_dat.go new file mode 100644 index 0000000..c336609 --- /dev/null +++ b/qqwry_dat.go @@ -0,0 +1,141 @@ +package qqwry + +import ( + "encoding/binary" + "errors" + "net" + "strings" + + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/transform" +) + +const ( + indexLen = 7 + redirectMode1 = 0x01 + redirectMode2 = 0x02 +) + +// byte3ToUInt32 将3字节切片转换为uint32 +// 入参: data 3字节切片 +// 返回: uint32转换后的值 +func byte3ToUInt32(data []byte) uint32 { + i := uint32(data[0]) & 0xff + i |= (uint32(data[1]) << 8) & 0xff00 + i |= (uint32(data[2]) << 16) & 0xff0000 + return i +} + +// gb18030Decode 将GB18030编码解码为UTF-8 +// 入参: src GB18030编码的字节切片 +// 返回: string UTF-8编码的字符串 +func gb18030Decode(src []byte) string { + d, _, _ := transform.Bytes(simplifiedchinese.GB18030.NewDecoder(), src) + return string(d) +} + +// queryIPDat 从DAT数据库查询IP +// 入参: ipv4 IPv4地址 +// 返回: location 位置信息, err 错误信息 +func (c *Client) queryIPDat(ipv4 string) (location *Location, err error) { + ip := net.ParseIP(ipv4).To4() + if ip == nil { + return nil, errors.New("ip is not ipv4") + } + ip32 := binary.BigEndian.Uint32(ip) + posA := binary.LittleEndian.Uint32(c.data[:4]) + posZ := binary.LittleEndian.Uint32(c.data[4:8]) + var offset uint32 = 0 + for { + mid := posA + (((posZ-posA)/indexLen)>>1)*indexLen + buf := c.data[mid : mid+indexLen] + _ip := binary.LittleEndian.Uint32(buf[:4]) + if posZ-posA == indexLen { + offset = byte3ToUInt32(buf[4:]) + buf = c.data[mid+indexLen : mid+indexLen+indexLen] + if ip32 < binary.LittleEndian.Uint32(buf[:4]) { + break + } else { + offset = 0 + break + } + } + if _ip > ip32 { + posZ = mid + } else if _ip < ip32 { + posA = mid + } else if _ip == ip32 { + offset = byte3ToUInt32(buf[4:]) + break + } + } + if offset <= 0 { + return nil, errors.New("ip not found") + } + posM := offset + 4 + mode := c.data[posM] + var ispPos uint32 + var addr, isp string + switch mode { + case redirectMode1: + posC := byte3ToUInt32(c.data[posM+1 : posM+4]) + mode = c.data[posC] + posCA := posC + if mode == redirectMode2 { + posCA = byte3ToUInt32(c.data[posC+1 : posC+4]) + posC += 4 + } + for i := posCA; i < c.dataLen; i++ { + if c.data[i] == 0 { + addr = string(c.data[posCA:i]) + break + } + } + if mode != redirectMode2 { + posC += uint32(len(addr) + 1) + } + ispPos = posC + case redirectMode2: + posCA := byte3ToUInt32(c.data[posM+1 : posM+4]) + for i := posCA; i < c.dataLen; i++ { + if c.data[i] == 0 { + addr = string(c.data[posCA:i]) + break + } + } + ispPos = offset + 8 + default: + posCA := offset + 4 + for i := posCA; i < c.dataLen; i++ { + if c.data[i] == 0 { + addr = string(c.data[posCA:i]) + break + } + } + ispPos = offset + uint32(5+len(addr)) + } + if addr != "" { + addr = strings.TrimSpace(gb18030Decode([]byte(addr))) + } + ispMode := c.data[ispPos] + if ispMode == redirectMode1 || ispMode == redirectMode2 { + ispPos = byte3ToUInt32(c.data[ispPos+1 : ispPos+4]) + } + if ispPos > 0 { + for i := ispPos; i < c.dataLen; i++ { + if c.data[i] == 0 { + isp = string(c.data[ispPos:i]) + if isp != "" { + if strings.Contains(isp, "CZ88.NET") { + isp = "" + } else { + isp = strings.TrimSpace(gb18030Decode([]byte(isp))) + } + } + break + } + } + } + location = SplitResult(addr, isp, ipv4) + return location, nil +} diff --git a/qqwry_global.go b/qqwry_global.go new file mode 100644 index 0000000..942d0c6 --- /dev/null +++ b/qqwry_global.go @@ -0,0 +1,66 @@ +package qqwry + +import ( + "os" + "sync" +) + +// defaultClient 默认客户端,用于向后兼容 +var ( + clientLock sync.RWMutex + defaultClient = &Client{dataType: dataTypeDat} +) + +// LoadData 从内存加载IP数据库 +// 入参: database DAT数据库或IPDB数据库 +func LoadData(database []byte) { + c, err := NewClientFromData(database) + if err != nil { + panic(err) + } + clientLock.Lock() + defaultClient = c + clientLock.Unlock() +} + +// LoadFile 从文件加载IP数据库 +// 入参: filepath 文件路径 +// 返回: err 错误信息 +func LoadFile(filepath string) (err error) { + body, err := os.ReadFile(filepath) + if err != nil { + return + } + LoadData(body) + return +} + +// QueryIP 从内存或缓存查询IP +// 入参: ip IP地址 +// 返回: location 位置信息, err 错误信息 +func QueryIP(ip string) (location *Location, err error) { + clientLock.RLock() + c := defaultClient + clientLock.RUnlock() + return c.QueryIP(ip) +} + +// QueryIPDat 从DAT数据库查询IP,仅加载DAT格式数据库时使用 +// 入参: ipv4 IPv4地址 +// 返回: location 位置信息, err 错误信息 +func QueryIPDat(ipv4 string) (location *Location, err error) { + clientLock.RLock() + c := defaultClient + clientLock.RUnlock() + return c.queryIPDat(ipv4) +} + +// QueryIPIpdb 从IPDB数据库查询IP,仅加载IPDB格式数据库时使用 +// 入参: ip IP地址 +// 返回: location 位置信息, err 错误信息 +func QueryIPIpdb(ip string) (location *Location, err error) { + clientLock.RLock() + c := defaultClient + clientLock.RUnlock() + return c.queryIPIpdb(ip) +} diff --git a/qqwry_ipdb.go b/qqwry_ipdb.go new file mode 100644 index 0000000..3490edb --- /dev/null +++ b/qqwry_ipdb.go @@ -0,0 +1,13 @@ +package qqwry + +// queryIPIpdb 从IPDB数据库查询IP +// 入参: ip IP地址 +// 返回: location 位置信息, err 错误信息 +func (c *Client) queryIPIpdb(ip string) (location *Location, err error) { + ret, err := c.ipdbCity.Find(ip, "CN") + if err != nil { + return + } + location = SplitResult(ret[0], ret[1], ip) + return location, nil +} diff --git a/qqwry_model.go b/qqwry_model.go new file mode 100644 index 0000000..f7d4892 --- /dev/null +++ b/qqwry_model.go @@ -0,0 +1,53 @@ +package qqwry + +import ( + "strings" +) + +// Location IP位置信息 +// 字段: Country 国家, Province 省份, City 城市, District 区县, ISP 运营商, IP IP地址 +type Location struct { + Country string // 国家 + Province string // 省份 + City string // 城市 + District string // 区县 + ISP string // 运营商 + IP string // IP地址 +} + +// Clone 克隆Location对象 +// 返回: newLocation 克隆后的对象 +func (l *Location) Clone() *Location { + return &Location{ + Country: l.Country, + Province: l.Province, + City: l.City, + District: l.District, + ISP: l.ISP, + IP: l.IP, + } +} + +// SplitResult 按照调整后的纯真社区版IP库地理位置格式返回结果 +// 入参: addr 地址信息, isp 运营商信息, ipv4 IP地址 +// 返回: location 位置信息 +func SplitResult(addr string, isp string, ipv4 string) (location *Location) { + location = &Location{ISP: isp, IP: ipv4} + splitList := strings.Split(addr, "–") + for i := 0; i < len(splitList); i++ { + switch i { + case 0: + location.Country = splitList[i] + case 1: + location.Province = splitList[i] + case 2: + location.City = splitList[i] + case 3: + location.District = splitList[i] + } + } + if location.Country == "局域网" { + location.ISP = location.Country + } + return +} diff --git a/qqwry_test.go b/qqwry_test.go index 365e042..9e185ad 100644 --- a/qqwry_test.go +++ b/qqwry_test.go @@ -1,32 +1,101 @@ package qqwry import ( + "fmt" "testing" ) -func init() { - if err := LoadFile("assets/qqwry.ipdb"); err != nil { - panic(err) +// TestClient_QueryIP 测试实例IP查询功能 +func TestClient_QueryIP(t *testing.T) { + tests := []struct { + name string + filePath string + ipAddrList []string + }{ + { + name: "DAT数据库", + filePath: "assets/qqwry.dat", + ipAddrList: []string{ + "119.29.29.29", + "8.8.8.8", + }, + }, + { + name: "IPDB数据库", + filePath: "assets/qqwry.ipdb", + ipAddrList: []string{ + "119.29.29.29", + "8.8.8.8", + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewClient(tt.filePath) + if err != nil { + t.Fatal(err) + } + for _, ip := range tt.ipAddrList { + location, err := client.QueryIP(ip) + if err != nil { + t.Error(err) + continue + } + fmt.Printf("国家:%s,省份:%s,城市:%s,区县:%s,运营商:%s\n", + location.Country, + location.Province, + location.City, + location.District, + location.ISP, + ) + } + }) } } -func TestQueryIP(t *testing.T) { - queryIp := "119.29.29.29" - location, err := QueryIP(queryIp) - if err != nil { - t.Fatal(err) +// TestGlobal_QueryIP 测试全局IP查询功能 +func TestGlobal_QueryIP(t *testing.T) { + tests := []struct { + name string + filePath string + ipAddrList []string + }{ + { + name: "兼容性-DAT", + filePath: "assets/qqwry.dat", + ipAddrList: []string{ + "119.29.29.29", + "8.8.8.8", + }, + }, + { + name: "兼容性-IPDB", + filePath: "assets/qqwry.ipdb", + ipAddrList: []string{ + "119.29.29.29", + "8.8.8.8", + }, + }, } - emptyVal := func(val string) string { - if val != "" { - return val - } - return "未知" + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := LoadFile(tt.filePath); err != nil { + t.Fatal(err) + } + for _, ip := range tt.ipAddrList { + location, err := QueryIP(ip) + if err != nil { + t.Error(err) + continue + } + fmt.Printf("国家:%s,省份:%s,城市:%s,区县:%s,运营商:%s\n", + location.Country, + location.Province, + location.City, + location.District, + location.ISP, + ) + } + }) } - t.Logf("国家:%s,省份:%s,城市:%s,区县:%s,运营商:%s", - emptyVal(location.Country), - emptyVal(location.Province), - emptyVal(location.City), - emptyVal(location.District), - emptyVal(location.ISP), - ) } diff --git a/server/server.go b/server/server.go index 6be73df..5e11c05 100644 --- a/server/server.go +++ b/server/server.go @@ -3,10 +3,11 @@ package main import ( "encoding/json" "flag" - "github.com/xiaoqidun/qqwry" - "github.com/xiaoqidun/qqwry/assets" "net" "net/http" + + "github.com/xiaoqidun/qqwry" + "github.com/xiaoqidun/qqwry/assets" ) type resp struct {