import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from aggregate import build, qd, Distortion, Portfolio #, knobble_fonts

# knobble_fonts(True)

def my_ff(x):
  """My float formatter."""
  try:
    if x == int(x):
      return f'{x:,d}'
    elif abs(x) < 0.005:
      return f'{x:.5g}'
    elif abs(x) < 100:
      return f'{x:.3f}'
    else:
      return f'{x:,.0f}'
  except:
    return str(x)

pd.options.display.float_format = my_ff

from IPython.display import HTML, display
from greater_tables import GT

__gt_global = GT()

def nothing(df):
    return df

sd = sdp = nothing

#from greater_tables import qd

# try:
#     from greater_tables import qd
# except ModuleNotFoundError:
#     print('Greater tables not found...using IPython.display.')
#     qd = display

#sd = qd
#sdp = qd

#def sd(df):
#    # return df
#    bit =  df.style.format(formatter=lambda x: #f'{x:.4g}').set_properties(**{'text-align': 'right'})
#    print(bit.to_html())
#    return bit
#
#def sdp(df):
#    # return df
#    bit = df.style.format(formatter=lambda x: #f'{x:.1%}').set_properties(**{'text-align': 'right'})
#    print(bit.to_html())
#    return bit

gcn_namer = {'X2': 'X2 net', 'X3': 'X2 ceded', 'X': 'X2', 'X4': 'X2'}

def gcn(df):
    bit = df.loc[['X2', 'X3']]
    df.loc['X', :] = bit.sum(0)
    df.loc['X', 'LR'] = df.loc['X', 'L'] / df.loc['X', 'P']
    df.loc['X', 'PQ'] = df.loc['X', 'P'] / df.loc['X', 'Q']
    df.loc['X', 'COC'] = df.loc['X', 'M'] / df.loc['X', 'Q']
    df = df.rename(index=gcn_namer)
    df = df.sort_index()
    return df

wgs = pd.DataFrame(
    {
        'X1': [36, 40, 28, 22, 33, 32, 31, 45, 25, 25],
        'X2': [ 0,  0,  0,  0,  7,  8,  9, 10, 40, 40],
        'X3': [ 0,  0,  0,  0,  0,  0,  0,  0,  0, 35],
        'p_total': 1/10
    }
)

wport = Portfolio.create_from_sample('WGS', wgs, bs=1, log2=8)
wport.calibrate_distortions(Ps=[1], COCs=[.15])
ans = wport.price(1, allocation='linear')

PLOT_FACE_COLOR = 'white'  # '#e9e9f2'   # 'lightsteelblue'
FIGURE_BG_COLOR = '#e9e9f2'  # '#262680'   # 'aliceblue'
plt.rcParams.update({
    "axes.edgecolor": "black",
    "axes.facecolor": PLOT_FACE_COLOR,
    "axes.labelcolor": "black",
    "figure.dpi": 150,
    "figure.facecolor": FIGURE_BG_COLOR,
    "font.family": "sans-serif",
    "font.size": 10,
    "legend.edgecolor": "none",
    "legend.facecolor": PLOT_FACE_COLOR,
    "legend.labelcolor": "black",
    "text.color": 'black',  # "white",
    "xtick.color": 'black', #"white",
    "ytick.color": 'black', #"white",
})
wport.dists
{'ccoc': ccoc, 'ph': ph, 'wang': wang, 'dual': dual, 'tvar': tvar}
fig, axs = plt.subplots(1, 5, figsize=(5*2, 2), constrained_layout=True, sharey=True)
xs = np.linspace(0, 1, 101)
for (k, v), ax in zip(wport.dists.items(), axs.flat):
    if k == 'ccoc':
        ax.plot([0,1], [1/1.15]*2)
        ax.plot(1, 2, 'r*')
    else:
        ax.plot(xs, v.g_prime(xs[::-1]))
    ax.plot([0, 1], [1,1], lw=.5, c='k')
    ax.set(title=k, ylim=[-0.05, 2.5], xlabel='Percentile')
    if ax is axs[0]:
        ax.set(ylabel='Weight adjustment')

colors
{0: 'Cccoc', 1: 'Cph', 2: 'Cwang', 3: 'Cdual', 4: 'Ctvar'}
fig, axs = plt.subplots(1, 2, figsize=(2*3, 3), constrained_layout=True, sharey=True)
ax0, ax1 = axs.flat
ps = np.linspace(0, 1, 1001)

colors = {k: f'C{i}' for i, k in enumerate(wport.dists.keys())}

ax = ax0
for (k, v) in wport.dists.items():
    ax.plot(ps, v.g(ps), lw=1, label=k, c=colors[k])
ax.set(aspect='equal')  
ax.legend(loc='upper left')

ccoc = wport.dists['ccoc']
tvar = wport.dists['tvar']
min_g = Distortion.minimum([ccoc, tvar])
colors[min_g.name] = 'k'

ax = ax1
for g in [ccoc, tvar, min_g]:
    ax.plot(ps, g.g(ps), lw=2 if g is min_g else 1, 
            ls='--' if g is min_g else '-',
            label=g.name, c=colors[g.name])
ax.set(aspect='equal')  
ax.legend(loc='upper left')

min_g.name
'minimum(2)'
wport

Portfolio object: WGS

Portfolio contains 3 aggregate components. Updated with bucket size 1, log2 = 8, validation: not unreasonable
E[X] Est E[X] Err E[X] CV(X) Est CV(X) Err CV(X) Skew(X) Est Skew(X)
unit X
X1 Freq 1.0 0.0
Sev 31.700 31.700 0.0 0.215 0.215 2.2204e-15 0.456 0.456
Agg 31.700 31.700 0.0 0.215 0.215 2.2204e-15 0.456 0.456
X2 Freq 1.0 0.0
Sev 11.400 11.400 2.2204e-16 1.299 1.299 -2.2204e-16 1.253 1.253
Agg 11.400 11.400 2.2204e-16 1.299 1.299 -2.2204e-16 1.253 1.253
X3 Freq 1.0 0.0
Sev 3.500 3.500 -2.2204e-16 3.0 3.000 2.2204e-16 2.667 2.667
Agg 3.500 3.500 -2.2204e-16 3.0 3.000 2.2204e-16 2.667 2.667
total Freq 3.0 0.0
Sev 15.533 15.533 2.2204e-16 1.051 0.404
Agg 46.600 46.600 -1.1102e-16 0.416 0.416 6.6613e-16 1.001 1.001
from aggregate.extensions.pir_figures import fig_10_5
fig_10_5(port=wport, x=60)

wport.distortion_df
S L P PQ Q COC param error
a LR method
100.0 0.870 ccoc 0.0 46.600 53.565 1.154 46.435 0.150 0.150 0.0
ph 0.0 46.600 53.565 1.154 46.435 0.150 0.720 3.2978e-10
wang 0.0 46.600 53.565 1.154 46.435 0.150 0.343 1.2525e-08
dual 0.0 46.600 53.565 1.154 46.435 0.150 1.595 -3.3927e-07
tvar 0.0 46.600 53.565 1.154 46.435 0.150 0.271 7.6102e-06

bit
p_total S F
loss
22.0 0.100 0.900 0.100
28.0 0.100 0.800 0.200
36.0 0.100 0.700 0.300
40.0 0.400 0.300 0.700
55.0 0.100 0.200 0.800
65.0 0.100 0.100 0.900
100.0 0.100 0.0 1.0
wport.dists['ph'].shape
0.7204792831878889
bit = wport.density_df.query('p_total > 0')[['p_total','S']]
bit.index.name = 'loss'
bit = bit.rename(columns={'p_total': 'p'})
bit['gS'] = g.g(bit.S)
bit['q'] = -np.diff(bit.gS, prepend=1)
bit['dx'] = np.diff(bit.index, prepend=0)
bit0 = bit.copy()
bit.loc['Total', :] = [(bit0.p * bit0.index).sum(), 
                       (bit0.S.shift(1, fill_value=1) * bit0.dx).sum(),                                         (bit0.gS.shift(1, fill_value=1) * bit0.dx).sum(), 
                       (bit0.q * bit0.index).sum(), 
                       bit0.dx.sum()] 
bit
p S gS q dx
loss
22.0 0.100 0.900 0.975 0.025 22.0
28.0 0.100 0.800 0.923 0.051 6.0
36.0 0.100 0.700 0.853 0.070 8.0
40.0 0.400 0.300 0.434 0.420 4.0
55.0 0.100 0.200 0.299 0.134 15.0
65.0 0.100 0.100 0.155 0.145 10.0
100.0 0.100 0.0 0.0 0.155 35.0
Total 46.600 46.600 53.565 53.565 100.0
bit0.S.shift(1, fill_value=1)
loss
22.0      1.0
28.0    0.900
36.0    0.800
40.0    0.700
55.0    0.300
65.0    0.200
100.0   0.100
Name: S, dtype: float64
(bit.q * bit.index).sum()
53.565217052032935
g.shape
1.5951514670652984