/* ** 2020-06-22 ** ** The author disclaims copyright to this source code. In place of ** a legal notice, here is a blessing: ** ** May you do good and not evil. ** May you find forgiveness for yourself and forgive others. ** May you share freely, never taking more than you give. ** ****************************************************************************** ** ** Routines to implement arbitrary-precision decimal math. ** ** The focus here is on simplicity and correctness, not performance. */ #include "sqlite3ext.h" SQLITE_EXTENSION_INIT1 #include #include #include #include /* Mark a function parameter as unused, to suppress nuisance compiler ** warnings. */ #ifndef UNUSED_PARAMETER # define UNUSED_PARAMETER(X) (void)(X) #endif /* A decimal object */ typedef struct Decimal Decimal; struct Decimal { char sign; /* 0 for positive, 1 for negative */ char oom; /* True if an OOM is encountered */ char isNull; /* True if holds a NULL rather than a number */ char isInit; /* True upon initialization */ int nDigit; /* Total number of digits */ int nFrac; /* Number of digits to the right of the decimal point */ signed char *a; /* Array of digits. Most significant first. */ }; /* ** Release memory held by a Decimal, but do not free the object itself. */ static void decimal_clear(Decimal *p){ sqlite3_free(p->a); } /* ** Destroy a Decimal object */ static void decimal_free(Decimal *p){ if( p ){ decimal_clear(p); sqlite3_free(p); } } /* ** Allocate a new Decimal object. Initialize it to the number given ** by the input string. */ static Decimal *decimal_new( sqlite3_context *pCtx, sqlite3_value *pIn, int nAlt, const unsigned char *zAlt ){ Decimal *p; int n, i; const unsigned char *zIn; int iExp = 0; p = sqlite3_malloc( sizeof(*p) ); if( p==0 ) goto new_no_mem; p->sign = 0; p->oom = 0; p->isInit = 1; p->isNull = 0; p->nDigit = 0; p->nFrac = 0; if( zAlt ){ n = nAlt, zIn = zAlt; }else{ if( sqlite3_value_type(pIn)==SQLITE_NULL ){ p->a = 0; p->isNull = 1; return p; } n = sqlite3_value_bytes(pIn); zIn = sqlite3_value_text(pIn); } p->a = sqlite3_malloc64( n+1 ); if( p->a==0 ) goto new_no_mem; for(i=0; isspace(zIn[i]); i++){} if( zIn[i]=='-' ){ p->sign = 1; i++; }else if( zIn[i]=='+' ){ i++; } while( i='0' && c<='9' ){ p->a[p->nDigit++] = c - '0'; }else if( c=='.' ){ p->nFrac = p->nDigit + 1; }else if( c=='e' || c=='E' ){ int j = i+1; int neg = 0; if( j>=n ) break; if( zIn[j]=='-' ){ neg = 1; j++; }else if( zIn[j]=='+' ){ j++; } while( j='0' && zIn[j]<='9' ){ iExp = iExp*10 + zIn[j] - '0'; } j++; } if( neg ) iExp = -iExp; break; } i++; } if( p->nFrac ){ p->nFrac = p->nDigit - (p->nFrac - 1); } if( iExp>0 ){ if( p->nFrac>0 ){ if( iExp<=p->nFrac ){ p->nFrac -= iExp; iExp = 0; }else{ iExp -= p->nFrac; p->nFrac = 0; } } if( iExp>0 ){ p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 ); if( p->a==0 ) goto new_no_mem; memset(p->a+p->nDigit, 0, iExp); p->nDigit += iExp; } }else if( iExp<0 ){ int nExtra; iExp = -iExp; nExtra = p->nDigit - p->nFrac - 1; if( nExtra ){ if( nExtra>=iExp ){ p->nFrac += iExp; iExp = 0; }else{ iExp -= nExtra; p->nFrac = p->nDigit - 1; } } if( iExp>0 ){ p->a = sqlite3_realloc64(p->a, p->nDigit + iExp + 1 ); if( p->a==0 ) goto new_no_mem; memmove(p->a+iExp, p->a, p->nDigit); memset(p->a, 0, iExp); p->nDigit += iExp; p->nFrac += iExp; } } return p; new_no_mem: if( pCtx ) sqlite3_result_error_nomem(pCtx); sqlite3_free(p); return 0; } /* ** Make the given Decimal the result. */ static void decimal_result(sqlite3_context *pCtx, Decimal *p){ char *z; int i, j; int n; if( p==0 || p->oom ){ sqlite3_result_error_nomem(pCtx); return; } if( p->isNull ){ sqlite3_result_null(pCtx); return; } z = sqlite3_malloc( p->nDigit+4 ); if( z==0 ){ sqlite3_result_error_nomem(pCtx); return; } i = 0; if( p->nDigit==0 || (p->nDigit==1 && p->a[0]==0) ){ p->sign = 0; } if( p->sign ){ z[0] = '-'; i = 1; } n = p->nDigit - p->nFrac; if( n<=0 ){ z[i++] = '0'; } j = 0; while( n>1 && p->a[j]==0 ){ j++; n--; } while( n>0 ){ z[i++] = p->a[j] + '0'; j++; n--; } if( p->nFrac ){ z[i++] = '.'; do{ z[i++] = p->a[j] + '0'; j++; }while( jnDigit ); } z[i] = 0; sqlite3_result_text(pCtx, z, i, sqlite3_free); } /* ** SQL Function: decimal(X) ** ** Convert input X into decimal and then back into text */ static void decimalFunc( sqlite3_context *context, int argc, sqlite3_value **argv ){ Decimal *p = decimal_new(context, argv[0], 0, 0); UNUSED_PARAMETER(argc); decimal_result(context, p); decimal_free(p); } /* ** Compare to Decimal objects. Return negative, 0, or positive if the ** first object is less than, equal to, or greater than the second. ** ** Preconditions for this routine: ** ** pA!=0 ** pA->isNull==0 ** pB!=0 ** pB->isNull==0 */ static int decimal_cmp(const Decimal *pA, const Decimal *pB){ int nASig, nBSig, rc, n; if( pA->sign!=pB->sign ){ return pA->sign ? -1 : +1; } if( pA->sign ){ const Decimal *pTemp = pA; pA = pB; pB = pTemp; } nASig = pA->nDigit - pA->nFrac; nBSig = pB->nDigit - pB->nFrac; if( nASig!=nBSig ){ return nASig - nBSig; } n = pA->nDigit; if( n>pB->nDigit ) n = pB->nDigit; rc = memcmp(pA->a, pB->a, n); if( rc==0 ){ rc = pA->nDigit - pB->nDigit; } return rc; } /* ** SQL Function: decimal_cmp(X, Y) ** ** Return negative, zero, or positive if X is less then, equal to, or ** greater than Y. */ static void decimalCmpFunc( sqlite3_context *context, int argc, sqlite3_value **argv ){ Decimal *pA = 0, *pB = 0; int rc; UNUSED_PARAMETER(argc); pA = decimal_new(context, argv[0], 0, 0); if( pA==0 || pA->isNull ) goto cmp_done; pB = decimal_new(context, argv[1], 0, 0); if( pB==0 || pB->isNull ) goto cmp_done; rc = decimal_cmp(pA, pB); if( rc<0 ) rc = -1; else if( rc>0 ) rc = +1; sqlite3_result_int(context, rc); cmp_done: decimal_free(pA); decimal_free(pB); } /* ** Expand the Decimal so that it has a least nDigit digits and nFrac ** digits to the right of the decimal point. */ static void decimal_expand(Decimal *p, int nDigit, int nFrac){ int nAddSig; int nAddFrac; if( p==0 ) return; nAddFrac = nFrac - p->nFrac; nAddSig = (nDigit - p->nDigit) - nAddFrac; if( nAddFrac==0 && nAddSig==0 ) return; p->a = sqlite3_realloc64(p->a, nDigit+1); if( p->a==0 ){ p->oom = 1; return; } if( nAddSig ){ memmove(p->a+nAddSig, p->a, p->nDigit); memset(p->a, 0, nAddSig); p->nDigit += nAddSig; } if( nAddFrac ){ memset(p->a+p->nDigit, 0, nAddFrac); p->nDigit += nAddFrac; p->nFrac += nAddFrac; } } /* ** Add the value pB into pA. ** ** Both pA and pB might become denormalized by this routine. */ static void decimal_add(Decimal *pA, Decimal *pB){ int nSig, nFrac, nDigit; int i, rc; if( pA==0 ){ return; } if( pA->oom || pB==0 || pB->oom ){ pA->oom = 1; return; } if( pA->isNull || pB->isNull ){ pA->isNull = 1; return; } nSig = pA->nDigit - pA->nFrac; if( nSig && pA->a[0]==0 ) nSig--; if( nSignDigit-pB->nFrac ){ nSig = pB->nDigit - pB->nFrac; } nFrac = pA->nFrac; if( nFracnFrac ) nFrac = pB->nFrac; nDigit = nSig + nFrac + 1; decimal_expand(pA, nDigit, nFrac); decimal_expand(pB, nDigit, nFrac); if( pA->oom || pB->oom ){ pA->oom = 1; }else{ if( pA->sign==pB->sign ){ int carry = 0; for(i=nDigit-1; i>=0; i--){ int x = pA->a[i] + pB->a[i] + carry; if( x>=10 ){ carry = 1; pA->a[i] = x - 10; }else{ carry = 0; pA->a[i] = x; } } }else{ signed char *aA, *aB; int borrow = 0; rc = memcmp(pA->a, pB->a, nDigit); if( rc<0 ){ aA = pB->a; aB = pA->a; pA->sign = !pA->sign; }else{ aA = pA->a; aB = pB->a; } for(i=nDigit-1; i>=0; i--){ int x = aA[i] - aB[i] - borrow; if( x<0 ){ pA->a[i] = x+10; borrow = 1; }else{ pA->a[i] = x; borrow = 0; } } } } } /* ** Compare text in decimal order. */ static int decimalCollFunc( void *notUsed, int nKey1, const void *pKey1, int nKey2, const void *pKey2 ){ const unsigned char *zA = (const unsigned char*)pKey1; const unsigned char *zB = (const unsigned char*)pKey2; Decimal *pA = decimal_new(0, 0, nKey1, zA); Decimal *pB = decimal_new(0, 0, nKey2, zB); int rc; UNUSED_PARAMETER(notUsed); if( pA==0 || pB==0 ){ rc = 0; }else{ rc = decimal_cmp(pA, pB); } decimal_free(pA); decimal_free(pB); return rc; } /* ** SQL Function: decimal_add(X, Y) ** decimal_sub(X, Y) ** ** Return the sum or difference of X and Y. */ static void decimalAddFunc( sqlite3_context *context, int argc, sqlite3_value **argv ){ Decimal *pA = decimal_new(context, argv[0], 0, 0); Decimal *pB = decimal_new(context, argv[1], 0, 0); UNUSED_PARAMETER(argc); decimal_add(pA, pB); decimal_result(context, pA); decimal_free(pA); decimal_free(pB); } static void decimalSubFunc( sqlite3_context *context, int argc, sqlite3_value **argv ){ Decimal *pA = decimal_new(context, argv[0], 0, 0); Decimal *pB = decimal_new(context, argv[1], 0, 0); UNUSED_PARAMETER(argc); if( pB ){ pB->sign = !pB->sign; decimal_add(pA, pB); decimal_result(context, pA); } decimal_free(pA); decimal_free(pB); } /* Aggregate funcion: decimal_sum(X) ** ** Works like sum() except that it uses decimal arithmetic for unlimited ** precision. */ static void decimalSumStep( sqlite3_context *context, int argc, sqlite3_value **argv ){ Decimal *p; Decimal *pArg; UNUSED_PARAMETER(argc); p = sqlite3_aggregate_context(context, sizeof(*p)); if( p==0 ) return; if( !p->isInit ){ p->isInit = 1; p->a = sqlite3_malloc(2); if( p->a==0 ){ p->oom = 1; }else{ p->a[0] = 0; } p->nDigit = 1; p->nFrac = 0; } if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return; pArg = decimal_new(context, argv[0], 0, 0); decimal_add(p, pArg); decimal_free(pArg); } static void decimalSumInverse( sqlite3_context *context, int argc, sqlite3_value **argv ){ Decimal *p; Decimal *pArg; UNUSED_PARAMETER(argc); p = sqlite3_aggregate_context(context, sizeof(*p)); if( p==0 ) return; if( sqlite3_value_type(argv[0])==SQLITE_NULL ) return; pArg = decimal_new(context, argv[0], 0, 0); if( pArg ) pArg->sign = !pArg->sign; decimal_add(p, pArg); decimal_free(pArg); } static void decimalSumValue(sqlite3_context *context){ Decimal *p = sqlite3_aggregate_context(context, 0); if( p==0 ) return; decimal_result(context, p); } static void decimalSumFinalize(sqlite3_context *context){ Decimal *p = sqlite3_aggregate_context(context, 0); if( p==0 ) return; decimal_result(context, p); decimal_clear(p); } /* ** SQL Function: decimal_mul(X, Y) ** ** Return the product of X and Y. ** ** All significant digits after the decimal point are retained. ** Trailing zeros after the decimal point are omitted as long as ** the number of digits after the decimal point is no less than ** either the number of digits in either input. */ static void decimalMulFunc( sqlite3_context *context, int argc, sqlite3_value **argv ){ Decimal *pA = decimal_new(context, argv[0], 0, 0); Decimal *pB = decimal_new(context, argv[1], 0, 0); signed char *acc = 0; int i, j, k; int minFrac; UNUSED_PARAMETER(argc); if( pA==0 || pA->oom || pA->isNull || pB==0 || pB->oom || pB->isNull ){ goto mul_end; } acc = sqlite3_malloc64( pA->nDigit + pB->nDigit + 2 ); if( acc==0 ){ sqlite3_result_error_nomem(context); goto mul_end; } memset(acc, 0, pA->nDigit + pB->nDigit + 2); minFrac = pA->nFrac; if( pB->nFracnFrac; for(i=pA->nDigit-1; i>=0; i--){ signed char f = pA->a[i]; int carry = 0, x; for(j=pB->nDigit-1, k=i+j+3; j>=0; j--, k--){ x = acc[k] + f*pB->a[j] + carry; acc[k] = x%10; carry = x/10; } x = acc[k] + carry; acc[k] = x%10; acc[k-1] += x/10; } sqlite3_free(pA->a); pA->a = acc; acc = 0; pA->nDigit += pB->nDigit + 2; pA->nFrac += pB->nFrac; pA->sign ^= pB->sign; while( pA->nFrac>minFrac && pA->a[pA->nDigit-1]==0 ){ pA->nFrac--; pA->nDigit--; } decimal_result(context, pA); mul_end: sqlite3_free(acc); decimal_free(pA); decimal_free(pB); } #ifdef _WIN32 __declspec(dllexport) #endif int sqlite3_decimal_init( sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi ){ int rc = SQLITE_OK; static const struct { const char *zFuncName; int nArg; void (*xFunc)(sqlite3_context*,int,sqlite3_value**); } aFunc[] = { { "decimal", 1, decimalFunc }, { "decimal_cmp", 2, decimalCmpFunc }, { "decimal_add", 2, decimalAddFunc }, { "decimal_sub", 2, decimalSubFunc }, { "decimal_mul", 2, decimalMulFunc }, }; unsigned int i; (void)pzErrMsg; /* Unused parameter */ SQLITE_EXTENSION_INIT2(pApi); for(i=0; i<(int)(sizeof(aFunc)/sizeof(aFunc[0])) && rc==SQLITE_OK; i++){ rc = sqlite3_create_function(db, aFunc[i].zFuncName, aFunc[i].nArg, SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 0, aFunc[i].xFunc, 0, 0); } if( rc==SQLITE_OK ){ rc = sqlite3_create_window_function(db, "decimal_sum", 1, SQLITE_UTF8|SQLITE_INNOCUOUS|SQLITE_DETERMINISTIC, 0, decimalSumStep, decimalSumFinalize, decimalSumValue, decimalSumInverse, 0); } if( rc==SQLITE_OK ){ rc = sqlite3_create_collation(db, "decimal", SQLITE_UTF8, 0, decimalCollFunc); } return rc; }