diff options
Diffstat (limited to 'logistic_regression/binary.py')
| -rwxr-xr-x | logistic_regression/binary.py | 102 |
1 files changed, 102 insertions, 0 deletions
diff --git a/logistic_regression/binary.py b/logistic_regression/binary.py new file mode 100755 index 0000000..a8a11e2 --- /dev/null +++ b/logistic_regression/binary.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.animation import FuncAnimation +from mpl_toolkits.mplot3d import Axes3D + + +# Load the data +csv="../data/binary_logistic.csv" +data=pd.read_csv(csv) +x_1=np.array(data[data.columns[0]]) +x_2=np.array(data[data.columns[1]]) +y=np.array(data[data.columns[2]]) + +w1=w2=w3=-8 + +# Define our model +def h(x_1,x_2): + global w1,w2,w3 + model=w1+w2*x_1+w3*x_2 + return(1/(1+np.exp(-model))) + + +def dw1(): + global x_1,x_2,y + return(1/len(x_1)*(sum(h(x_1,x_2)-y))) +def dw2(): + global x_1,x_2,y + return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y))) +def dw3(): + global x_1,x_2,y + return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y))) + + +# Perform the gradient decent +#fig, ax = plt.subplots(dpi=300) +alpha=0.01 # Proportion of the gradient to take into account +accuracy=0.0001 # Accuracy of the decent +done=False +def decent(): + global w1,w2,w3,x,y + skip_frame=0 # Current frame (plot animation) + while True: + w1_old=w1 + w1_new=w1-alpha*dw1() + w2_old=w2 + w2_new=w2-alpha*dw2() + w3_old=w3 + w3_new=w3-alpha*dw3() + w1=w1_new + w2=w2_new + w3=w3_new + + if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy: + break + skip_frame+=1 + + + + + + + +decent() +fig=plt.figure() + +#print(np.round(h(x_1,x_2))) +#pred=np.round(h(x_1,x_2)) + +# Plot data +ax = fig.add_subplot(2,2,1) +ax.set_title("Original Data") +ax.set_xlabel("X") +ax.set_ylabel("Y") +scatter=plt.scatter(x_1,x_2,c=y,marker="o") +handles, labels = scatter.legend_elements(prop="colors", alpha=0.6) +legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend") + + +# Plot model +ax = fig.add_subplot(2,2,2,projection='3d') +ax.set_title("Model") +X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2)) +ax.set_xlabel("X") +ax.set_ylabel("Y") +ax.set_zlabel("Probability") +surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10) + +# Plot prediction +ax = fig.add_subplot(2,1,2) +ax.set_title("Predictions") +ax.set_xlabel("X") +ax.set_ylabel("Y") +scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o") +handles, labels = scatter.legend_elements(prop="colors", alpha=0.6) +legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend") + +# Save +plt.tight_layout() +plt.savefig("binary.png",dpi=300) |
