2022-10-20 16:21:50 +02:00

206 lines
7.0 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# IMPORTS
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
from __future__ import annotations;
from src.thirdparty.code import *;
from src.thirdparty.types import *;
from src.thirdparty.maths import *;
from src.thirdparty.plots import *;
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# EXPORTS
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
__all__ = [
'Function',
'Functions',
];
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# CONSTANTS / VARIABLES
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
T1 = TypeVar('T1');
T2 = TypeVar('T2');
SCALE = (1., 4.);
OFFSET = (3., 0.);
MARGIN = 0.1;
N_RESOLUTION = 100;
ANNOTATE_OFFSET = (0, 10);
FONTSIZE_PTS = 10;
FONTSIZE_FCT = 14;
FONTSIZE_SETS = 14;
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@dataclass
class Function(Generic[T1,T2]):
name: tuple[str, str, str] = field();
domain: list[T1] = field();
codomain: list[T2] = field();
fct: list[tuple[T1,T2]] = field();
@property
def range(self) -> list[T2]:
return [y for x, y in self.fct];
@property
def indexes(self) -> list[tuple[int, int]]:
# prevent repeated computation:
if not hasattr(self, '_indexes'):
self._indexes = [
(self.domain.index(x), self.codomain.index(y))
for x, y in self.fct
];
return getattr(self, '_indexes');
def draw(self) -> Figure:
return Functions(self).draw();
class Functions:
fcts: list[Function];
def __init__(self, *f: Function):
self.fcts = list(f);
def draw(self, show_labels: bool = True) -> Figure:
N = len(self.fcts);
obj = mplot.subplots(1, 1, constrained_layout=True);
fig: Figure = obj[0];
axs: Axes = obj[1];
axs.tick_params(axis='both', which='both', left=False, right=False, top=False, bottom=False, labelbottom=False, labelleft=False);
mplot.title('');
mplot.xlabel('');
mplot.ylabel('');
mplot.margins(x=MARGIN, y=MARGIN);
origin = np.asarray((0., 0.));
offset = np.asarray(OFFSET);
p_set = oval(nr_points=N_RESOLUTION, scale=SCALE, centre=origin);
for k in range(N+1):
axs.plot(p_set[:, 0] + k*offset[0], p_set[:, 1] + k*offset[1], label='', color='blue');
p_domain = [];
p_codomain = [];
comp_range = [];
anchors = [
[
# function name
origin + (k + 0.5)*offset + (0, -1.1*SCALE[1]),
# sets
origin + k * offset + (0, 1.1 * SCALE[1]),
origin + (k + 1) * offset + (0, 1.1 * SCALE[1]),
# arrow start -> end
origin + (k + 0.05) * offset + (0, -1.1 * SCALE[1]),
origin + (k + 1 - 0.05) * offset + (0, -1.1 * SCALE[1]),
]
for k in range(N)
];
for k, f in enumerate(self.fcts):
if k == 0:
comp_range = f.domain;
p_domain = random_points(nr_points=len(f.domain), scale=SCALE, centre=origin + k*offset);
else:
p_domain = p_codomain;
p_codomain = random_points(nr_points=len(f.codomain), scale=SCALE, centre=origin + (k+1)*offset);
# range of composition so far:
comp_range_next = [y for x, y in f.fct if x in comp_range];
if k == 0:
axs.scatter(p_domain[:, 0], p_domain[:, 1], label='', color='black', marker='o');
if show_labels:
for i, p in enumerate(p_domain):
x_name = f.domain[i];
axs.annotate(text=f'{x_name}', xy = p, textcoords='offset points', xytext=ANNOTATE_OFFSET, ha='center', size=FONTSIZE_PTS);
for j, p in enumerate(p_codomain):
y = f.codomain[j];
marker = 'o' if (y in comp_range_next) else 'x';
axs.scatter([p[0]], [p[1]], label='', color='black', marker=marker);
y_name = f.codomain[j];
if show_labels:
axs.annotate(text=f'{y_name}', xy=p, textcoords='offset points', xytext=ANNOTATE_OFFSET, ha='center', size=FONTSIZE_PTS);
for i, j in f.indexes:
p = p_domain[i];
q = p_codomain[j];
x = f.domain[i];
if k == 0 or (x in comp_range):
axs.plot([p[0], q[0]], [p[1], q[1]], label='', color='g', linewidth=2);
else:
axs.plot([p[0], q[0]], [p[1], q[1]], label='', color='g', linestyle='--', linewidth=1);
anchor = anchors[k];
fct_name, X_name, Y_name = f.name;
axs.annotate(text=f'{fct_name}', xy=anchor[0], ha='center', size=FONTSIZE_FCT);
if k == 0:
axs.annotate(text=f'{X_name}', xy=anchor[1], ha='center', size=FONTSIZE_FCT);
axs.annotate(text=f'{Y_name}', xy=anchor[2], ha='center', size=FONTSIZE_FCT);
axs.add_patch(FancyArrowPatch(
anchor[3], anchor[4],
connectionstyle = 'arc3,rad=0.5',
arrowstyle = 'Simple, tail_width=0.5, head_width=4, head_length=8',
color = 'black',
));
# update range of composition:
comp_range = comp_range_next;
return fig;
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# AUXILIARY
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def oval(
nr_points: int,
scale: tuple[float, float] = (1., 1.),
centre: tuple[float, float] = (0., 0.),
) -> NDArray[Shape['*, 2'], Float]:
theta = np.linspace(start=0, stop=2*np.pi, num=nr_points, endpoint=True);
P = np.zeros(shape=(nr_points, 2), dtype=float);
P[:, 0] = centre[0] + scale[0] * np.cos(theta);
P[:, 1] = centre[1] + scale[1] * np.sin(theta);
P[-1, :] = P[0, :];
return P;
def random_points(
nr_points: int,
scale: tuple[float, float] = (1., 1.),
centre: tuple[float, float] = (0., 0.),
force: bool = False,
tol: float = 0.2,
) -> NDArray[Shape['*, 2'], Float]:
theta = np.linspace(start=0, stop=2*np.pi, num=nr_points, endpoint=False);
r_min = 0.25;
r_max = 1;
while True:
u = np.random.random(size=(nr_points,));
u_max = max(u);
if u_max == 0.:
continue;
if force:
u = np.minimum((1 + tol) * u / u_max, 1);
else:
u = (1 - tol) * u / u_max;
break;
r = r_min + (r_max - r_min) * u;
P = np.zeros(shape=(nr_points, 2), dtype=float);
P[:, 0] = centre[0] + scale[0] * r * np.cos(theta);
P[:, 1] = centre[1] + scale[1] * r * np.sin(theta);
return P;