solver.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. from sympy import *
  2. from itertools import product, combinations
  3. import plotly.graph_objects as go
  4. import numpy as np
  5. import math
  6. class solver:
  7. corners = (-100, 100)
  8. data: list[str]
  9. equalations: list[Equality]
  10. sequance = None
  11. solutions: list
  12. points: list
  13. ndims: int
  14. __X = [*symbols('x1 x2 x3')]
  15. @staticmethod
  16. def toEq(data):
  17. data = data[:]
  18. for i,linEx in enumerate(data):
  19. data[i] = Eq(*[simplify(side) for side in linEx.split('=')])
  20. return data
  21. def solve(self):
  22. result = []
  23. for Eq in self.equalations:
  24. lin = []
  25. for prod in product([-100, 100], repeat=self.ndims-1):
  26. subEq = Eq.copy()
  27. X = self.__X[:]
  28. high_sym = sorted(list(subEq.free_symbols), key=lambda x: x.name)[0]
  29. X.remove(high_sym)
  30. values = [(sym,corner) for sym, corner in zip(X, prod)]
  31. subEq = subEq.subs(values)
  32. solution = int(solve(subEq, high_sym)[0])
  33. values.append((high_sym, solution))
  34. lin.append(sorted(values, key=lambda x: x[0].name))
  35. result.append([[dot[dim][1] for dot in lin] for dim in range(self.ndims)])
  36. return result
  37. def right_dote(self, dote):
  38. flag = True
  39. for line in self.data:
  40. for sym, val in zip(self.__X, dote): line = line.replace(sym.name, str(val))
  41. flag *= eval(line)
  42. return flag
  43. def get_dots(self):
  44. result = []
  45. for Eqs in combinations(self.equalations, r=2):
  46. if Eqs[0] == Eqs[1]: continue
  47. solution = list(solve(Eqs, Eqs[0].free_symbols | Eqs[1].free_symbols, set=True))[1]
  48. if len(solution) == 0: continue
  49. dot = list(solve(Eqs, Eqs[0].free_symbols | Eqs[1].free_symbols, set=True)[1])[0]
  50. if self.right_dote(dot): result.append(dot)
  51. reference_point = result[0]
  52. sorted_coordinates = sorted(result, key=lambda point: math.atan2(point[1] - reference_point[1], point[0] - reference_point[0]))
  53. return [[float(val[dim]) for val in sorted_coordinates] for dim in range(self.ndims)]
  54. def show(self):
  55. fig = go.Figure()
  56. for line, names in zip(self.solutions, self.data):
  57. fig.add_trace(go.Scatter({dim:val for val, dim in zip(line, ('x','y','z'))}, name=str(names)))
  58. fig.add_trace(go.Scatter({dim:val for val, dim in zip(self.get_dots(), ('x','y','z'))}, mode='markers', fill='toself', fillpattern=dict(fillmode='replace', shape='x')))
  59. fig.add_trace(go.Scatter(x=[0, self.gradient[0]], y=[0, self.gradient[1]],
  60. marker=dict(color='black', symbol='arrow', size=16, angleref="previous"),
  61. line = dict(width=4, dash='dot', color='black')))
  62. touch = len(fig.data)
  63. for step in np.arange(0, self.count, self.step):
  64. k = ((self.gradient[1]-0) * (step-0) - (self.gradient[1]-0) * (0-0)) / ((self.gradient[1]-0)**2 + (self.gradient[0]-0)**2)
  65. x4 = step - k * (self.gradient[1]-0)
  66. y4 = 0 + k * (self.gradient[0]-0)
  67. y5 = y4+y4
  68. x5 = x4+(x4-step)
  69. fig.add_trace(
  70. go.Scatter(visible=False, line=dict(color='black', width=2),
  71. x=[step, x4, x5], y=[0, y4, y5])
  72. )
  73. fig.data[touch].visible = True
  74. steps = []
  75. for i in range(len(fig.data[touch:])):
  76. step = dict(
  77. method="update",
  78. args=[{"visible": [True]*touch + [False] * (len(fig.data)-touch)},
  79. {"title": "Slider switched to step: " + str(i)}], # layout attribute
  80. )
  81. step["args"][0]["visible"][i] = True # Toggle i'th trace to "visible"
  82. steps.append(step)
  83. sliders = [dict(
  84. active=10,
  85. currentvalue={"prefix": "Frequency: "},
  86. pad={"t": 50},
  87. steps=steps
  88. )]
  89. fig.update_layout(
  90. sliders=sliders
  91. )
  92. fig.update_xaxes(title_text='x1', gridwidth=1)
  93. fig.update_yaxes(title_text='x2', gridwidth=1)
  94. fig.show()
  95. def __init__(self, seq: str, data: list[str], ndims=2, step=0.01, count=10):
  96. self.data = data
  97. self.gradient = list(map(int,Poly(simplify(seq)).coeffs()))
  98. self.equalations = solver.toEq([lin.replace('>','').replace('<', '') for lin in data])
  99. self.ndims = ndims
  100. self.__X = self.__X[:ndims]
  101. self.solutions = self.solve()
  102. self.count = count
  103. self.step = step
  104. if __name__ == '__main__':
  105. # solver( seq='3*x1 + 4*x2',
  106. # data=['4*x1 + x2 <= 8',
  107. # 'x1 >= 0',
  108. # 'x1 - x2 >= -3',
  109. # 'x2 >= 0'], ndims=2).show()
  110. # solver( seq='3*x1 + 2*x2',
  111. # data=['2*x1 + 3*x2 <= 6',
  112. # 'x1 <= 2', 'x1 >= 0',
  113. # '2*x1 - x2 >= 0',
  114. # 'x2 >= 0', 'x2 <= 1'], ndims=2).show()
  115. # solver( seq='x1 + 3*x2',
  116. # data=['2*x1 + 3*x2 <= 24',
  117. # 'x1 >= 0',
  118. # 'x1 - x2 <= 7',
  119. # 'x2 >= 0', 'x2 <= 6'], ndims=2, step=0.1, count=25).show()
  120. # solver( seq='x1 - 1.1*x2 + 7.4',
  121. # data=['x1 >= 0',
  122. # 'x2 >= 0',
  123. # 'x1 + x2 <= 10',
  124. # '10 - x1 >= 0', '10 - x2 >= 0'], ndims=2, step=0.1, count=15).show()
  125. pass