Claude-skill-registry complex-tensor-handler
Handle complex-valued tensors in PyTorch for astronomical imaging applications. This skill should be used when working with Fourier transforms, phase/amplitude representations, and complex arithmetic in PRISM.
install
source · Clone the upstream repo
git clone https://github.com/majiayu000/claude-skill-registry
Claude Code · Install into ~/.claude/skills/
T=$(mktemp -d) && git clone --depth=1 https://github.com/majiayu000/claude-skill-registry "$T" && mkdir -p ~/.claude/skills && cp -r "$T/skills/data/complex-tensor-handler" ~/.claude/skills/majiayu000-claude-skill-registry-complex-tensor-handler && rm -rf "$T"
manifest:
skills/data/complex-tensor-handler/SKILL.mdsource content
Complex Tensor Handler
Work with complex-valued tensors in PyTorch for astronomical imaging, including FFT operations, phase/amplitude conversions, and complex arithmetic.
Purpose
PRISM deals with complex-valued images (phase + amplitude or real + imaginary). This skill provides patterns for correctly handling complex tensors in PyTorch.
When to Use
Use this skill when:
- Working with FFT/IFFT operations
- Converting between phase/amplitude and real/imaginary
- Performing complex arithmetic
- Dealing with complex-valued neural networks
Complex Tensor Basics
Creating Complex Tensors
import torch # From real and imaginary parts real = torch.randn(1, 1, 256, 256) imag = torch.randn(1, 1, 256, 256) complex_tensor = torch.complex(real, imag) # or: complex_tensor = real + 1j * imag # From magnitude and phase magnitude = torch.rand(1, 1, 256, 256) phase = torch.rand(1, 1, 256, 256) * 2 * np.pi complex_tensor = magnitude * torch.exp(1j * phase) # Convert real to complex real_tensor = torch.randn(1, 1, 256, 256) complex_tensor = real_tensor.to(torch.complex64)
Accessing Components
# Real and imaginary parts real_part = complex_tensor.real imag_part = complex_tensor.imag # Magnitude and phase magnitude = complex_tensor.abs() phase = complex_tensor.angle() # Conjugate conjugate = complex_tensor.conj()
Common Operations
FFT and IFFT
def fft(image: Tensor, norm: str = 'ortho') -> Tensor: """2D FFT of image tensor. Parameters ---------- image : Tensor Spatial domain image [B, C, H, W], real or complex norm : str Normalization: 'ortho', 'forward', or 'backward' Returns ------- Tensor Frequency domain, complex-valued [B, C, H, W] """ # PyTorch FFT handles real input automatically freq = torch.fft.fft2(image, norm=norm) # Optionally shift zero-frequency to center freq = torch.fft.fftshift(freq, dim=(-2, -1)) return freq def ifft(freq: Tensor, norm: str = 'ortho') -> Tensor: """Inverse 2D FFT. Parameters ---------- freq : Tensor Frequency domain [B, C, H, W], complex-valued norm : str Normalization mode Returns ------- Tensor Spatial domain [B, C, H, W], complex-valued """ # Unshift if needed freq = torch.fft.ifftshift(freq, dim=(-2, -1)) # Inverse FFT image = torch.fft.ifft2(freq, norm=norm) return image
Phase and Amplitude
def to_phase_amplitude(complex_tensor: Tensor) -> tuple[Tensor, Tensor]: """Convert complex tensor to phase and amplitude. Parameters ---------- complex_tensor : Tensor Complex-valued tensor [B, C, H, W] Returns ------- phase : Tensor Phase in radians [-π, π], shape [B, C, H, W] amplitude : Tensor Amplitude (magnitude), shape [B, C, H, W] """ phase = complex_tensor.angle() amplitude = complex_tensor.abs() return phase, amplitude def from_phase_amplitude(phase: Tensor, amplitude: Tensor) -> Tensor: """Create complex tensor from phase and amplitude. Parameters ---------- phase : Tensor Phase in radians, shape [B, C, H, W] amplitude : Tensor Amplitude, shape [B, C, H, W] Returns ------- Tensor Complex-valued tensor [B, C, H, W] """ return amplitude * torch.exp(1j * phase)
Real and Imaginary
def to_real_imag(complex_tensor: Tensor) -> tuple[Tensor, Tensor]: """Split complex tensor into real and imaginary parts. Parameters ---------- complex_tensor : Tensor Complex-valued tensor [B, C, H, W] Returns ------- real : Tensor Real part, shape [B, C, H, W] imag : Tensor Imaginary part, shape [B, C, H, W] """ return complex_tensor.real, complex_tensor.imag def from_real_imag(real: Tensor, imag: Tensor) -> Tensor: """Create complex tensor from real and imaginary parts. Parameters ---------- real : Tensor Real part, shape [B, C, H, W] imag : Tensor Imaginary part, shape [B, C, H, W] Returns ------- Tensor Complex-valued tensor [B, C, H, W] """ return torch.complex(real, imag)
PRISM-Specific Patterns
Generate Complex Image
class ComplexImageGenerator(nn.Module): """Generate complex-valued images (phase + amplitude).""" def __init__(self, latent_dim: int = 128, output_size: int = 256): super().__init__() self.decoder = build_decoder(latent_dim, output_channels=2) self.output_size = output_size def forward(self, latent: Optional[Tensor] = None) -> Tensor: """Generate complex image. Returns ------- Tensor Complex-valued image [1, 1, H, W] """ if latent is None: latent = self.latent.expand(1, -1, 1, 1) # Generate phase and amplitude as separate channels output = self.decoder(latent) # [1, 2, H, W] # Split into phase and amplitude phase = output[:, 0:1] # [1, 1, H, W] amplitude = output[:, 1:2].exp() # [1, 1, H, W], always positive # Create complex tensor complex_image = from_phase_amplitude(phase, amplitude) return complex_image
Telescope Measurement with Complex Values
class Telescope(nn.Module): """Telescope with complex-valued measurements.""" def forward( self, complex_image: Tensor, centers: list[tuple[float, float]] ) -> list[Tensor]: """Take measurements of complex image. Parameters ---------- complex_image : Tensor Complex-valued image [1, 1, H, W] centers : list[tuple[float, float]] Measurement positions Returns ------- list[Tensor] Complex-valued measurements """ measurements = [] for center in centers: # Create aperture mask mask = self.create_mask(center) # [H, W], real # Apply mask (broadcasts to complex) masked = complex_image * mask # [1, 1, H, W], complex # FFT to measurement plane measurement = fft(masked) # [1, 1, H, W], complex # Add noise (complex noise) if self.snr is not None: noise_real = torch.randn_like(measurement.real) / self.snr noise_imag = torch.randn_like(measurement.imag) / self.snr noise = torch.complex(noise_real, noise_imag) measurement = measurement + noise measurements.append(measurement) return measurements
Complex Loss Functions
def complex_mse_loss(pred: Tensor, target: Tensor) -> Tensor: """MSE loss for complex-valued tensors. Parameters ---------- pred : Tensor Predicted complex tensor [B, C, H, W] target : Tensor Target complex tensor [B, C, H, W] Returns ------- Tensor Scalar loss """ # Separate real and imaginary parts loss_real = F.mse_loss(pred.real, target.real) loss_imag = F.mse_loss(pred.imag, target.imag) return loss_real + loss_imag def phase_amplitude_loss(pred: Tensor, target: Tensor) -> Tensor: """Loss in phase-amplitude space. Parameters ---------- pred : Tensor Predicted complex tensor target : Tensor Target complex tensor Returns ------- Tensor Scalar loss """ # Convert to phase and amplitude pred_phase, pred_amp = to_phase_amplitude(pred) target_phase, target_amp = to_phase_amplitude(target) # Loss on amplitude amp_loss = F.mse_loss(pred_amp, target_amp) # Loss on phase (handle wrapping) phase_diff = pred_phase - target_phase # Wrap to [-π, π] phase_diff = torch.atan2(torch.sin(phase_diff), torch.cos(phase_diff)) phase_loss = (phase_diff ** 2).mean() return amp_loss + phase_loss
Visualization
def visualize_complex_image(complex_tensor: Tensor, title: str = ""): """Visualize complex image as phase and amplitude. Parameters ---------- complex_tensor : Tensor Complex image [1, 1, H, W] title : str Plot title """ # Convert to numpy [H, W] img = complex_tensor[0, 0].detach().cpu() phase = img.angle().numpy() amplitude = img.abs().numpy() fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5)) # Phase (use hsv colormap for [-π, π]) im1 = ax1.imshow(phase, cmap='hsv', vmin=-np.pi, vmax=np.pi) ax1.set_title(f'{title} - Phase') plt.colorbar(im1, ax=ax1, label='Phase (radians)') # Amplitude im2 = ax2.imshow(amplitude, cmap='gray') ax2.set_title(f'{title} - Amplitude') plt.colorbar(im2, ax=ax2, label='Amplitude') plt.tight_layout() return fig
Common Pitfalls
Pitfall 1: Forgetting dtype
# Wrong - may lose imaginary part complex_tensor = torch.tensor([1+2j, 3+4j]) # dtype inferred incorrectly # Correct - specify dtype complex_tensor = torch.tensor([1+2j, 3+4j], dtype=torch.complex64)
Pitfall 2: Operations Not Complex-Safe
# Some operations don't support complex complex_tensor = torch.randn(10, dtype=torch.complex64) # Error: relu not defined for complex # output = F.relu(complex_tensor) # Solution: Apply to real and imag separately output = torch.complex( F.relu(complex_tensor.real), F.relu(complex_tensor.imag) )
Pitfall 3: Phase Wrapping
# Phase differences can wrap around phase1 = torch.tensor([3.0]) # Near π phase2 = torch.tensor([-3.0]) # Near -π # Direct difference gives large value diff = phase1 - phase2 # 6.0 radians # Correct: wrap to [-π, π] diff_wrapped = torch.atan2(torch.sin(diff), torch.cos(diff)) # Near 0
Type Hints
from torch import Tensor from typing import TypeAlias # Type aliases for clarity ComplexTensor: TypeAlias = Tensor # Complex-valued tensor PhaseTensor: TypeAlias = Tensor # Phase in radians AmplitudeTensor: TypeAlias = Tensor # Non-negative amplitude def process_complex( image: ComplexTensor ) -> tuple[PhaseTensor, AmplitudeTensor]: """Process complex image.""" phase = image.angle() amplitude = image.abs() return phase, amplitude
Testing Complex Operations
def test_fft_inverse(): """Test FFT followed by IFFT returns original.""" image = torch.randn(1, 1, 256, 256) freq = fft(image) reconstructed = ifft(freq) assert torch.allclose(reconstructed.real, image, rtol=1e-5) assert torch.allclose(reconstructed.imag, torch.zeros_like(image), atol=1e-7) def test_phase_amplitude_roundtrip(): """Test phase/amplitude conversion roundtrip.""" original = torch.randn(1, 1, 256, 256, dtype=torch.complex64) phase, amplitude = to_phase_amplitude(original) reconstructed = from_phase_amplitude(phase, amplitude) assert torch.allclose(reconstructed, original, rtol=1e-5)