In [1]:
from glob import glob
from utils import *
from riftnet import RiftNet

import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
c:\Users\user\AppData\Local\Programs\Python\Python310\lib\site-packages\torch\utils\_pytree.py:185: FutureWarning: optree is installed but the version is too old to support PyTorch Dynamo in C++ pytree. C++ pytree support is disabled. Please consider upgrading optree using `python3 -m pip install --upgrade 'optree>=0.13.0'`.
  warnings.warn(
In [2]:
file_paths = glob("Bluetooth_Datasets/Dataset 250 Msps/*/*/*/*")

unknown_users = {"4s_013004004984503_oguz_guler", "Note2_356261053336200_ismet_buyukkilic", "XperiaM5_354188070809491_firat_vural"}

file_paths_filtered = []
user_names_filtered = []

for file_path in file_paths:
    user = format_model_name(file_path.split("\\")[2:4])
    if user not in unknown_users:
        file_paths_filtered.append(file_path)
        user_names_filtered.append(user)

# Create dataset with filtered data
dataset = RFFDataset(file_paths=file_paths_filtered, labels=user_names_filtered)
In [3]:
def train_riftnet(dataset, analytic_fn, num_classes, epochs):
    encoded_dataset, le = encode_labels(dataset)

    # Stratified split: 80% train, 20% test
    train_data, test_data = train_test_split(
        encoded_dataset, test_size=0.2, stratify=[d[1] for d in encoded_dataset], random_state=0
    )

    train_dataset = IQDataset(train_data, analytic_fn)
    test_dataset = IQDataset(test_data, analytic_fn)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = RiftNet(num_classes=num_classes).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    losses, accuracies, test_accuracies = [], [], []

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", ncols=100) as pbar:
            for x_long, x_short, y in pbar:
                x_long, x_short, y = x_long.to(device), x_short.to(device), y.to(device)
                optimizer.zero_grad()
                out = model(x_long, x_short)
                loss = criterion(out, y)
                loss.backward()
                optimizer.step()

                total_loss += loss.item() * x_long.size(0)
                _, preds = out.max(1)
                correct += (preds == y).sum().item()
                total += y.size(0)

                avg_loss = total_loss / total
                acc = correct / total
                pbar.set_postfix(loss=avg_loss, acc=acc)

        losses.append(avg_loss)
        accuracies.append(acc)

        # ----- Evaluate on test set -----  
        model.eval()
        correct_test, total_test = 0, 0
        y_true, y_pred = [], []  # To store true and predicted labels for confusion matrix
        with torch.no_grad():
            for x_long, x_short, y in test_loader:
                x_long, x_short, y = x_long.to(device), x_short.to(device), y.to(device)
                out = model(x_long, x_short)
                _, preds = out.max(1)
                correct_test += (preds == y).sum().item()
                total_test += y.size(0)
                
                # Collect true and predicted labels
                y_true.extend(y.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())
        
        test_acc = correct_test / total_test
        test_accuracies.append(test_acc)
        print(f"Test acc: {test_acc:.2f} %")
    
        # Save best model
        if epoch == 0 or test_acc > max(test_accuracies[:-1]):
            torch.save(model.state_dict(), 'best_model_state_dict.pth')
            print(f"✅ Saved best model at epoch {epoch+1} with test acc: {test_acc:.2%}")

    model.load_state_dict(torch.load('best_model_state_dict.pth'))
    # ----- Evaluate on test set -----  
    model.eval()
    correct_test, total_test = 0, 0
    y_true, y_pred = [], []  # To store true and predicted labels for confusion matrix
    with torch.no_grad():
        for x_long, x_short, y in test_loader:
            x_long, x_short, y = x_long.to(device), x_short.to(device), y.to(device)
            out = model(x_long, x_short)
            _, preds = out.max(1)
            correct_test += (preds == y).sum().item()
            total_test += y.size(0)
            
            # Collect true and predicted labels
            y_true.extend(y.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())
    
    test_acc = correct_test / total_test
    print(f"Last test acc: {test_acc:.2f} %")

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=le.classes_)
    fig, ax = plt.subplots(figsize=(16, 12))
    disp.plot(ax=ax, cmap='Blues', colorbar=True)
    plt.xticks(rotation=90)
    plt.title(f"Confusion Matrix at Epoch {epoch+1}, Test Accuracy: {test_acc:.2%}")
    plt.tight_layout()
    plt.show()

    # ----- Plot -----  
    plt.figure(figsize=(15, 4))

    plt.subplot(1, 3, 1)
    plt.plot(losses, label='Train Loss')
    plt.xlabel("Epoch"); plt.ylabel("Loss")
    plt.title("Training Loss")

    plt.subplot(1, 3, 2)
    plt.plot(accuracies, label='Train Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy")
    plt.title("Training Accuracy")

    plt.subplot(1, 3, 3)
    plt.plot(test_accuracies, label='Test Acc')
    plt.xlabel("Epoch"); plt.ylabel("Accuracy")
    plt.title("Test Accuracy")

    plt.tight_layout()
    plt.show()

    return model, le

# ====== USAGE ======
model, le = train_riftnet(dataset, analytic_signal, num_classes=30, epochs=100)
Epoch 1/100: 100%|██████████████████████████| 113/113 [00:12<00:00,  9.13it/s, acc=0.242, loss=3.28]
Test acc: 0.29 %
✅ Saved best model at epoch 1 with test acc: 29.11%
Epoch 2/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.66it/s, acc=0.408, loss=3.11]
Test acc: 0.52 %
✅ Saved best model at epoch 2 with test acc: 52.22%
Epoch 3/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.528, loss=2.96]
Test acc: 0.55 %
✅ Saved best model at epoch 3 with test acc: 55.00%
Epoch 4/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.72it/s, acc=0.601, loss=2.89]
Test acc: 0.64 %
✅ Saved best model at epoch 4 with test acc: 64.33%
Epoch 5/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.648, loss=2.84]
Test acc: 0.62 %
Epoch 6/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.73it/s, acc=0.742, loss=2.76]
Test acc: 0.70 %
✅ Saved best model at epoch 6 with test acc: 69.89%
Epoch 7/100: 100%|███████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.796, loss=2.7]
Test acc: 0.79 %
✅ Saved best model at epoch 7 with test acc: 79.00%
Epoch 8/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.828, loss=2.67]
Test acc: 0.82 %
✅ Saved best model at epoch 8 with test acc: 82.44%
Epoch 9/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.74it/s, acc=0.858, loss=2.63]
Test acc: 0.83 %
✅ Saved best model at epoch 9 with test acc: 83.11%
Epoch 10/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.71it/s, acc=0.89, loss=2.59]
Test acc: 0.90 %
✅ Saved best model at epoch 10 with test acc: 89.67%
Epoch 11/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.91, loss=2.57]
Test acc: 0.76 %
Epoch 12/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.913, loss=2.56]
Test acc: 0.86 %
Epoch 13/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.924, loss=2.55]
Test acc: 0.90 %
✅ Saved best model at epoch 13 with test acc: 89.89%
Epoch 14/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.924, loss=2.54]
Test acc: 0.90 %
Epoch 15/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.913, loss=2.56]
Test acc: 0.91 %
✅ Saved best model at epoch 15 with test acc: 90.78%
Epoch 16/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.931, loss=2.54]
Test acc: 0.90 %
Epoch 17/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.929, loss=2.54]
Test acc: 0.73 %
Epoch 18/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.925, loss=2.54]
Test acc: 0.91 %
✅ Saved best model at epoch 18 with test acc: 91.33%
Epoch 19/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.935, loss=2.53]
Test acc: 0.82 %
Epoch 20/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.937, loss=2.53]
Test acc: 0.92 %
✅ Saved best model at epoch 20 with test acc: 92.44%
Epoch 21/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.72it/s, acc=0.938, loss=2.53]
Test acc: 0.89 %
Epoch 22/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.936, loss=2.53]
Test acc: 0.80 %
Epoch 23/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.73it/s, acc=0.938, loss=2.53]
Test acc: 0.91 %
Epoch 24/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.953, loss=2.51]
Test acc: 0.90 %
Epoch 25/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.977, loss=2.49]
Test acc: 0.96 %
✅ Saved best model at epoch 25 with test acc: 95.89%
Epoch 26/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.978, loss=2.49]
Test acc: 0.91 %
Epoch 27/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.977, loss=2.49]
Test acc: 0.88 %
Epoch 28/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.982, loss=2.48]
Test acc: 0.93 %
Epoch 29/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.978, loss=2.49]
Test acc: 0.96 %
✅ Saved best model at epoch 29 with test acc: 96.33%
Epoch 30/100: 100%|█████████████████████████| 113/113 [00:12<00:00,  9.41it/s, acc=0.983, loss=2.48]
Test acc: 0.92 %
Epoch 31/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.44it/s, acc=0.984, loss=2.48]
Test acc: 0.94 %
Epoch 32/100: 100%|██████████████████████████| 113/113 [00:12<00:00,  9.38it/s, acc=0.98, loss=2.48]
Test acc: 0.95 %
Epoch 33/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.57it/s, acc=0.984, loss=2.48]
Test acc: 0.95 %
Epoch 34/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.985, loss=2.48]
Test acc: 0.96 %
Epoch 35/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.984, loss=2.48]
Test acc: 0.92 %
Epoch 36/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.45it/s, acc=0.981, loss=2.48]
Test acc: 0.90 %
Epoch 37/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.63it/s, acc=0.987, loss=2.47]
Test acc: 0.94 %
Epoch 38/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.68it/s, acc=0.986, loss=2.48]
Test acc: 0.94 %
Epoch 39/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.986, loss=2.48]
Test acc: 0.95 %
Epoch 40/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.985, loss=2.48]
Test acc: 0.96 %
✅ Saved best model at epoch 40 with test acc: 96.44%
Epoch 41/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.988, loss=2.47]
Test acc: 0.90 %
Epoch 42/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.985, loss=2.48]
Test acc: 0.92 %
Epoch 43/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.986, loss=2.48]
Test acc: 0.85 %
Epoch 44/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.988, loss=2.47]
Test acc: 0.96 %
Epoch 45/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.991, loss=2.47]
Test acc: 0.96 %
Epoch 46/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.985, loss=2.48]
Test acc: 0.96 %
Epoch 47/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.988, loss=2.47]
Test acc: 0.63 %
Epoch 48/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.986, loss=2.47]
Test acc: 0.90 %
Epoch 49/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.983, loss=2.48]
Test acc: 0.78 %
Epoch 50/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.79it/s, acc=0.989, loss=2.47]
Test acc: 0.89 %
Epoch 51/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.985, loss=2.48]
Test acc: 0.76 %
Epoch 52/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.68it/s, acc=0.986, loss=2.47]
Test acc: 0.66 %
Epoch 53/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.988, loss=2.47]
Test acc: 0.88 %
Epoch 54/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.73it/s, acc=0.981, loss=2.48]
Test acc: 0.93 %
Epoch 55/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.74it/s, acc=0.984, loss=2.48]
Test acc: 0.94 %
Epoch 56/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.988, loss=2.47]
Test acc: 0.86 %
Epoch 57/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.92 %
Epoch 58/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.991, loss=2.47]
Test acc: 0.92 %
Epoch 59/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.91 %
Epoch 60/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.988, loss=2.47]
Test acc: 0.60 %
Epoch 61/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.986, loss=2.47]
Test acc: 0.77 %
Epoch 62/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.988, loss=2.47]
Test acc: 0.79 %
Epoch 63/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.986, loss=2.47]
Test acc: 0.90 %
Epoch 64/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.984, loss=2.47]
Test acc: 0.79 %
Epoch 65/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.86 %
Epoch 66/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.983, loss=2.48]
Test acc: 0.84 %
Epoch 67/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.91 %
Epoch 68/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.991, loss=2.47]
Test acc: 0.77 %
Epoch 69/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.987, loss=2.47]
Test acc: 0.58 %
Epoch 70/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.988, loss=2.47]
Test acc: 0.77 %
Epoch 71/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.988, loss=2.47]
Test acc: 0.92 %
Epoch 72/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.989, loss=2.47]
Test acc: 0.53 %
Epoch 73/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.94 %
Epoch 74/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.70it/s, acc=0.989, loss=2.47]
Test acc: 0.69 %
Epoch 75/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.73it/s, acc=0.992, loss=2.47]
Test acc: 0.95 %
Epoch 76/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.992, loss=2.47]
Test acc: 0.95 %
Epoch 77/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.96 %
Epoch 78/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.992, loss=2.47]
Test acc: 0.93 %
Epoch 79/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.991, loss=2.47]
Test acc: 0.96 %
Epoch 80/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.992, loss=2.47]
Test acc: 0.89 %
Epoch 81/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.988, loss=2.47]
Test acc: 0.96 %
Epoch 82/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.67 %
Epoch 83/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.86 %
Epoch 84/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.988, loss=2.47]
Test acc: 0.95 %
Epoch 85/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.88 %
Epoch 86/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.988, loss=2.47]
Test acc: 0.83 %
Epoch 87/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.71it/s, acc=0.989, loss=2.47]
Test acc: 0.93 %
Epoch 88/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.72it/s, acc=0.988, loss=2.47]
Test acc: 0.91 %
Epoch 89/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.73it/s, acc=0.987, loss=2.47]
Test acc: 0.94 %
Epoch 90/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.71it/s, acc=0.991, loss=2.47]
Test acc: 0.92 %
Epoch 91/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.74it/s, acc=0.985, loss=2.47]
Test acc: 0.88 %
Epoch 92/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.96 %
Epoch 93/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.993, loss=2.47]
Test acc: 0.91 %
Epoch 94/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.76it/s, acc=0.993, loss=2.47]
Test acc: 0.74 %
Epoch 95/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.989, loss=2.47]
Test acc: 0.92 %
Epoch 96/100: 100%|██████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.99, loss=2.47]
Test acc: 0.58 %
Epoch 97/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.991, loss=2.47]
Test acc: 0.88 %
Epoch 98/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.78it/s, acc=0.989, loss=2.47]
Test acc: 0.78 %
Epoch 99/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.75it/s, acc=0.993, loss=2.47]
Test acc: 0.68 %
Epoch 100/100: 100%|█████████████████████████| 113/113 [00:11<00:00,  9.77it/s, acc=0.99, loss=2.47]
Test acc: 0.36 %
Last test acc: 0.96 %
No description has been provided for this image
No description has been provided for this image