1 /******
2  * Handling amounts of money safely and efficiently.
3  *
4  * An amount of money is a number tagged with a currency id like "EUR"
5  * or "USD". Precision and rounding mode can be chosen as template
6  * parameters.
7  *
8  * If you write code which handles money, you have to choose a data type
9  * for it. Out of the box, D offers you floating point, integer, and
10  * std.bigint. All of these have their problems.
11  *
12  * Floating point is inherently imprecise. If your dollar numbers become
13  * too big, then you start getting too much or too little cents. This
14  * is not acceptable as the errors accumulate. Also, floating point has
15  * values like "infinity" and "not a number" and if those show up,
16  * usually things break, if you did not prepare for it. Debugging then
17  * means to work backwards how this happened, which is tedious and hard.
18  *
19  * Integer numbers do not suffer from imprecision, but they can not
20  * represent numbers as big as floating point. Worse, if your numbers
21  * become too big, then your CPU silently wraps them into negative
22  * numbers. Like the imprecision with floating point, your data is
23  * now corrupted without anyone noticing it yet. Also, fixed point
24  * arithmetic with integers is easy to get wrong and you need a
25  * fractional part to represent cents, for example.
26  *
27  * As a third option, there is std.bigint, which provides numbers
28  * with arbitrary precision. Like floating point, the arithmetic is easy.
29  * Like integer, precision is fine. The downside is performance.
30  * Nevertheless, from the three options, this is the most safe one.
31  *
32  * Can we do even better?
33  * If we design a custom data type for money, we can improve safety
34  * even more. For example, certain arithmetics can be forbidden. What
35  * does it mean to multiply two money amounts, for example? There is no
36  * such thing as $² which makes any sense. However, you can certainly
37  * multiply a money amount with a unitless number. A custom data type
38  * can precisely allow and forbid this operations.
39  *
40  * Here the design decision is to use an integer for the internal
41  * representation. This limits the amounts you can use. For example,
42  * if you decide to use 4 digits behind the comma, the maximum number
43  * is 922,337,203,685,477.5807 or roughly 922 trillion. The US debt is
44  * currently in the trillions, so there are certainly cases where
45  * this representation is not applicable. However, we can check overflow,
46  * so if it happens, you get an exception thrown and notice it
47  * right away. The upside of using an integer is performance and
48  * a deterministic arithmetic all programmers are familiar with.
49  *
50  * License: $(LINK2 http://www.boost.org/LICENSE_1_0.txt, Boost License 1.0)
51  * Authors: Andreas Zwinkau
52  */
53 module money;
54 
55 import std.math : floor, ceil, lrint, abs, FloatingPointControl;
56 import std.conv : to;
57 import core.checkedint : adds, subs, muls, negs;
58 import std.format : FormatSpec, formattedWrite;
59 import std.traits : hasMember;
60 
61 @nogc pure @safe nothrow private long pow10(int x)
62 {
63     if (x <= 0)
64         return 1;
65     return 10 * pow10(x - 1);
66 }
67 
68 /** Holds an amount of money **/
69 struct money(string currency, int dec_places = 4, roundingMode rmode = roundingMode.HALF_UP)
70 {
71     alias T = typeof(this);
72     enum __currency = currency;
73     enum __dec_places = dec_places;
74     enum __rmode = rmode;
75     long amount;
76 
77     /// Usual contructor. Uses rmode on x.
78     this(double x)
79     {
80         amount = to!long(round(x * pow10(dec_places), rmode));
81     }
82 
83     private static T fromLong(long a)
84     {
85         T ret = void;
86         ret.amount = a;
87         return ret;
88     }
89 
90     /// default initialisation value is zero
91     static immutable init = fromLong(0L);
92     /// maximum amount depends on dec_places
93     static immutable max = fromLong(long.max);
94     /// minimum amount depends on dec_places
95     static immutable min = fromLong(long.min);
96 
97     private static immutable dec_mask = pow10(dec_places);
98 
99     /// Can add and subtract money amounts of the same type.
100     T opBinary(string op)(const T rhs) const
101     {
102         static if (op == "+")
103         {
104             bool overflow;
105             auto ret = fromLong(adds(amount, rhs.amount, overflow));
106             if (overflow)
107                 throw new OverflowException();
108             return ret;
109         }
110         else static if (op == "-")
111         {
112             bool overflow;
113             auto ret = fromLong(subs(amount, rhs.amount, overflow));
114             if (overflow)
115                 throw new OverflowException();
116             return ret;
117         }
118         else
119             static assert(0, "Operator " ~ op ~ " not implemented");
120     }
121 
122     /// Can multiply, divide, and modulo with integer values.
123     T opBinary(string op)(const long rhs) const
124     {
125         static if (op == "*")
126         {
127             bool overflow;
128             auto ret = fromLong(muls(amount, rhs, overflow));
129             if (overflow)
130                 throw new OverflowException();
131             return ret;
132         }
133         else static if (op == "/")
134         {
135             return fromLong(amount / rhs);
136         }
137         else static if (op == "%")
138         {
139             const intpart = amount / pow10(dec_places);
140             return fromLong(intpart % rhs * pow10(dec_places));
141         }
142         // TODO support * / % ? Might be useful for taxes etc.
143         else
144             static assert(0, "Operator " ~ op ~ " not implemented");
145     }
146 
147     /// Can multiply, divide, and modulo floating point numbers.
148     T opBinary(string op)(const real rhs) const
149     {
150         static if (op == "*")
151         {
152             const converted = T(rhs);
153             bool overflow = false;
154             const result = muls(amount, converted.amount, overflow);
155             if (overflow)
156                 throw new OverflowException();
157             return fromLong(result / pow10(dec_places));
158         }
159         else static if (op == "/")
160         {
161             const converted = T(rhs);
162             bool overflow = false;
163             auto mult = muls(amount, pow10(dec_places), overflow);
164             if (overflow)
165                 throw new OverflowException();
166             return fromLong(mult / converted.amount);
167         }
168         else static if (op == "%")
169         {
170             const converted = T(rhs);
171             return fromLong(amount % converted.amount);
172         }
173         // TODO support * / % ? Might be useful for taxes etc.
174         else
175             static assert(0, "Operator " ~ op ~ " not implemented");
176     }
177 
178     /// Can check equality with money amounts of the same concurrency and decimal places.
179     bool opEquals(OT)(auto ref const OT other) const if (isMoney!OT
180             && other.__currency == currency && other.__dec_places == dec_places)
181     {
182         return other.amount == amount;
183     }
184 
185     /// Can compare with money amounts of the same concurrency.
186     int opCmp(OT)(const OT other) const if (isMoney!OT && other.__currency == currency)
187     {
188         static if (dec_places == other.__dec_places)
189         {
190             if (other.amount > this.amount)
191                 return -1;
192             if (other.amount < this.amount)
193                 return 1;
194             return 0;
195         }
196         else static if (dec_places < other.__dec_places)
197         {
198             /* D implicitly makes this work for case '>' */
199             auto nthis = this * pow10(other.__dec_places - dec_places);
200             /* overflow check included */
201             if (other.amount > nthis.amount)
202                 return -1;
203             if (other.amount < nthis.amount)
204                 return 1;
205             return 0;
206         }
207         else
208             static assert(0, "opCmp with such 'other' not implemented");
209     }
210 
211     /// Can convert to string.
212     void toString(scope void delegate(const(char)[]) sink, FormatSpec!char fmt) const
213     {
214         switch (fmt.spec)
215         {
216         case 's': /* default e.g. for writeln */
217             goto case;
218         case 'f':
219             formattedWrite(sink, "%d", (amount / dec_mask));
220             sink(".");
221             auto decimals = amount % dec_mask;
222             if (fmt.precision < dec_places)
223             {
224                 auto n = dec_places - fmt.precision;
225                 decimals = round!(rmode)(decimals, n);
226                 decimals = decimals / pow10(n);
227             }
228             formattedWrite(sink, "%d", decimals);
229             sink(currency);
230             break;
231         case 'd':
232             auto ra = round!rmode(amount, dec_places);
233             formattedWrite(sink, "%d", (ra / dec_mask));
234             sink(currency);
235             break;
236         default:
237             throw new Exception("Unknown format specifier: %" ~ fmt.spec);
238         }
239     }
240 }
241 
242 /// Basic usage
243 unittest
244 {
245     alias EUR = money!("EUR");
246     assert(EUR(100.0001) == EUR(100.00009));
247     assert(EUR(3.10) + EUR(1.40) == EUR(4.50));
248     assert(EUR(3.10) - EUR(1.40) == EUR(1.70));
249     assert(EUR(10.01) * 1.1 == EUR(11.011));
250 
251     import std.format : format;
252 
253     // for writefln("%d", EUR(3.6));
254     assert(format("%d", EUR(3.6)) == "4EUR");
255     assert(format("%d", EUR(3.1)) == "3EUR");
256     // for writefln("%f", EUR(3.141592));
257     assert(format("%f", EUR(3.141592)) == "3.1416EUR");
258     assert(format("%.2f", EUR(3.145)) == "3.15EUR");
259 }
260 
261 /// Overflow is an error, since silent corruption is worse
262 unittest
263 {
264     import std.exception : assertThrown;
265 
266     alias EUR = money!("EUR");
267     auto one = EUR(1);
268     assertThrown!OverflowException(EUR.max + one);
269     assertThrown!OverflowException(EUR.min - one);
270 }
271 
272 /// Arithmetic ignores rounding mode
273 unittest
274 {
275     alias EUR = money!("EUR", 2, roundingMode.UP);
276     auto one = EUR(1);
277     assert(one != one / 3);
278 }
279 
280 /// Generic equality and order
281 unittest
282 {
283     alias USD = money!("USD", 2);
284     alias EURa = money!("EUR", 2);
285     alias EURb = money!("EUR", 4);
286     alias EURc = money!("EUR", 4, roundingMode.DOWN);
287     // cannot compile with different currencies
288     static assert(!__traits(compiles, EURa(1) == USD(1)));
289     // cannot compile with different dec_places
290     static assert(!__traits(compiles, EURa(1) == EURb(1)));
291     // can check equality if only rounding mode differs
292     assert(EURb(1.01) == EURc(1.01));
293     // cannot compare with different currencies
294     static assert(!__traits(compiles, EURa(1) < USD(1)));
295 }
296 
297 
298 // TODO Using negative dec_places for big numbers?
299 //unittest
300 //{
301 //    alias USD = money!("USD", -6);
302 //    assert(USD(1_000_000.00) == USD(1_100_000.));
303 //}
304 
305 enum isMoney(T) = (hasMember!(T, "amount") && hasMember!(T, "__dec_places")
306         && hasMember!(T, "__rmode"));
307 static assert(isMoney!(money!"EUR"));
308 
309 unittest {
310     alias EUR = money!("EUR");
311     import std.format : format;
312     assert(format("%s", EUR(3.1)) == "3.1000EUR");
313 
314     import std.exception : assertThrown;
315     assertThrown!Exception(format("%x", EUR(3.1)));
316 }
317 
318 unittest
319 {
320     alias EUR = money!("EUR");
321     assert(EUR(5) < EUR(6));
322     assert(EUR(6) > EUR(5));
323     assert(EUR(5) >= EUR(5));
324     assert(EUR(5) == EUR(5));
325     assert(EUR(6) != EUR(5));
326 
327     import std.exception : assertThrown;
328     assertThrown!OverflowException(EUR.max * 2);
329     assertThrown!OverflowException(EUR.max * 2.0);
330 }
331 
332 unittest
333 {
334     alias EUR = money!("EUR");
335     auto x = EUR(42);
336     assert(EUR(84) == x * 2);
337     static assert(!__traits(compiles, x * x));
338     assert(EUR(21) == x / 2);
339     assert(EUR(2) == x % 4);
340 }
341 
342 unittest
343 {
344     alias EURa = money!("EUR", 2);
345     alias EURb = money!("EUR", 4);
346     auto x = EURa(1.01);
347     assert(x > EURb(1.0001));
348     assert(x < EURb(1.0101));
349     assert(x <= EURb(1.01));
350 }
351 
352 /** Specifies rounding behavior **/
353 enum roundingMode
354 {
355     // dfmt off
356     /** Round upwards, e.g. 3.1 up to 4. */
357     UP,
358     /** Round downwards, e.g. 3.9 down to 3. */
359     DOWN,
360     /** Round to nearest number, half way between round up, e.g. 3.5 to 4. */
361     HALF_UP,
362     /** Round to nearest number, half way between round dow, e.g. 3.5 to 3.  */
363     HALF_DOWN,
364     /** Round to nearest number, half way between round to even number, e.g. 3.5 to 4. */
365     HALF_EVEN,
366     /** Round to nearest number, half way between round to odd number, e.g. 3.5 to 3. */
367     HALF_ODD,
368     /** Round to nearest number, half way between round towards zero, e.g. -3.5 to -3.  */
369     HALF_TO_ZERO,
370     /** Round to nearest number, half way between round away from zero, e.g. -3.5 to -4.  */
371     HALF_FROM_ZERO,
372     /** Throw exception if rounding would be necessary */
373     UNNECESSARY
374     // dfmt on
375 }
376 
377 /** Round an integer to a certain decimal place according to rounding mode */
378 long round(roundingMode m)(long x, int dec_place)
379 out (result)
380 {
381     assert((result % pow10(dec_place)) == 0);
382 }
383 body
384 {
385     const zeros = pow10(dec_place);
386     /* short cut, also removes edge cases */
387     if ((x % zeros) == 0)
388         return x;
389 
390     const half = zeros / 2;
391     with (roundingMode)
392     {
393         static if (m == UP)
394         {
395             return ((x / zeros) + 1) * zeros;
396         }
397         else static if (m == DOWN)
398         {
399             return x / zeros * zeros;
400         }
401         else static if (m == HALF_UP)
402         {
403             if ((x % zeros) >= half)
404                 return ((x / zeros) + 1) * zeros;
405             else
406                 return x / zeros * zeros;
407         }
408         else static if (m == HALF_DOWN)
409         {
410             if ((x % zeros) > half)
411                 return ((x / zeros) + 1) * zeros;
412             else
413                 return x / zeros * zeros;
414         }
415         else static if (m == HALF_EVEN)
416         {
417             const down = x / zeros;
418             if (down % 2 == 0)
419                 return down * zeros;
420             else
421                 return (down + 1) * zeros;
422         }
423         else static if (m == HALF_ODD)
424         {
425             const down = x / zeros;
426             if (down % 2 == 0)
427                 return (down + 1) * zeros;
428             else
429                 return down * zeros;
430         }
431         else static if (m == HALF_TO_ZERO)
432         {
433             const down = x / zeros;
434             if (down < 0)
435             {
436                 if (abs(x % zeros) <= half)
437                 {
438                     return (down) * zeros;
439                 }
440                 else
441                 {
442                     return (down - 1) * zeros;
443                 }
444             }
445             else
446             {
447                 if ((x % zeros) > half)
448                 {
449                     return (down + 1) * zeros;
450                 }
451                 else
452                 {
453                     return (down) * zeros;
454                 }
455             }
456         }
457         else static if (m == HALF_FROM_ZERO)
458         {
459             const down = x / zeros;
460             if (down < 0)
461             {
462                 if (abs(x % zeros) < half)
463                 {
464                     return (down) * zeros;
465                 }
466                 else
467                 {
468                     return (down - 1) * zeros;
469                 }
470             }
471             else
472             {
473                 if (x % zeros >= half)
474                 {
475                     return (down + 1) * zeros;
476                 }
477                 else
478                 {
479                     return (down) * zeros;
480                 }
481             }
482         }
483         else static if (m == UNNECESSARY)
484         {
485             throw new ForbiddenRounding();
486         }
487     }
488 }
489 
490 // dfmt off
491 ///
492 unittest
493 {
494     assert (round!(roundingMode.DOWN)     (1009, 1) == 1000);
495     assert (round!(roundingMode.UP)       (1001, 1) == 1010);
496     assert (round!(roundingMode.HALF_UP)  (1005, 1) == 1010);
497     assert (round!(roundingMode.HALF_DOWN)(1005, 1) == 1000);
498 }
499 // dfmt on
500 
501 @safe pure @nogc nothrow unittest
502 {
503     // dfmt off
504     assert (round!(roundingMode.HALF_UP)       ( 10, 1) ==  10);
505     assert (round!(roundingMode.UP)            ( 11, 1) ==  20);
506     assert (round!(roundingMode.DOWN)          ( 19, 1) ==  10);
507     assert (round!(roundingMode.HALF_UP)       ( 15, 1) ==  20);
508     assert (round!(roundingMode.HALF_UP)       (-15, 1) == -10);
509     assert (round!(roundingMode.HALF_DOWN)     ( 15, 1) ==  10);
510     assert (round!(roundingMode.HALF_DOWN)     ( 16, 1) ==  20);
511     assert (round!(roundingMode.HALF_EVEN)     ( 15, 1) ==  20);
512     assert (round!(roundingMode.HALF_EVEN)     ( 25, 1) ==  20);
513     assert (round!(roundingMode.HALF_ODD)      ( 15, 1) ==  10);
514     assert (round!(roundingMode.HALF_ODD)      ( 25, 1) ==  30);
515     assert (round!(roundingMode.HALF_TO_ZERO)  ( 25, 1) ==  20);
516     assert (round!(roundingMode.HALF_TO_ZERO)  ( 26, 1) ==  30);
517     assert (round!(roundingMode.HALF_TO_ZERO)  (-25, 1) == -20);
518     assert (round!(roundingMode.HALF_TO_ZERO)  (-26, 1) == -30);
519     assert (round!(roundingMode.HALF_FROM_ZERO)( 25, 1) ==  30);
520     assert (round!(roundingMode.HALF_FROM_ZERO)( 24, 1) ==  20);
521     assert (round!(roundingMode.HALF_FROM_ZERO)(-25, 1) == -30);
522     assert (round!(roundingMode.HALF_FROM_ZERO)(-24, 1) == -20);
523     // dfmt on
524 }
525 
526 unittest
527 {
528     import std.exception : assertThrown;
529 
530     assert(round!(roundingMode.UNNECESSARY)(10, 1) == 10);
531     assertThrown!ForbiddenRounding(round!(roundingMode.UNNECESSARY)(12, 1) == 10);
532 }
533 
534 /** Round a float to an integer according to rounding mode */
535 //pure nothrow @nogc @trusted
536 real round(real x, roundingMode m) body
537 {
538     FloatingPointControl fpctrl;
539     final switch (m) with (roundingMode)
540     {
541     case UP:
542         return ceil(x);
543     case DOWN:
544         return floor(x);
545     case HALF_UP:
546         return lrint(x);
547     case HALF_DOWN:
548         fpctrl.rounding = FloatingPointControl.roundDown;
549         return lrint(x);
550     case HALF_TO_ZERO:
551         fpctrl.rounding = FloatingPointControl.roundToZero;
552         return lrint(x);
553     case HALF_EVEN:
554     case HALF_ODD:
555     case HALF_FROM_ZERO:
556     case UNNECESSARY:
557         throw new ForbiddenRounding();
558     }
559 }
560 
561 unittest {
562     assert(round(3.5, roundingMode.HALF_DOWN) == 3.0);
563     assert(round(3.8, roundingMode.HALF_TO_ZERO) == 3.0);
564 
565     import std.exception : assertThrown;
566     assertThrown!ForbiddenRounding(round(3.1, roundingMode.UNNECESSARY));
567     assertThrown!ForbiddenRounding(round(3.1, roundingMode.HALF_EVEN));
568     assertThrown!ForbiddenRounding(round(3.1, roundingMode.HALF_ODD));
569     assertThrown!ForbiddenRounding(round(3.1, roundingMode.HALF_FROM_ZERO));
570 }
571 
572 /** Exception is thrown if rounding would have to happen,
573     but roundingMode.UNNECESSARY is specified. */
574 class ForbiddenRounding : Exception
575 {
576     public
577     {
578         @safe pure nothrow this(string file = __FILE__, size_t line = __LINE__, Throwable next = null)
579         {
580             super("Rounding is forbidden", file, line, next);
581         }
582     }
583 }
584 
585 /** Overflow can happen with money arithmetic. */
586 class OverflowException : Exception
587 {
588     public
589     {
590         @safe pure nothrow this(string file = __FILE__, size_t line = __LINE__, Throwable next = null)
591         {
592             super("Overflow", file, line, next);
593         }
594     }
595 }