一种getline的快速实现

也许只是没用的技巧,但谁不喜欢折腾代码

背景

这个需求很常见

  • 需要一个getline()获取若干行字符串(const char*
  • 某些API传入这些字符串时需要得知字符串的长度
  • 很自然地想到strlen

不满

很显然strlen()线性扫描引起不适,为什么为了得到长度就要一个个字符地找

能不能更优雅一点

猜想

假如,char[]中在默认情况下全是'\0',有什么启发,是不是可以直接二分?

只需找到第一个'\0'在哪就行了!

可问题是第二遍查找,当再次cin.getline(buf, sizeof(buf))的时候,你需要重新memset

当上一次获得的字符串足够长的时候,你需要memset的成本跟一次$O(N)$找strlen的成本基本一致

改进

上面的问题是:

  • memset就会覆盖,无法满足后缀单调全\0可以二分的条件
  • memset再二分,成本和strlen一致

那么我们打破条件即可,构造一种条件使得无需memset也不会覆盖\0

最简单的就是append:构造一个buffer,往后面追加而不是覆盖,是不是超简单的

struct IoResult {
    const char *buf;
    size_t len;

    operator bool() { return buf != nullptr; }
};

template <size_t N>
class FastIo {
public:

    IoResult getline(std::istream &in = std::cin) {
        if(cur > N) clear();
        if(!in.getline(_buf + cur, M - cur)) return {nullptr, 0};
        size_t bound = strlen();
        IoResult result = {_buf + cur, bound - cur};
        cur = bound;
        return result;
    }

private:
    void clear() {
        memset(_buf, 0, cur);
        cur = 0;
    }

    size_t strlen() {
        size_t lo = cur, hi = M-1;
        while(lo < hi) {
            size_t mid = lo + (hi-lo >> 1);
            if(_buf[mid] == '\0') hi = mid;
            else lo = mid+1;
        }
        return lo;
    }

    static constexpr size_t M  = N << 1;
    char _buf[M] {};
    size_t cur = 0;
};

再改进

上面的朴素实现是能动的,但是二分时的check能不能直接用CPU字长套一波位运算更爽快点

可以的

struct IoResult {
    const char *buf;
    size_t len;

    operator bool() { return buf != nullptr; }
};

template <size_t N>
class FastIo {
public:

    IoResult getline(std::istream &in = std::cin) {
        if(cur > N) clear();
        if(!in.getline(_buf + cur, M - cur)) return {nullptr, 0};
        size_t bound = strlen2();
        IoResult result = {_buf + cur, bound - cur};
        cur = bound;
        return result;
    }

private:
    void clear() {
        memset(_buf, 0, cur);
        cur = 0;
    }

    size_t strlen2() {
        size_t lo = cur >> 3;
        size_t hi = M-1 >> 3;
        while(lo < hi) {
            size_t mid = lo + (hi-lo >> 1);
            auto chars = *((long long*)(_buf + (mid << 3)));
            if((chars & 0xff) == 0) hi = mid;
            else lo = mid + 1;
        }
        lo = ((lo ? lo-1 : lo) << 3);
        for(int i = lo, j = lo + 8; i <= j; ++i) { // 恰为8的倍数时需要=
            if(_buf[i] == '\0') return i;
        }
        return M-1;
    }

    static constexpr size_t M  = ((N+3 >> 2) << 3) + 8;
    char _buf[M] {};
    size_t cur = 0;
};

这里有点hard code,默认就是64位了

需要补充的几点是:

  • 二分时的check是指,在mid为起始的连续8个char中,只需判断最后1个char是否为\0即可
  • 当恰好遇到8的倍数时,需要特判下一个char
  • M的意思是指:我既期望至少是M的两倍,又期望是8的倍数,更为了前面的特判条件加上8防止访问越界
    • 这里用的操作是获得4的上取整倍数再乘上2的意思,别问为什么不直接拿8,我喜欢

实测

直接上代码吧

int main() {
    std::vector<char> text;
    const int MAXN = 1e8 + 11;
    for(int i = 0; i < MAXN; ++i) text.emplace_back('a' + (i%26));
    const int round[] = {10, 100, 1000, 10000, 100000};
    for(int i = 0, j = 0; i < MAXN; ++i) {
        if((i % round[j]) == 0) {
            text[i] = '\n';
            j = (j + 1) % 4;
        }
    }
    std::istringstream is(text.data());
    std::cin.rdbuf(is.rdbuf());

    auto t1 = clock();
    char buf[100007];
    size_t len = 0;
    while(std::cin.getline(buf, sizeof(buf))) {
        len += strlen(buf);  

    }

    // size_t len = 0;
    // FastIo<100007> io;
    // while(auto res = io.getline()) {
    //     len += res.len;
    // }
    auto t2 = clock();
    std::cout << (1.0*t2 - t1)/CLOCKS_PER_SEC << "s " << "len = " << len << std::endl;
}
std::getline   0.019879s
fastIo.getline 0.008329s

测试环境为-O3

彩蛋

实际上,fastIostrlenstrlen2跑的差不多快,但我觉得strlen2版本比较酷炫

cin.getline加上memset(buf, 0, curLen)其实效率也挺高的,但是使用上的便利程度是必须要考虑的

发表评论

邮箱地址不会被公开。 必填项已用*标注