# Linear MMSE Limitations

When E[x|y] is nonlinear, linear estimators are suboptimal.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from sklearn.ensemble import RandomForestRegressor
from sklearn.neural_network import MLPRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline

# Modern matplotlib style
plt.style.use('default')
plt.rcParams['figure.facecolor'] = 'white'
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['font.size'] = 10

np.random.seed(42)

## Nonlinear Test Cases

Linear MMSE achieves optimality only when $\mathbb{E}[\mathbf{x}|\mathbf{y}]$ is linear in $\mathbf{y}$. This occurs for jointly Gaussian variables but fails for many practical relationships.

Quadratic dependencies ($x = ay^2 + by$) and symmetric relationships ($x = y^2 - 1$) create scenarios where $\text{Cov}(x,y) \approx 0$ despite strong dependence, causing linear estimators to predict $\hat{x} \approx \mu_x$.

In [2]:
class NonlinearRelationship:
    """Base class for different nonlinear relationships between x and y"""
    
    def __init__(self, noise_std=0.2):
        self.noise_std = noise_std
    
    def generate_data(self, n_samples=1000, y_range=(-3, 3)):
        """Generate (x, y) samples with this relationship"""
        y = np.random.uniform(y_range[0], y_range[1], n_samples)
        x = self.true_function(y) + self.noise_std * np.random.randn(n_samples)
        return x, y
    
    def true_function(self, y):
        """True conditional expectation E[x|y]"""
        raise NotImplementedError
    
    def description(self):
        """Description of the relationship"""
        raise NotImplementedError

class QuadraticRelationship(NonlinearRelationship):
    def __init__(self, a=0.5, b=0.1, noise_std=0.2):
        super().__init__(noise_std)
        self.a = a
        self.b = b
    
    def true_function(self, y):
        return self.a * y**2 + self.b * y
    
    def description(self):
        return f"Quadratic: x = {self.a:.1f}y² + {self.b:.1f}y + noise"

class SinusoidalRelationship(NonlinearRelationship):
    def __init__(self, freq=1.0, amplitude=1.0, noise_std=0.2):
        super().__init__(noise_std)
        self.freq = freq
        self.amplitude = amplitude
    
    def true_function(self, y):
        return self.amplitude * np.sin(self.freq * y)
    
    def description(self):
        return f"Sinusoidal: x = {self.amplitude:.1f}sin({self.freq:.1f}y) + noise"

class SymmetricRelationship(NonlinearRelationship):
    """Symmetric nonlinearity where Cov(x,y) = 0 but x depends on y"""
    
    def __init__(self, noise_std=0.2):
        super().__init__(noise_std)
    
    def true_function(self, y):
        return y**2 - 1  # Symmetric around y=0
    
    def description(self):
        return "Symmetric: x = y² - 1 + noise (Cov(x,y) ≈ 0!)"

class PhaseRetrievalRelationship(NonlinearRelationship):
    """Phase retrieval: observe magnitude, estimate signal"""
    
    def generate_data(self, n_samples=1000, y_range=(-3, 3)):
        # x is the true signal
        x = np.random.uniform(y_range[0], y_range[1], n_samples)
        # y is magnitude observation
        y = np.abs(x) + self.noise_std * np.random.randn(n_samples)
        return x, y
    
    def true_function(self, y):
        # For |x| = y, E[x|y] = 0 due to symmetry
        return np.zeros_like(y)
    
    def description(self):
        return "Phase Retrieval: y = |x| + noise (linear MMSE fails)"

## Estimator Comparison

Linear MMSE gives $\hat{x} = \mu_x + \frac{\text{Cov}(x,y)}{\text{Var}(y)}(y - \mu_y)$ - the best linear predictor for any relationship.

When $\mathbb{E}[x|y]$ is nonlinear, polynomial regression, neural networks, and random forests can achieve lower MSE by approximating the true conditional expectation.

In [3]:
class EstimatorComparison:
    """Compare different estimators on the same data"""
    
    def __init__(self):
        self.estimators = {
            'Linear MMSE': self._linear_mmse,
            'Polynomial (deg=2)': self._polynomial_estimator(2),
            'Polynomial (deg=3)': self._polynomial_estimator(3),
            'Neural Network': self._neural_network,
            'Random Forest': self._random_forest
        }
    
    def _linear_mmse(self, x_train, y_train):
        """Linear MMSE estimator using covariance"""
        # Center the data
        x_mean, y_mean = np.mean(x_train), np.mean(y_train)
        
        # Compute covariances
        cov_xy = np.cov(x_train, y_train)[0, 1]
        var_y = np.var(y_train)
        
        if var_y < 1e-10:
            # Degenerate case
            def predictor(y):
                return np.full_like(y, x_mean)
        else:
            # Linear MMSE formula
            a = cov_xy / var_y
            b = x_mean - a * y_mean
            
            def predictor(y):
                return a * y + b
        
        return predictor
    
    def _polynomial_estimator(self, degree):
        def estimator(x_train, y_train):
            poly_reg = Pipeline([
                ('poly', PolynomialFeatures(degree=degree)),
                ('ridge', Ridge(alpha=0.01))
            ])
            poly_reg.fit(y_train.reshape(-1, 1), x_train)
            
            def predictor(y):
                return poly_reg.predict(y.reshape(-1, 1))
            
            return predictor
        return estimator
    
    def _neural_network(self, x_train, y_train):
        mlp = MLPRegressor(hidden_layer_sizes=(20, 10), max_iter=1000, random_state=42)
        mlp.fit(y_train.reshape(-1, 1), x_train)
        
        def predictor(y):
            return mlp.predict(y.reshape(-1, 1))
        
        return predictor
    
    def _random_forest(self, x_train, y_train):
        rf = RandomForestRegressor(n_estimators=50, random_state=42)
        rf.fit(y_train.reshape(-1, 1), x_train)
        
        def predictor(y):
            return rf.predict(y.reshape(-1, 1))
        
        return predictor
    
    def compare_estimators(self, relationship, n_samples=1000, test_fraction=0.3):
        """Compare all estimators on given relationship"""
        # Generate data
        x, y = relationship.generate_data(n_samples)
        
        # Train/test split
        n_train = int(n_samples * (1 - test_fraction))
        indices = np.random.permutation(n_samples)
        train_idx, test_idx = indices[:n_train], indices[n_train:]
        
        x_train, y_train = x[train_idx], y[train_idx]
        x_test, y_test = x[test_idx], y[test_idx]
        
        # Fit estimators
        fitted_estimators = {}
        mse_results = {}
        
        for name, estimator_func in self.estimators.items():
            try:
                predictor = estimator_func(x_train, y_train)
                fitted_estimators[name] = predictor
                
                # Evaluate on test set
                x_pred = predictor(y_test)
                mse = np.mean((x_test - x_pred)**2)
                mse_results[name] = mse
                
            except Exception as e:
                print(f"Error fitting {name}: {e}")
                mse_results[name] = np.inf
        
        return fitted_estimators, mse_results, (x_train, y_train), (x_test, y_test)

## Performance Analysis

For $x = y^2 - 1$ with symmetric $y$: $\mathbb{E}[xy] = \mathbb{E}[y^3 - y] = 0$, so linear MMSE predicts $\hat{x} = \mu_x$ despite perfect dependence.

Mutual information $I(X;Y)$ measures total dependence. Linear information $-\frac{1}{2}\log(1-\rho^2)$ captures only correlation. Large gaps indicate where nonlinear methods excel.

In [4]:
def interactive_nonlinear_demo():
    """Interactive comparison of linear vs nonlinear MMSE"""
    
    # Available relationships
    relationships = {
        'Quadratic': QuadraticRelationship(a=0.5, b=0.1),
        'Sinusoidal': SinusoidalRelationship(freq=1.0, amplitude=1.5),
        'Symmetric (Cov=0)': SymmetricRelationship(),
        'Phase Retrieval': PhaseRetrievalRelationship()
    }
    
    # Widget controls
    relationship_dropdown = widgets.Dropdown(
        options=list(relationships.keys()),
        value='Quadratic',
        description='Relationship:',
        style={'description_width': 'initial'}
    )
    
    noise_slider = widgets.FloatSlider(
        value=0.3, min=0.1, max=1.0, step=0.1,
        description='Noise Level:',
        style={'description_width': 'initial'}
    )
    
    n_samples_slider = widgets.IntSlider(
        value=500, min=200, max=1000, step=100,
        description='# Samples:',
        style={'description_width': 'initial'}
    )
    
    def update_comparison(relationship_name, noise_level, n_samples):
        # Get selected relationship and set noise level
        relationship = relationships[relationship_name]
        relationship.noise_std = noise_level
        
        # Compare estimators
        comparator = EstimatorComparison()
        estimators, mse_results, train_data, test_data = comparator.compare_estimators(
            relationship, n_samples
        )
        
        x_train, y_train = train_data
        x_test, y_test = test_data
        
        # Create visualization
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # 1. Data scatter plot
        ax = axes[0, 0]
        ax.scatter(y_train, x_train, alpha=0.5, s=20, label='Training data')
        ax.scatter(y_test, x_test, alpha=0.5, s=20, color='orange', label='Test data')
        
        # Plot true function
        y_grid = np.linspace(np.min(y_train) - 0.5, np.max(y_train) + 0.5, 200)
        try:
            x_true = relationship.true_function(y_grid)
            ax.plot(y_grid, x_true, 'k-', linewidth=3, label='True E[x|y]', alpha=0.8)
        except:
            pass
        
        ax.set_xlabel('y')
        ax.set_ylabel('x')
        ax.set_title(f'Data: {relationship.description()}')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 2. Estimator predictions
        ax = axes[0, 1]
        colors = ['red', 'blue', 'green', 'purple', 'brown']
        
        for i, (name, estimator) in enumerate(estimators.items()):
            if name in mse_results and mse_results[name] < np.inf:
                try:
                    x_pred_grid = estimator(y_grid)
                    ax.plot(y_grid, x_pred_grid, color=colors[i % len(colors)], 
                           linewidth=2, label=f'{name} (MSE: {mse_results[name]:.3f})')
                except:
                    pass
        
        # True function
        try:
            ax.plot(y_grid, relationship.true_function(y_grid), 'k-', 
                   linewidth=3, label='True E[x|y]', alpha=0.8)
        except:
            pass
        
        ax.set_xlabel('y')
        ax.set_ylabel('x')
        ax.set_title('Estimator Comparison')
        ax.legend()
        ax.grid(True, alpha=0.3)
        
        # 3. MSE comparison bar chart
        ax = axes[0, 2]
        valid_results = {k: v for k, v in mse_results.items() if v < np.inf}
        
        if valid_results:
            names = list(valid_results.keys())
            mses = list(valid_results.values())
            
            bars = ax.bar(range(len(names)), mses, alpha=0.7)
            ax.set_xticks(range(len(names)))
            ax.set_xticklabels(names, rotation=45, ha='right')
            ax.set_ylabel('Test MSE')
            ax.set_title('MSE Comparison')
            
            # Highlight linear MMSE
            if 'Linear MMSE' in names:
                linear_idx = names.index('Linear MMSE')
                bars[linear_idx].set_color('red')
                bars[linear_idx].set_alpha(0.9)
            
            ax.grid(True, alpha=0.3, axis='y')
        
        # 4. Linear MMSE failure analysis
        ax = axes[1, 0]
        if 'Linear MMSE' in estimators:
            linear_pred = estimators['Linear MMSE'](y_test)
            ax.scatter(x_test, linear_pred, alpha=0.6, s=20)
            
            lims = [min(np.min(x_test), np.min(linear_pred)), 
                   max(np.max(x_test), np.max(linear_pred))]
            ax.plot(lims, lims, 'r--', alpha=0.8, label='Perfect prediction')
            
            ax.set_xlabel('True x')
            ax.set_ylabel('Linear MMSE prediction')
            ax.set_title(f'Linear MMSE: True vs Predicted')
            ax.legend()
            ax.grid(True, alpha=0.3)
        
        # 5. Correlation analysis
        ax = axes[1, 1]
        
        # Compute correlations
        corr_xy = np.corrcoef(x_train, y_train)[0, 1]
        
        # For symmetric case, show why correlation is zero
        if relationship_name == 'Symmetric (Cov=0)':
            # Show that E[xy] = E[y³] ≈ 0 for symmetric distribution
            y_centered = y_train - np.mean(y_train)
            x_centered = x_train - np.mean(x_train)
            
            ax.scatter(y_centered, x_centered * y_centered, alpha=0.5, s=20)
            ax.axhline(0, color='red', linestyle='--')
            ax.set_xlabel('y - μ_y')
            ax.set_ylabel('(x - μ_x)(y - μ_y)')
            ax.set_title(f'Why Cov(x,y) ≈ 0: E[xy] = {np.mean(x_centered * y_centered):.4f}')
        else:
            # Regular correlation plot
            ax.scatter(y_train, x_train, alpha=0.5, s=20)
            ax.set_xlabel('y')
            ax.set_ylabel('x')
            ax.set_title(f'Correlation: ρ = {corr_xy:.3f}')
        
        ax.grid(True, alpha=0.3)
        
        # 6. Information analysis
        ax = axes[1, 2]
        
        # Mutual information estimation (simplified)
        from sklearn.feature_selection import mutual_info_regression
        
        mi = mutual_info_regression(y_train.reshape(-1, 1), x_train)[0]
        
        # Linear information (based on correlation)
        linear_info = -0.5 * np.log(1 - corr_xy**2) if abs(corr_xy) < 0.99 else 5
        
        info_types = ['Linear Info\n(-½log(1-ρ²))', 'Mutual Info\nI(X;Y)']
        info_values = [linear_info, mi]
        
        bars = ax.bar(info_types, info_values, alpha=0.7, color=['red', 'blue'])
        ax.set_ylabel('Information (nats)')
        ax.set_title('Information Content Analysis')
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add text annotations
        for bar, val in zip(bars, info_values):
            ax.text(bar.get_x() + bar.get_width()/2, val + 0.05, 
                   f'{val:.3f}', ha='center', fontsize=10)
        
        plt.tight_layout()
        plt.show()
        
        # Summary statistics
        print(f"\n=== Analysis Summary ===")
        print(f"Relationship: {relationship.description()}")
        print(f"Correlation ρ(x,y): {corr_xy:.4f}")
        print(f"Mutual Information I(X;Y): {mi:.4f} nats")
        print(f"\nMSE Results:")
        
        best_mse = min(mse_results.values())
        linear_mse = mse_results.get('Linear MMSE', np.inf)
        
        for name, mse in sorted(mse_results.items(), key=lambda x: x[1]):
            if mse < np.inf:
                efficiency = best_mse / mse * 100
                print(f"  {name:20s}: {mse:.4f} ({efficiency:.1f}% efficiency)")
        
        if linear_mse < np.inf and best_mse > 0:
            linear_loss = (linear_mse - best_mse) / best_mse * 100
            print(f"\nLinear MMSE Performance Loss: {linear_loss:.1f}%")
    
    # Create interactive widget
    interactive_plot = widgets.interactive(
        update_comparison,
        relationship_name=relationship_dropdown,
        noise_level=noise_slider,
        n_samples=n_samples_slider
    )
    
    display(interactive_plot)

# Run the interactive demo
interactive_nonlinear_demo()

interactive(children=(Dropdown(description='Relationship:', options=('Quadratic', 'Sinusoidal', 'Symmetric (Co…

## Linear MMSE Failure Modes

1. **Symmetric case**: Cov(x,y) ≈ 0 despite strong dependence
2. **Phase retrieval**: Many-to-one mapping destroys linear relationship
3. **Information loss**: Linear estimators capture only correlation-based information

Trade-off: Linear MMSE requires O(m²) computation vs O(training time) for nonlinear methods.