diff --git a/src/maths/diagrams/sets.py b/src/maths/diagrams/sets.py index d6effbb..a92546a 100644 --- a/src/maths/diagrams/sets.py +++ b/src/maths/diagrams/sets.py @@ -66,18 +66,24 @@ class Function(Generic[T1,T2]): return Functions(self).draw(); class Functions: + name: str; fcts: list[Function]; def __init__(self, *f: Function): self.fcts = list(f); + self.name = r' \circ '.join([ fct.name[0] for fct in self.fcts ][::-1]); def draw(self, show_labels: bool = True) -> Figure: N = len(self.fcts); obj = mplot.subplots(1, 1, constrained_layout=True); fig: Figure = obj[0]; axs: Axes = obj[1]; - axs.tick_params(axis='both', which='both', left=False, right=False, top=False, bottom=False, labelbottom=False, labelleft=False); - mplot.title(''); + axs.tick_params(axis='both', which='both', labelbottom=False, labelleft=False); + mplot.title(f'Darstellung von ${self.name}$', fontdict={ + 'fontsize': 16, + 'horizontalalignment': 'center', + 'color': 'forestgreen', + }); mplot.xlabel(''); mplot.ylabel(''); mplot.margins(x=MARGIN, y=MARGIN); @@ -127,7 +133,7 @@ class Functions: for j, p in enumerate(p_codomain): y = f.codomain[j]; - marker = 'o' if (y in comp_range_next) else 'x'; + marker = 'o' if (y in comp_range_next) else '.'; axs.scatter([p[0]], [p[1]], label='', color='black', marker=marker); y_name = f.codomain[j]; if show_labels: @@ -138,16 +144,16 @@ class Functions: q = p_codomain[j]; x = f.domain[i]; if k == 0 or (x in comp_range): - axs.plot([p[0], q[0]], [p[1], q[1]], label='', color='g', linewidth=2); + axs.plot([p[0], q[0]], [p[1], q[1]], label='', color='black', linewidth=1); else: - axs.plot([p[0], q[0]], [p[1], q[1]], label='', color='g', linestyle='--', linewidth=1); + axs.plot([p[0], q[0]], [p[1], q[1]], label='', color='red', linestyle='--', linewidth=1); anchor = anchors[k]; fct_name, X_name, Y_name = f.name; - axs.annotate(text=f'{fct_name}', xy=anchor[0], ha='center', size=FONTSIZE_FCT); + axs.annotate(text=f'${fct_name}$', xy=anchor[0], ha='center', size=FONTSIZE_FCT); if k == 0: - axs.annotate(text=f'{X_name}', xy=anchor[1], ha='center', size=FONTSIZE_FCT); - axs.annotate(text=f'{Y_name}', xy=anchor[2], ha='center', size=FONTSIZE_FCT); + axs.annotate(text=f'${X_name}$', xy=anchor[1], ha='center', size=FONTSIZE_FCT); + axs.annotate(text=f'${Y_name}$', xy=anchor[2], ha='center', size=FONTSIZE_FCT); axs.add_patch(FancyArrowPatch( anchor[3], anchor[4], connectionstyle = 'arc3,rad=0.5', @@ -176,8 +182,6 @@ def oval( P[-1, :] = P[0, :]; return P; - - def random_points( nr_points: int, scale: tuple[float, float] = (1., 1.), diff --git a/src/maths/sets/random.py b/src/maths/sets/random.py index f0c4561..c3ac611 100644 --- a/src/maths/sets/random.py +++ b/src/maths/sets/random.py @@ -34,16 +34,19 @@ T2 = TypeVar('T2'); # METHODS # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -def randomset_integers(low: int, high: int) -> list[int]: - N = random.randint(low, high); +def randomset_integers(N: int = -1, low: int = 1, high: int = 1) -> list[int]: + if N == -1: + N = random.randint(low, high); return list(range(1, N+1)); -def randomset_alphabet(low: int, high: int) -> list[int]: - N = random.randint(low, high); +def randomset_alphabet(N: int = -1, low: int = 1, high: int = 1) -> list[int]: + if N == -1: + N = random.randint(low, high); return list([a for k, a in enumerate(ALPHA) if k < N]); -def randomset_greek(low: int, high: int) -> list[int]: - N = random.randint(low, high); +def randomset_greek(N: int = -1, low: int = 1, high: int = 1) -> list[int]: + if N == -1: + N = random.randint(low, high); return list([a for k, a in enumerate(GREEK) if k < N]); def random_function( @@ -52,11 +55,39 @@ def random_function( injective: Optional[bool] = None, surjective: Optional[bool] = None, ) -> list[tuple[T1, T2]]: - # TODO: add feature to force injectivity/surjectivity, if possible. - # m = len(X); - # n = len(Y); - # if m > n: - # injective = False; - # if m < n: - # surjective = False; - return [ (x, random.choice(Y)) for x in X ]; + m = len(X); + n = len(Y); + if m == 0: + return []; + if n == 0: + raise Exception(f'Impossible to create a function with {m} elements in the domain and {n} in the codomain.'); + match (injective, surjective): + case (True, _): + assert m <= n, f'Impossible to create an injective function with {m} elements in the domain and {n} in the codomain.'; + Y = random.sample(Y, m); + return [(x, y) for x, y in zip(X, Y)]; + case (_, True): + assert m >= n, f'Impossible to create an surjective function with {m} elements in the domain and {n} in the codomain.'; + indexes = random.sample(list(range(m)), n); + g = [ (indexes[j], Y[j]) for j in range(n) ] \ + + [ + (i, random.choice(Y)) + for i in range(m) + if not i in indexes + ]; + g = sorted(g, key=lambda o: o[0]); + return [ (X[i], y) for (i, y) in g ]; + case (False, _): + assert m > 1, f'Impossible to create a non-injective function with {m} elements in the domain.'; + indexes = random.sample(list(range(m)), m); + g = random_function(indexes, Y); + [(i0, y0), (i1, y1)] = g[:2]; + g[0] = (i0, y1); + g = sorted(g, key=lambda o: o[0]); + return [ (X[i], y) for (i, y) in g ]; + case (_, False): + assert n > 1, f'Impossible to create a non-surjective function with {n} elements in the codomain.'; + Y = random.sample(Y, n-1); + return random_function(X, Y); + case _: + return [ (x, random.choice(Y)) for x in X ];