1 /****** 2 * Handling amounts of money safely and efficiently. 3 * 4 * An amount of money is a number tagged with a currency id like "EUR" 5 * or "USD". Precision and rounding mode can be chosen as template 6 * parameters. 7 * 8 * If you write code which handles money, you have to choose a data type 9 * for it. Out of the box, D offers you floating point, integer, and 10 * std.bigint. All of these have their problems. 11 * 12 * Floating point is inherently imprecise. If your dollar numbers become 13 * too big, then you start getting too much or too little cents. This 14 * is not acceptable as the errors accumulate. Also, floating point has 15 * values like "infinity" and "not a number" and if those show up, 16 * usually things break, if you did not prepare for it. Debugging then 17 * means to work backwards how this happened, which is tedious and hard. 18 * 19 * Integer numbers do not suffer from imprecision, but they can not 20 * represent numbers as big as floating point. Worse, if your numbers 21 * become too big, then your CPU silently wraps them into negative 22 * numbers. Like the imprecision with floating point, your data is 23 * now corrupted without anyone noticing it yet. Also, fixed point 24 * arithmetic with integers is easy to get wrong and you need a 25 * fractional part to represent cents, for example. 26 * 27 * As a third option, there is std.bigint, which provides numbers 28 * with arbitrary precision. Like floating point, the arithmetic is easy. 29 * Like integer, precision is fine. The downside is performance. 30 * Nevertheless, from the three options, this is the most safe one. 31 * 32 * Can we do even better? 33 * If we design a custom data type for money, we can improve safety 34 * even more. For example, certain arithmetics can be forbidden. What 35 * does it mean to multiply two money amounts, for example? There is no 36 * such thing as $² which makes any sense. However, you can certainly 37 * multiply a money amount with a unitless number. A custom data type 38 * can precisely allow and forbid this operations. 39 * 40 * Here the design decision is to use an integer for the internal 41 * representation. This limits the amounts you can use. For example, 42 * if you decide to use 4 digits behind the comma, the maximum number 43 * is 922,337,203,685,477.5807 or roughly 922 trillion. The US debt is 44 * currently in the trillions, so there are certainly cases where 45 * this representation is not applicable. However, we can check overflow, 46 * so if it happens, you get an exception thrown and notice it 47 * right away. The upside of using an integer is performance and 48 * a deterministic arithmetic all programmers are familiar with. 49 * 50 * License: $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost License 1.0) 51 * Authors: Andreas Zwinkau 52 */ 53 module money; 54 55 import std.math : floor, ceil, lrint, abs, FloatingPointControl; 56 import std.conv : to; 57 import core.checkedint : adds, subs, muls, negs; 58 import std.format : FormatSpec, formattedWrite; 59 import std.traits : hasMember; 60 61 @nogc pure @safe nothrow private long pow10(int x) 62 { 63 if (x <= 0) 64 return 1; 65 return 10 * pow10(x - 1); 66 } 67 68 /** Holds an amount of money **/ 69 struct money(string currency, int dec_places = 4, roundingMode rmode = roundingMode.HALF_UP) 70 { 71 alias T = typeof(this); 72 enum __currency = currency; 73 enum __dec_places = dec_places; 74 enum __rmode = rmode; 75 long amount; 76 77 /// Usual contructor. Uses rmode on x. 78 this(double x) 79 { 80 amount = to!long(round(x * pow10(dec_places), rmode)); 81 } 82 83 private static T fromLong(long a) 84 { 85 T ret = void; 86 ret.amount = a; 87 return ret; 88 } 89 90 /// default initialisation value is zero 91 static immutable init = fromLong(0L); 92 /// maximum amount depends on dec_places 93 static immutable max = fromLong(long.max); 94 /// minimum amount depends on dec_places 95 static immutable min = fromLong(long.min); 96 97 private static immutable dec_mask = pow10(dec_places); 98 99 /// Can add and subtract money amounts of the same type. 100 T opBinary(string op)(const T rhs) const 101 { 102 static if (op == "+") 103 { 104 bool overflow; 105 auto ret = fromLong(adds(amount, rhs.amount, overflow)); 106 if (overflow) 107 throw new OverflowException(); 108 return ret; 109 } 110 else static if (op == "-") 111 { 112 bool overflow; 113 auto ret = fromLong(subs(amount, rhs.amount, overflow)); 114 if (overflow) 115 throw new OverflowException(); 116 return ret; 117 } 118 else 119 static assert(0, "Operator " ~ op ~ " not implemented"); 120 } 121 122 /// Can multiply, divide, and modulo with integer values. 123 T opBinary(string op)(const long rhs) const 124 { 125 static if (op == "*") 126 { 127 bool overflow; 128 auto ret = fromLong(muls(amount, rhs, overflow)); 129 if (overflow) 130 throw new OverflowException(); 131 return ret; 132 } 133 else static if (op == "/") 134 { 135 return fromLong(amount / rhs); 136 } 137 else static if (op == "%") 138 { 139 const intpart = amount / pow10(dec_places); 140 return fromLong(intpart % rhs * pow10(dec_places)); 141 } 142 // TODO support * / % ? Might be useful for taxes etc. 143 else 144 static assert(0, "Operator " ~ op ~ " not implemented"); 145 } 146 147 /// Can multiply, divide, and modulo floating point numbers. 148 T opBinary(string op)(const real rhs) const 149 { 150 static if (op == "*") 151 { 152 const converted = T(rhs); 153 bool overflow = false; 154 const result = muls(amount, converted.amount, overflow); 155 if (overflow) 156 throw new OverflowException(); 157 return fromLong(result / pow10(dec_places)); 158 } 159 else static if (op == "/") 160 { 161 const converted = T(rhs); 162 bool overflow = false; 163 auto mult = muls(amount, pow10(dec_places), overflow); 164 if (overflow) 165 throw new OverflowException(); 166 return fromLong(mult / converted.amount); 167 } 168 else static if (op == "%") 169 { 170 const converted = T(rhs); 171 return fromLong(amount % converted.amount); 172 } 173 // TODO support * / % ? Might be useful for taxes etc. 174 else 175 static assert(0, "Operator " ~ op ~ " not implemented"); 176 } 177 178 /// Can check equality with money amounts of the same concurrency and decimal places. 179 bool opEquals(OT)(auto ref const OT other) const if (isMoney!OT 180 && other.__currency == currency && other.__dec_places == dec_places) 181 { 182 return other.amount == amount; 183 } 184 185 /// Can compare with money amounts of the same concurrency. 186 int opCmp(OT)(const OT other) const if (isMoney!OT && other.__currency == currency) 187 { 188 static if (dec_places == other.__dec_places) 189 { 190 if (other.amount > this.amount) 191 return -1; 192 if (other.amount < this.amount) 193 return 1; 194 return 0; 195 } 196 else static if (dec_places < other.__dec_places) 197 { 198 /* D implicitly makes this work for case '>' */ 199 auto nthis = this * pow10(other.__dec_places - dec_places); 200 /* overflow check included */ 201 if (other.amount > nthis.amount) 202 return -1; 203 if (other.amount < nthis.amount) 204 return 1; 205 return 0; 206 } 207 else 208 static assert(0, "opCmp with such 'other' not implemented"); 209 } 210 211 /// Can convert to string. 212 void toString(scope void delegate(const(char)[]) sink, FormatSpec!char fmt) const 213 { 214 switch (fmt.spec) 215 { 216 case 's': /* default e.g. for writeln */ 217 goto case; 218 case 'f': 219 formattedWrite(sink, "%d", (amount / dec_mask)); 220 sink("."); 221 auto decimals = amount % dec_mask; 222 if (fmt.precision < dec_places) 223 { 224 auto n = dec_places - fmt.precision; 225 decimals = round!(rmode)(decimals, n); 226 decimals = decimals / pow10(n); 227 } 228 formattedWrite(sink, "%d", decimals); 229 sink(currency); 230 break; 231 case 'd': 232 auto ra = round!rmode(amount, dec_places); 233 formattedWrite(sink, "%d", (ra / dec_mask)); 234 sink(currency); 235 break; 236 default: 237 throw new Exception("Unknown format specifier: %" ~ fmt.spec); 238 } 239 } 240 } 241 242 /// Basic usage 243 unittest 244 { 245 alias EUR = money!("EUR"); 246 assert(EUR(100.0001) == EUR(100.00009)); 247 assert(EUR(3.10) + EUR(1.40) == EUR(4.50)); 248 assert(EUR(3.10) - EUR(1.40) == EUR(1.70)); 249 assert(EUR(10.01) * 1.1 == EUR(11.011)); 250 251 import std.format : format; 252 253 // for writefln("%d", EUR(3.6)); 254 assert(format("%d", EUR(3.6)) == "4EUR"); 255 assert(format("%d", EUR(3.1)) == "3EUR"); 256 // for writefln("%f", EUR(3.141592)); 257 assert(format("%f", EUR(3.141592)) == "3.1416EUR"); 258 assert(format("%.2f", EUR(3.145)) == "3.15EUR"); 259 } 260 261 /// Overflow is an error, since silent corruption is worse 262 unittest 263 { 264 import std.exception : assertThrown; 265 266 alias EUR = money!("EUR"); 267 auto one = EUR(1); 268 assertThrown!OverflowException(EUR.max + one); 269 assertThrown!OverflowException(EUR.min - one); 270 } 271 272 /// Arithmetic ignores rounding mode 273 unittest 274 { 275 alias EUR = money!("EUR", 2, roundingMode.UP); 276 auto one = EUR(1); 277 assert(one != one / 3); 278 } 279 280 /// Generic equality and order 281 unittest 282 { 283 alias USD = money!("USD", 2); 284 alias EURa = money!("EUR", 2); 285 alias EURb = money!("EUR", 4); 286 alias EURc = money!("EUR", 4, roundingMode.DOWN); 287 // cannot compile with different currencies 288 static assert(!__traits(compiles, EURa(1) == USD(1))); 289 // cannot compile with different dec_places 290 static assert(!__traits(compiles, EURa(1) == EURb(1))); 291 // can check equality if only rounding mode differs 292 assert(EURb(1.01) == EURc(1.01)); 293 // cannot compare with different currencies 294 static assert(!__traits(compiles, EURa(1) < USD(1))); 295 } 296 297 298 // TODO Using negative dec_places for big numbers? 299 //unittest 300 //{ 301 // alias USD = money!("USD", -6); 302 // assert(USD(1_000_000.00) == USD(1_100_000.)); 303 //} 304 305 enum isMoney(T) = (hasMember!(T, "amount") && hasMember!(T, "__dec_places") 306 && hasMember!(T, "__rmode")); 307 static assert(isMoney!(money!"EUR")); 308 309 unittest { 310 alias EUR = money!("EUR"); 311 import std.format : format; 312 assert(format("%s", EUR(3.1)) == "3.1000EUR"); 313 314 import std.exception : assertThrown; 315 assertThrown!Exception(format("%x", EUR(3.1))); 316 } 317 318 unittest 319 { 320 alias EUR = money!("EUR"); 321 assert(EUR(5) < EUR(6)); 322 assert(EUR(6) > EUR(5)); 323 assert(EUR(5) >= EUR(5)); 324 assert(EUR(5) == EUR(5)); 325 assert(EUR(6) != EUR(5)); 326 327 import std.exception : assertThrown; 328 assertThrown!OverflowException(EUR.max * 2); 329 assertThrown!OverflowException(EUR.max * 2.0); 330 } 331 332 unittest 333 { 334 alias EUR = money!("EUR"); 335 auto x = EUR(42); 336 assert(EUR(84) == x * 2); 337 static assert(!__traits(compiles, x * x)); 338 assert(EUR(21) == x / 2); 339 assert(EUR(2) == x % 4); 340 } 341 342 unittest 343 { 344 alias EURa = money!("EUR", 2); 345 alias EURb = money!("EUR", 4); 346 auto x = EURa(1.01); 347 assert(x > EURb(1.0001)); 348 assert(x < EURb(1.0101)); 349 assert(x <= EURb(1.01)); 350 } 351 352 /** Specifies rounding behavior **/ 353 enum roundingMode 354 { 355 // dfmt off 356 /** Round upwards, e.g. 3.1 up to 4. */ 357 UP, 358 /** Round downwards, e.g. 3.9 down to 3. */ 359 DOWN, 360 /** Round to nearest number, half way between round up, e.g. 3.5 to 4. */ 361 HALF_UP, 362 /** Round to nearest number, half way between round dow, e.g. 3.5 to 3. */ 363 HALF_DOWN, 364 /** Round to nearest number, half way between round to even number, e.g. 3.5 to 4. */ 365 HALF_EVEN, 366 /** Round to nearest number, half way between round to odd number, e.g. 3.5 to 3. */ 367 HALF_ODD, 368 /** Round to nearest number, half way between round towards zero, e.g. -3.5 to -3. */ 369 HALF_TO_ZERO, 370 /** Round to nearest number, half way between round away from zero, e.g. -3.5 to -4. */ 371 HALF_FROM_ZERO, 372 /** Throw exception if rounding would be necessary */ 373 UNNECESSARY 374 // dfmt on 375 } 376 377 /** Round an integer to a certain decimal place according to rounding mode */ 378 long round(roundingMode m)(long x, int dec_place) 379 out (result) 380 { 381 assert((result % pow10(dec_place)) == 0); 382 } 383 body 384 { 385 const zeros = pow10(dec_place); 386 /* short cut, also removes edge cases */ 387 if ((x % zeros) == 0) 388 return x; 389 390 const half = zeros / 2; 391 with (roundingMode) 392 { 393 static if (m == UP) 394 { 395 return ((x / zeros) + 1) * zeros; 396 } 397 else static if (m == DOWN) 398 { 399 return x / zeros * zeros; 400 } 401 else static if (m == HALF_UP) 402 { 403 if ((x % zeros) >= half) 404 return ((x / zeros) + 1) * zeros; 405 else 406 return x / zeros * zeros; 407 } 408 else static if (m == HALF_DOWN) 409 { 410 if ((x % zeros) > half) 411 return ((x / zeros) + 1) * zeros; 412 else 413 return x / zeros * zeros; 414 } 415 else static if (m == HALF_EVEN) 416 { 417 const down = x / zeros; 418 if (down % 2 == 0) 419 return down * zeros; 420 else 421 return (down + 1) * zeros; 422 } 423 else static if (m == HALF_ODD) 424 { 425 const down = x / zeros; 426 if (down % 2 == 0) 427 return (down + 1) * zeros; 428 else 429 return down * zeros; 430 } 431 else static if (m == HALF_TO_ZERO) 432 { 433 const down = x / zeros; 434 if (down < 0) 435 { 436 if (abs(x % zeros) <= half) 437 { 438 return (down) * zeros; 439 } 440 else 441 { 442 return (down - 1) * zeros; 443 } 444 } 445 else 446 { 447 if ((x % zeros) > half) 448 { 449 return (down + 1) * zeros; 450 } 451 else 452 { 453 return (down) * zeros; 454 } 455 } 456 } 457 else static if (m == HALF_FROM_ZERO) 458 { 459 const down = x / zeros; 460 if (down < 0) 461 { 462 if (abs(x % zeros) < half) 463 { 464 return (down) * zeros; 465 } 466 else 467 { 468 return (down - 1) * zeros; 469 } 470 } 471 else 472 { 473 if (x % zeros >= half) 474 { 475 return (down + 1) * zeros; 476 } 477 else 478 { 479 return (down) * zeros; 480 } 481 } 482 } 483 else static if (m == UNNECESSARY) 484 { 485 throw new ForbiddenRounding(); 486 } 487 } 488 } 489 490 // dfmt off 491 /// 492 unittest 493 { 494 assert (round!(roundingMode.DOWN) (1009, 1) == 1000); 495 assert (round!(roundingMode.UP) (1001, 1) == 1010); 496 assert (round!(roundingMode.HALF_UP) (1005, 1) == 1010); 497 assert (round!(roundingMode.HALF_DOWN)(1005, 1) == 1000); 498 } 499 // dfmt on 500 501 @safe pure @nogc nothrow unittest 502 { 503 // dfmt off 504 assert (round!(roundingMode.HALF_UP) ( 10, 1) == 10); 505 assert (round!(roundingMode.UP) ( 11, 1) == 20); 506 assert (round!(roundingMode.DOWN) ( 19, 1) == 10); 507 assert (round!(roundingMode.HALF_UP) ( 15, 1) == 20); 508 assert (round!(roundingMode.HALF_UP) (-15, 1) == -10); 509 assert (round!(roundingMode.HALF_DOWN) ( 15, 1) == 10); 510 assert (round!(roundingMode.HALF_DOWN) ( 16, 1) == 20); 511 assert (round!(roundingMode.HALF_EVEN) ( 15, 1) == 20); 512 assert (round!(roundingMode.HALF_EVEN) ( 25, 1) == 20); 513 assert (round!(roundingMode.HALF_ODD) ( 15, 1) == 10); 514 assert (round!(roundingMode.HALF_ODD) ( 25, 1) == 30); 515 assert (round!(roundingMode.HALF_TO_ZERO) ( 25, 1) == 20); 516 assert (round!(roundingMode.HALF_TO_ZERO) ( 26, 1) == 30); 517 assert (round!(roundingMode.HALF_TO_ZERO) (-25, 1) == -20); 518 assert (round!(roundingMode.HALF_TO_ZERO) (-26, 1) == -30); 519 assert (round!(roundingMode.HALF_FROM_ZERO)( 25, 1) == 30); 520 assert (round!(roundingMode.HALF_FROM_ZERO)( 24, 1) == 20); 521 assert (round!(roundingMode.HALF_FROM_ZERO)(-25, 1) == -30); 522 assert (round!(roundingMode.HALF_FROM_ZERO)(-24, 1) == -20); 523 // dfmt on 524 } 525 526 unittest 527 { 528 import std.exception : assertThrown; 529 530 assert(round!(roundingMode.UNNECESSARY)(10, 1) == 10); 531 assertThrown!ForbiddenRounding(round!(roundingMode.UNNECESSARY)(12, 1) == 10); 532 } 533 534 /** Round a float to an integer according to rounding mode */ 535 //pure nothrow @nogc @trusted 536 real round(real x, roundingMode m) body 537 { 538 FloatingPointControl fpctrl; 539 final switch (m) with (roundingMode) 540 { 541 case UP: 542 return ceil(x); 543 case DOWN: 544 return floor(x); 545 case HALF_UP: 546 return lrint(x); 547 case HALF_DOWN: 548 fpctrl.rounding = FloatingPointControl.roundDown; 549 return lrint(x); 550 case HALF_TO_ZERO: 551 fpctrl.rounding = FloatingPointControl.roundToZero; 552 return lrint(x); 553 case HALF_EVEN: 554 case HALF_ODD: 555 case HALF_FROM_ZERO: 556 case UNNECESSARY: 557 throw new ForbiddenRounding(); 558 } 559 } 560 561 unittest { 562 assert(round(3.5, roundingMode.HALF_DOWN) == 3.0); 563 assert(round(3.8, roundingMode.HALF_TO_ZERO) == 3.0); 564 565 import std.exception : assertThrown; 566 assertThrown!ForbiddenRounding(round(3.1, roundingMode.UNNECESSARY)); 567 assertThrown!ForbiddenRounding(round(3.1, roundingMode.HALF_EVEN)); 568 assertThrown!ForbiddenRounding(round(3.1, roundingMode.HALF_ODD)); 569 assertThrown!ForbiddenRounding(round(3.1, roundingMode.HALF_FROM_ZERO)); 570 } 571 572 /** Exception is thrown if rounding would have to happen, 573 but roundingMode.UNNECESSARY is specified. */ 574 class ForbiddenRounding : Exception 575 { 576 public 577 { 578 @safe pure nothrow this(string file = __FILE__, size_t line = __LINE__, Throwable next = null) 579 { 580 super("Rounding is forbidden", file, line, next); 581 } 582 } 583 } 584 585 /** Overflow can happen with money arithmetic. */ 586 class OverflowException : Exception 587 { 588 public 589 { 590 @safe pure nothrow this(string file = __FILE__, size_t line = __LINE__, Throwable next = null) 591 { 592 super("Overflow", file, line, next); 593 } 594 } 595 }