import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# Toy dataset: two Gaussian blobs for binary classification
np.random.seed(42)
class0 = np.random.randn(100, 2) * 0.5 np.array([0, 0])
class1 = np.random.randn(100, 2) * 0.5 np.array([2, 2])
data = np.vstack([class0, class1])
labels = np.hstack([np.zeros(100), np.ones(100)])
data_tensor = torch.tensor(data, dtype=torch.float32)
labels_tensor = torch.tensor(labels, dtype=torch.long)
# Neural ODE dynamics: parameterizes dz/dt
class ODEFunc(nn.Module):
def __init__(self, dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, 50),
nn.Tanh(),
nn.Linear(50, dim)
)
def forward(self, t, z):
return
self.net(z)
# Euler method for forward ODE solve (N steps approximate continuous integration)
def euler_solve(func, z0, t0=0.0, t1=1.0, N=50):
dt = (t1 - t0) / N
z = z0
t = t0
for _ in range(N):
z = z dt * func(t, z)
t = dt
return z
# Full model: ODE evolution classifier
class NeuralODEClassifier(nn.Module):
def __init__(self, input_dim, num_classes):
super().__init__()
self.func = ODEFunc(input_dim)
self.classifier = nn.Linear(input_dim, num_classes)
def forward(self, x):
z_T = euler_solve(self.func, x)
return self.classifier(z_T)
# Training
model = NeuralODEClassifier(2, 2)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for epoch in range(100):
optimizer.zero_grad()
outputs = model(data_tensor)
loss = criterion(outputs, labels_tensor)
loss.backward()
optimizer.step()
if epoch % 20 == 0:
acc = (outputs.argmax(1) == labels_tensor).float().mean().item()
print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Accuracy = {acc:.4f}")
# Visualization (optional, for demo)
with
torch.no_grad():
xx, yy = np.meshgrid(np.linspace(-1, 3, 50), np.linspace(-1, 3, 50))
grid = torch.tensor(np.c_[xx.ravel(), yy.ravel()], dtype=torch.float32)
preds = model(grid).argmax(1).reshape(xx.shape)
plt.contourf(xx, yy, preds.numpy(), alpha=0.8)
plt.scatter(data[:, 0], data[:, 1], c=labels)
plt.title("Decision Boundary via Neural ODE Trajectory")
plt.show()