import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from dogs_cats_ds import DogCatDataset
from model import DogCatClassifier
from consts import TEST_DATA
import torch.optim as optim

def test(model: nn.Module, test_loader: DataLoader, criterion, device):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for img, lab in test_loader:
            img, lab = img.to(device), lab.to(device).float().view(-1, 1)
            out = model(img)
            loss = criterion(out, lab)
            test_loss += loss.item()
            pred = (out > 0.5).float()
            total += lab.size(0)
            correct += (pred == lab).sum().item()

        print(f'test loss: {test_loss / len(test_loader):.4f}, test_acc: {100*correct/total:.2f}%')
        model.train()


if __name__ == "__main__":
    print('yo')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')
    dog_test_dataset = DogCatDataset(TEST_DATA)
    dog_test_loader = DataLoader(dog_test_dataset, batch_size = 32, shuffle = False) # since its test, bad to shuffle
    model = DogCatClassifier()
    criterion = nn.BCELoss()
    model.load_state_dict(torch.load('dog_cat_classifier.pth', map_location = device, weights_only = True))
    test(model, dog_test_loader, criterion, device)