프로젝트/암호화 모듈

[프로젝트-CRYPTO] Big Integer 구현

-=HaeJuK=- 2024. 12. 30. 17:20
반응형
#git URL
git clone https://github.com/HaeJuk-Lab/crypto.git

왜 Big Integer와 같은 타입이 필요한가?

  1. RSA 및 ECC에서 큰 정수 연산
    • RSA-2048에서는 617자리 이상의 정수 연산이 필요합니다.
    • 기본 int, long long은 64비트(약 20자리)까지만 다루기 때문에 가변 길이의 정수(Big Integer) 구현이 필요합니다.
  2. 비트 기반 연산
    • 블록 암호 알고리즘(ARIA, AES, SEED)은 128비트, 192비트, 256비트 키를 다룹니다.
    • 이를 효율적으로 다루기 위해 128비트 이상의 고정 크기 타입이 필요합니다.
  3. 가변 길이 버퍼 및 데이터
    • 해시, 블록 암호화에서는 블록 단위로 데이터를 처리하므로, 고정 크기 블록 타입가변 길이 타입이 필요합니다

 

🔹 핵심 타입 목록

  1. Big Integer (RSA 등에서 사용)
  2. 128비트, 256비트 정수 (블록 암호화에서 사용)
  3. ByteArray (가변 길이 바이트 버퍼)
  4. BitVector (비트 단위 연산이 필요한 경우)

 

class hxcBigInteger {
private:
    std::vector<uint32_t> digits;  // 32-bit words in little-endian order
    bool negative;                 // Sign indicator

    // Helper function to remove leading zeros
    void Trim() {
        while (!digits.empty() && digits.back() == 0) {
            digits.pop_back();
        }
        if (digits.empty()) {
            digits.push_back(0);
            negative = false;
        }
    }

public:
    // Default constructor
    hxcBigInteger(int64_t value = 0) : negative(value < 0) {
        uint64_t absValue = (value < 0) ? -value : value;
        while (absValue > 0) {
            digits.push_back(static_cast<uint32_t>(absValue & 0xFFFFFFFF));
            absValue >>= 32;
        }
        if (digits.empty()) {
            digits.push_back(0);
        }
    }

    // Constructor from string
    hxcBigInteger(const std::string& str) {
        *this = FromString(str);
    }

    // Convert string to hxcBigInteger
    static hxcBigInteger FromString(const std::string& str) {
        hxcBigInteger result;
        result.negative = (str[0] == '-');
        size_t start = (result.negative || str[0] == '+') ? 1 : 0;

        for (size_t i = start; i < str.length(); ++i) {
            result = result * 10 + (str[i] - '0');
        }
        return result;
    }

    // Addition operator
    hxcBigInteger operator+(const hxcBigInteger& other) const {
        if (negative == other.negative) {
            hxcBigInteger result;
            result.negative = negative;
            uint64_t carry = 0;
            size_t maxSize = std::max(digits.size(), other.digits.size());

            for (size_t i = 0; i < maxSize || carry; ++i) {
                uint64_t sum = carry;
                if (i < digits.size()) sum += digits[i];
                if (i < other.digits.size()) sum += other.digits[i];

                if (i >= result.digits.size()) {
                    result.digits.push_back(0);
                }
                result.digits[i] = sum & 0xFFFFFFFF;
                carry = sum >> 32;
            }
            result.Trim();
            return result;
        }
        return *this - (-other);
    }

    // Negation operator
    hxcBigInteger operator-() const {
        hxcBigInteger result = *this;
        if (*this != 0) {
            result.negative = !negative;
        }
        return result;
    }

    // Subtraction operator
    hxcBigInteger operator-(const hxcBigInteger& other) const {
        if (negative != other.negative) {
            return *this + (-other);
        }
        if (*this == other) return hxcBigInteger(0);

        bool swapNeeded = (*this < other);
        const hxcBigInteger& larger = swapNeeded ? other : *this;
        const hxcBigInteger& smaller = swapNeeded ? *this : other;

        hxcBigInteger result;
        result.negative = swapNeeded ? !negative : negative;

        int64_t carry = 0;
        for (size_t i = 0; i < larger.digits.size() || carry; ++i) {
            int64_t diff = larger.digits[i] - carry - (i < smaller.digits.size() ? smaller.digits[i] : 0);
            carry = (diff < 0);
            if (carry) diff += 0x100000000LL;
            result.digits.push_back(static_cast<uint32_t>(diff));
        }
        result.Trim();
        return result;
    }

    // Multiplication operator
    hxcBigInteger operator*(const hxcBigInteger& other) const {
        hxcBigInteger result;
        result.digits.resize(digits.size() + other.digits.size());
        result.negative = (negative != other.negative);

        for (size_t i = 0; i < digits.size(); ++i) {
            uint64_t carry = 0;
            for (size_t j = 0; j < other.digits.size() || carry; ++j) {
                uint64_t current = result.digits[i + j] +
                                   carry +
                                   static_cast<uint64_t>(digits[i]) * (j < other.digits.size() ? other.digits[j] : 0);
                result.digits[i + j] = current & 0xFFFFFFFF;
                carry = current >> 32;
            }
        }
        result.Trim();
        return result;
    }

    // Division operator
    hxcBigInteger operator/(const hxcBigInteger& other) const {
        // To be implemented: Long division algorithm
        return hxcBigInteger(0);
    }

    // Modulus operator
    hxcBigInteger operator%(const hxcBigInteger& other) const {
        // To be implemented: Modulus calculation
        return hxcBigInteger(0);
    }

    // Modular exponentiation
    hxcBigInteger ModExp(const hxcBigInteger& exponent, const hxcBigInteger& modulus) const {
        hxcBigInteger base = *this % modulus;
        hxcBigInteger result(1);
        hxcBigInteger exp = exponent;

        while (exp != 0) {
            if (exp.digits[0] & 1) {
                result = (result * base) % modulus;
            }
            base = (base * base) % modulus;
            exp >>= 1;
        }
        return result;
    }

    // ToString method for output
    std::string ToString() const {
        if (*this == 0) return "0";
        hxcBigInteger temp = *this;
        std::string result;
        while (temp != 0) {
            result += '0' + (temp % 10).digits[0];
            temp /= 10;
        }
        if (negative) result += '-';
        std::reverse(result.begin(), result.end());
        return result;
    }
};
struct uint128_t {
    uint64_t high;  // 상위 64비트
    uint64_t low;   // 하위 64비트

    uint128_t(uint64_t h = 0, uint64_t l = 0) : high(h), low(l) {}

    // 덧셈
    uint128_t operator+(const uint128_t& other) const {
        uint128_t result;
        result.low = low + other.low;
        result.high = high + other.high + (result.low < low);  // carry 처리
        return result;
    }
};

struct uint256_t {
    uint128_t high;
    uint128_t low;

    uint256_t(uint128_t h = {}, uint128_t l = {}) : high(h), low(l) {}

    uint256_t operator+(const uint256_t& other) const {
        uint256_t result;
        result.low = low + other.low;
        result.high = high + other.high + (result.low.low < low.low);  // carry 처리
        return result;
    }
};

 

class ByteArray {
private:
    std::vector<uint8_t> buffer;

public:
    ByteArray(size_t size = 0) : buffer(size) {}

    // 데이터 삽입
    void Append(const uint8_t* data, size_t len) {
        buffer.insert(buffer.end(), data, data + len);
    }

    // 바이트 접근
    uint8_t& operator[](size_t index) {
        return buffer[index];
    }

    // 길이 반환
    size_t Length() const {
        return buffer.size();
    }

    const uint8_t* Data() const {
        return buffer.data();
    }
};
class BitVector {
private:
    std::vector<uint8_t> bits;

public:
    BitVector(size_t bit_size = 0) : bits((bit_size + 7) / 8, 0) {}

    // 특정 비트 설정
    void SetBit(size_t index, bool value) {
        size_t byte_index = index / 8;
        size_t bit_index = index % 8;
        if (value) {
            bits[byte_index] |= (1 << bit_index);
        } else {
            bits[byte_index] &= ~(1 << bit_index);
        }
    }

    // 특정 비트 읽기
    bool GetBit(size_t index) const {
        size_t byte_index = index / 8;
        size_t bit_index = index % 8;
        return (bits[byte_index] & (1 << bit_index)) != 0;
    }

    // 비트 개수 반환
    size_t Size() const {
        return bits.size() * 8;
    }
};
int main() {
    BigInteger a("123456789123456789123456789");
    BigInteger b("987654321987654321987654321");
    BigInteger c = a + b;

    std::cout << "Sum: " << c.ToString() << std::endl;

    uint256_t x({0xFFFFFFFFFFFFFFFF}, {0x1});
    uint256_t y({0}, {0x1});
    uint256_t z = x + y;

    std::cout << "256-bit Addition: " << z.low.low << std::endl;

    ByteArray buffer(16);
    buffer.Append((uint8_t*)"Hello", 5);
    std::cout << "Buffer Length: " << buffer.Length() << std::endl;

    return 0;
}
728x90