Decision Map

Authors
Date
Jun 20, 2023 05:25 AM
Field
Machine Learning
Main Tags
Tags
Additional Tags
 
What is
when use

how to implement

def sample_grid(M=500, x_max=2.0): """ Helper function to simulate sample meshgrid Args: M: int Size of the constructed tensor with meshgrid x_max: float Defines range for the set of points Returns: X_all: torch.tensor Concatenated meshgrid tensor """ ii, jj = torch.meshgrid(torch.linspace(-x_max, x_max, M), torch.linspace(-x_max, x_max, M)) X_all = torch.cat([ii.unsqueeze(-1), jj.unsqueeze(-1)], dim=-1).view(-1, 2) return X_all def plot_decision_map(X_all, y_pred, X_test, y_test, M=500, x_max=2.0, eps=1e-3): """ Helper function to plot decision map Args: X_all: torch.tensor Concatenated meshgrid tensor y_pred: torch.tensor Labels predicted by the network X_test: torch.tensor Test data y_test: torch.tensor Labels of the test data M: int Size of the constructed tensor with meshgrid x_max: float Defines range for the set of points eps: float Decision threshold Returns: Nothing """ decision_map = torch.argmax(y_pred, dim=1) for i in range(len(X_test)): indices = (X_all[:, 0] - X_test[i, 0])**2 + (X_all[:, 1] - X_test[i, 1])**2 < eps decision_map[indices] = (K + y_test[i]).long() decision_map = decision_map.view(M, M) plt.imshow(decision_map, extent=[-x_max, x_max, -x_max, x_max], cmap='jet') plt.show()
call the fucntion
X_all = sample_grid() y_pred = net(X_all) plot_decision_map(X_all, y_pred, X_test, y_test)
notion image
an full example: