aboutsummaryrefslogtreecommitdiff
path: root/logistic_regression/binary.py
diff options
context:
space:
mode:
Diffstat (limited to 'logistic_regression/binary.py')
-rwxr-xr-xlogistic_regression/binary.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/logistic_regression/binary.py b/logistic_regression/binary.py
new file mode 100755
index 0000000..a8a11e2
--- /dev/null
+++ b/logistic_regression/binary.py
@@ -0,0 +1,102 @@
+#!/usr/bin/env python
+
+import pandas as pd
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib.animation import FuncAnimation
+from mpl_toolkits.mplot3d import Axes3D
+
+
+# Load the data
+csv="../data/binary_logistic.csv"
+data=pd.read_csv(csv)
+x_1=np.array(data[data.columns[0]])
+x_2=np.array(data[data.columns[1]])
+y=np.array(data[data.columns[2]])
+
+w1=w2=w3=-8
+
+# Define our model
+def h(x_1,x_2):
+ global w1,w2,w3
+ model=w1+w2*x_1+w3*x_2
+ return(1/(1+np.exp(-model)))
+
+
+def dw1():
+ global x_1,x_2,y
+ return(1/len(x_1)*(sum(h(x_1,x_2)-y)))
+def dw2():
+ global x_1,x_2,y
+ return(1/len(x_1)*sum(x_1*(h(x_1,x_2)-y)))
+def dw3():
+ global x_1,x_2,y
+ return(1/len(x_1)*sum(x_2*(h(x_1,x_2)-y)))
+
+
+# Perform the gradient decent
+#fig, ax = plt.subplots(dpi=300)
+alpha=0.01 # Proportion of the gradient to take into account
+accuracy=0.0001 # Accuracy of the decent
+done=False
+def decent():
+ global w1,w2,w3,x,y
+ skip_frame=0 # Current frame (plot animation)
+ while True:
+ w1_old=w1
+ w1_new=w1-alpha*dw1()
+ w2_old=w2
+ w2_new=w2-alpha*dw2()
+ w3_old=w3
+ w3_new=w3-alpha*dw3()
+ w1=w1_new
+ w2=w2_new
+ w3=w3_new
+
+ if abs(w1_new-w1_old) <= accuracy and abs(w2_new-w2_old) <= accuracy and abs(w2_new-w2_old) <= accuracy:
+ break
+ skip_frame+=1
+
+
+
+
+
+
+
+decent()
+fig=plt.figure()
+
+#print(np.round(h(x_1,x_2)))
+#pred=np.round(h(x_1,x_2))
+
+# Plot data
+ax = fig.add_subplot(2,2,1)
+ax.set_title("Original Data")
+ax.set_xlabel("X")
+ax.set_ylabel("Y")
+scatter=plt.scatter(x_1,x_2,c=y,marker="o")
+handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
+legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
+
+
+# Plot model
+ax = fig.add_subplot(2,2,2,projection='3d')
+ax.set_title("Model")
+X,Y= np.meshgrid(np.sort(x_1), np.sort(x_2))
+ax.set_xlabel("X")
+ax.set_ylabel("Y")
+ax.set_zlabel("Probability")
+surf = ax.plot_wireframe(X,Y, h(X,Y),rstride=10,cstride=10)
+
+# Plot prediction
+ax = fig.add_subplot(2,1,2)
+ax.set_title("Predictions")
+ax.set_xlabel("X")
+ax.set_ylabel("Y")
+scatter=plt.scatter(x_1,x_2,c=np.round(h(x_1,x_2)),marker="o")
+handles, labels = scatter.legend_elements(prop="colors", alpha=0.6)
+legend = ax.legend(handles, ["Class A","Class B"], loc="upper right", title="Legend")
+
+# Save
+plt.tight_layout()
+plt.savefig("binary.png",dpi=300)