#import torch
#import torch.nn as nn
import matplotlib.pyplot as plt
# Assuming you have an input tensor 'input_tensor'
input_tensor = torch.tensor([1, 2, 3, 4], dtype=torch.float32)
# Create an instance of the 'Net' model
model = Net()
# Set the model to evaluation mode
model.eval()
# Store the intermediate layer outputs
outputs = []
# Forward pass and store intermediate outputs
def hook(module, input, output):
outputs.append(output)
# Register the hook to capture intermediate outputs
hook_handle = model.fc3.register_forward_hook(hook)
model(input_tensor.unsqueeze(0))
hook_handle.remove()
# Plot the intermediate outputs
for i, output in enumerate(outputs):
print(i)
plt.figure()
plt.title(f'Layer {i+1} Output')
plt.bar(range(output.size(-1)), output.squeeze().detach().numpy())
plt.show()