In [1]:
from glob import glob
from utils import *
from riftnet import RiftNet
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
c:\Users\user\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\_pytree.py:185: FutureWarning: optree is installed but the version is too old to support PyTorch Dynamo in C++ pytree. C++ pytree support is disabled. Please consider upgrading optree using `python3 -m pip install --upgrade 'optree>=0.13.0'`. warnings.warn(
In [2]:
file_paths = glob("Bluetooth_Datasets/Dataset 250 Msps/*/*/*/*")
unknown_users = {"4s_013004004984503_oguz_guler", "Note2_356261053336200_ismet_buyukkilic", "XperiaM5_354188070809491_firat_vural"}
file_paths_filtered = []
user_names_filtered = []
for file_path in file_paths:
user = format_model_name(file_path.split("\\")[2:4])
if user not in unknown_users:
file_paths_filtered.append(file_path)
user_names_filtered.append(user)
# Create dataset with filtered data
dataset = RFFDataset(file_paths=file_paths_filtered, labels=user_names_filtered)
In [3]:
def train_riftnet(dataset, analytic_fn, num_classes, epochs):
encoded_dataset, le = encode_labels(dataset)
# Stratified split: 80% train, 20% test
train_data, test_data = train_test_split(
encoded_dataset, test_size=0.2, stratify=[d[1] for d in encoded_dataset], random_state=0
)
train_dataset = IQDataset(train_data, analytic_fn)
test_dataset = IQDataset(test_data, analytic_fn)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = RiftNet(num_classes=num_classes).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
losses, accuracies, test_accuracies = [], [], []
for epoch in range(epochs):
model.train()
total_loss, correct, total = 0.0, 0, 0
with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100) as pbar:
for x_long, x_short, y in pbar:
x_long, x_short, y = x_long.to(device), x_short.to(device), y.to(device)
optimizer.zero_grad()
out = model(x_long, x_short)
loss = criterion(out, y)
loss.backward()
optimizer.step()
total_loss += loss.item() * x_long.size(0)
_, preds = out.max(1)
correct += (preds == y).sum().item()
total += y.size(0)
avg_loss = total_loss / total
acc = correct / total
pbar.set_postfix(loss=avg_loss, acc=acc)
losses.append(avg_loss)
accuracies.append(acc)
# ----- Evaluate on test set -----
model.eval()
correct_test, total_test = 0, 0
y_true, y_pred = [], [] # To store true and predicted labels for confusion matrix
with torch.no_grad():
for x_long, x_short, y in test_loader:
x_long, x_short, y = x_long.to(device), x_short.to(device), y.to(device)
out = model(x_long, x_short)
_, preds = out.max(1)
correct_test += (preds == y).sum().item()
total_test += y.size(0)
# Collect true and predicted labels
y_true.extend(y.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
test_acc = correct_test / total_test
test_accuracies.append(test_acc)
print(f"Test acc: {test_acc:.2f} %")
# Save best model
if epoch == 0 or test_acc > max(test_accuracies[:-1]):
torch.save(model.state_dict(), 'best_model_state_dict.pth')
print(f"✅ Saved best model at epoch {epoch+1} with test acc: {test_acc:.2%}")
model.load_state_dict(torch.load('best_model_state_dict.pth'))
# ----- Evaluate on test set -----
model.eval()
correct_test, total_test = 0, 0
y_true, y_pred = [], [] # To store true and predicted labels for confusion matrix
with torch.no_grad():
for x_long, x_short, y in test_loader:
x_long, x_short, y = x_long.to(device), x_short.to(device), y.to(device)
out = model(x_long, x_short)
_, preds = out.max(1)
correct_test += (preds == y).sum().item()
total_test += y.size(0)
# Collect true and predicted labels
y_true.extend(y.cpu().numpy())
y_pred.extend(preds.cpu().numpy())
test_acc = correct_test / total_test
print(f"Last test acc: {test_acc:.2f} %")
# Confusion Matrix
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_)
fig, ax = plt.subplots(figsize=(16, 12))
disp.plot(ax=ax, cmap='Blues', colorbar=True)
plt.xticks(rotation=90)
plt.title(f"Confusion Matrix at Epoch {epoch+1}, Test Accuracy: {test_acc:.2%}")
plt.tight_layout()
plt.show()
# ----- Plot -----
plt.figure(figsize=(15, 4))
plt.subplot(1, 3, 1)
plt.plot(losses, label='Train Loss')
plt.xlabel("Epoch"); plt.ylabel("Loss")
plt.title("Training Loss")
plt.subplot(1, 3, 2)
plt.plot(accuracies, label='Train Acc')
plt.xlabel("Epoch"); plt.ylabel("Accuracy")
plt.title("Training Accuracy")
plt.subplot(1, 3, 3)
plt.plot(test_accuracies, label='Test Acc')
plt.xlabel("Epoch"); plt.ylabel("Accuracy")
plt.title("Test Accuracy")
plt.tight_layout()
plt.show()
return model, le
# ====== USAGE ======
model, le = train_riftnet(dataset, analytic_signal, num_classes=30, epochs=100)
Epoch 1/100: 100%|██████████████████████████| 113/113 [00:12<00:00, 9.13it/s, acc=0.242, loss=3.28]
Test acc: 0.29 % ✅ Saved best model at epoch 1 with test acc: 29.11%
Epoch 2/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.66it/s, acc=0.408, loss=3.11]
Test acc: 0.52 % ✅ Saved best model at epoch 2 with test acc: 52.22%
Epoch 3/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.528, loss=2.96]
Test acc: 0.55 % ✅ Saved best model at epoch 3 with test acc: 55.00%
Epoch 4/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.72it/s, acc=0.601, loss=2.89]
Test acc: 0.64 % ✅ Saved best model at epoch 4 with test acc: 64.33%
Epoch 5/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.648, loss=2.84]
Test acc: 0.62 %
Epoch 6/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.73it/s, acc=0.742, loss=2.76]
Test acc: 0.70 % ✅ Saved best model at epoch 6 with test acc: 69.89%
Epoch 7/100: 100%|███████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.796, loss=2.7]
Test acc: 0.79 % ✅ Saved best model at epoch 7 with test acc: 79.00%
Epoch 8/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.828, loss=2.67]
Test acc: 0.82 % ✅ Saved best model at epoch 8 with test acc: 82.44%
Epoch 9/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.74it/s, acc=0.858, loss=2.63]
Test acc: 0.83 % ✅ Saved best model at epoch 9 with test acc: 83.11%
Epoch 10/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.71it/s, acc=0.89, loss=2.59]
Test acc: 0.90 % ✅ Saved best model at epoch 10 with test acc: 89.67%
Epoch 11/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.91, loss=2.57]
Test acc: 0.76 %
Epoch 12/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.913, loss=2.56]
Test acc: 0.86 %
Epoch 13/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.924, loss=2.55]
Test acc: 0.90 % ✅ Saved best model at epoch 13 with test acc: 89.89%
Epoch 14/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.924, loss=2.54]
Test acc: 0.90 %
Epoch 15/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.913, loss=2.56]
Test acc: 0.91 % ✅ Saved best model at epoch 15 with test acc: 90.78%
Epoch 16/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.931, loss=2.54]
Test acc: 0.90 %
Epoch 17/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.929, loss=2.54]
Test acc: 0.73 %
Epoch 18/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.925, loss=2.54]
Test acc: 0.91 % ✅ Saved best model at epoch 18 with test acc: 91.33%
Epoch 19/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.935, loss=2.53]
Test acc: 0.82 %
Epoch 20/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.937, loss=2.53]
Test acc: 0.92 % ✅ Saved best model at epoch 20 with test acc: 92.44%
Epoch 21/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.72it/s, acc=0.938, loss=2.53]
Test acc: 0.89 %
Epoch 22/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.936, loss=2.53]
Test acc: 0.80 %
Epoch 23/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.73it/s, acc=0.938, loss=2.53]
Test acc: 0.91 %
Epoch 24/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.953, loss=2.51]
Test acc: 0.90 %
Epoch 25/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.977, loss=2.49]
Test acc: 0.96 % ✅ Saved best model at epoch 25 with test acc: 95.89%
Epoch 26/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.978, loss=2.49]
Test acc: 0.91 %
Epoch 27/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.977, loss=2.49]
Test acc: 0.88 %
Epoch 28/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.982, loss=2.48]
Test acc: 0.93 %
Epoch 29/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.978, loss=2.49]
Test acc: 0.96 % ✅ Saved best model at epoch 29 with test acc: 96.33%
Epoch 30/100: 100%|█████████████████████████| 113/113 [00:12<00:00, 9.41it/s, acc=0.983, loss=2.48]
Test acc: 0.92 %
Epoch 31/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.44it/s, acc=0.984, loss=2.48]
Test acc: 0.94 %
Epoch 32/100: 100%|██████████████████████████| 113/113 [00:12<00:00, 9.38it/s, acc=0.98, loss=2.48]
Test acc: 0.95 %
Epoch 33/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.57it/s, acc=0.984, loss=2.48]
Test acc: 0.95 %
Epoch 34/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.985, loss=2.48]
Test acc: 0.96 %
Epoch 35/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.984, loss=2.48]
Test acc: 0.92 %
Epoch 36/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.45it/s, acc=0.981, loss=2.48]
Test acc: 0.90 %
Epoch 37/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.63it/s, acc=0.987, loss=2.47]
Test acc: 0.94 %
Epoch 38/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.68it/s, acc=0.986, loss=2.48]
Test acc: 0.94 %
Epoch 39/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.986, loss=2.48]
Test acc: 0.95 %
Epoch 40/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.985, loss=2.48]
Test acc: 0.96 % ✅ Saved best model at epoch 40 with test acc: 96.44%
Epoch 41/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.988, loss=2.47]
Test acc: 0.90 %
Epoch 42/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.985, loss=2.48]
Test acc: 0.92 %
Epoch 43/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.986, loss=2.48]
Test acc: 0.85 %
Epoch 44/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.988, loss=2.47]
Test acc: 0.96 %
Epoch 45/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.991, loss=2.47]
Test acc: 0.96 %
Epoch 46/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.985, loss=2.48]
Test acc: 0.96 %
Epoch 47/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.988, loss=2.47]
Test acc: 0.63 %
Epoch 48/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.986, loss=2.47]
Test acc: 0.90 %
Epoch 49/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.983, loss=2.48]
Test acc: 0.78 %
Epoch 50/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.79it/s, acc=0.989, loss=2.47]
Test acc: 0.89 %
Epoch 51/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.985, loss=2.48]
Test acc: 0.76 %
Epoch 52/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.68it/s, acc=0.986, loss=2.47]
Test acc: 0.66 %
Epoch 53/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.988, loss=2.47]
Test acc: 0.88 %
Epoch 54/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.73it/s, acc=0.981, loss=2.48]
Test acc: 0.93 %
Epoch 55/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.74it/s, acc=0.984, loss=2.48]
Test acc: 0.94 %
Epoch 56/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.988, loss=2.47]
Test acc: 0.86 %
Epoch 57/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.92 %
Epoch 58/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.991, loss=2.47]
Test acc: 0.92 %
Epoch 59/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.91 %
Epoch 60/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.988, loss=2.47]
Test acc: 0.60 %
Epoch 61/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.986, loss=2.47]
Test acc: 0.77 %
Epoch 62/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.988, loss=2.47]
Test acc: 0.79 %
Epoch 63/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.986, loss=2.47]
Test acc: 0.90 %
Epoch 64/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.984, loss=2.47]
Test acc: 0.79 %
Epoch 65/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.86 %
Epoch 66/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.983, loss=2.48]
Test acc: 0.84 %
Epoch 67/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.91 %
Epoch 68/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.991, loss=2.47]
Test acc: 0.77 %
Epoch 69/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.987, loss=2.47]
Test acc: 0.58 %
Epoch 70/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.988, loss=2.47]
Test acc: 0.77 %
Epoch 71/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.988, loss=2.47]
Test acc: 0.92 %
Epoch 72/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.989, loss=2.47]
Test acc: 0.53 %
Epoch 73/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.94 %
Epoch 74/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.70it/s, acc=0.989, loss=2.47]
Test acc: 0.69 %
Epoch 75/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.73it/s, acc=0.992, loss=2.47]
Test acc: 0.95 %
Epoch 76/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.992, loss=2.47]
Test acc: 0.95 %
Epoch 77/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.96 %
Epoch 78/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.992, loss=2.47]
Test acc: 0.93 %
Epoch 79/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.991, loss=2.47]
Test acc: 0.96 %
Epoch 80/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.992, loss=2.47]
Test acc: 0.89 %
Epoch 81/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.988, loss=2.47]
Test acc: 0.96 %
Epoch 82/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.67 %
Epoch 83/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.86 %
Epoch 84/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.988, loss=2.47]
Test acc: 0.95 %
Epoch 85/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.88 %
Epoch 86/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.988, loss=2.47]
Test acc: 0.83 %
Epoch 87/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.71it/s, acc=0.989, loss=2.47]
Test acc: 0.93 %
Epoch 88/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.72it/s, acc=0.988, loss=2.47]
Test acc: 0.91 %
Epoch 89/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.73it/s, acc=0.987, loss=2.47]
Test acc: 0.94 %
Epoch 90/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.71it/s, acc=0.991, loss=2.47]
Test acc: 0.92 %
Epoch 91/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.74it/s, acc=0.985, loss=2.47]
Test acc: 0.88 %
Epoch 92/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.96 %
Epoch 93/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.993, loss=2.47]
Test acc: 0.91 %
Epoch 94/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.76it/s, acc=0.993, loss=2.47]
Test acc: 0.74 %
Epoch 95/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.92 %
Epoch 96/100: 100%|██████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.99, loss=2.47]
Test acc: 0.58 %
Epoch 97/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.991, loss=2.47]
Test acc: 0.88 %
Epoch 98/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.78 %
Epoch 99/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.75it/s, acc=0.993, loss=2.47]
Test acc: 0.68 %
Epoch 100/100: 100%|█████████████████████████| 113/113 [00:11<00:00, 9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.36 % Last test acc: 0.96 %