#!/usr/bin/env python
"""
SGD-based rotation learning for ARQ
"""

import torch
import torch.nn as nn
import torch.optim as optim
from typing import Optional, Dict, Any
from tqdm import tqdm
import math

from .losses import MultiObjectiveLoss


class RotationLearnerSGD:
    """
    Rotation learner using SGD with momentum
    """
    
    def __init__(self, 
                 transform: nn.Module,
                 loss_fn: Optional[MultiObjectiveLoss] = None,
                 lr: float = 10.0,
                 momentum: float = 0.9,
                 steps: int = 100,
                 lr_scheduler: str = 'none',
                 lr_min: float = 0.01,
                 warmup_steps: int = 0):
        self.transform = transform
        self.loss_fn = loss_fn or MultiObjectiveLoss(
            lambda_quant=1.0,
            lambda_ortho=0.0,
            lambda_entropy=0.0,
            bits=4,
            sym=True
        )
        self.lr = lr
        self.momentum = momentum
        self.steps = steps
        self.lr_scheduler = lr_scheduler
        self.lr_min = lr_min
        self.warmup_steps = warmup_steps
        
        # SGD optimizer with momentum
        self.optimizer = optim.SGD(
            self.transform.parameters(), 
            lr=lr, 
            momentum=momentum
        )
        
        # Create learning rate scheduler
        self.scheduler = self._create_scheduler()
    
    def _create_scheduler(self):
        """Create learning rate scheduler based on type"""
        if self.lr_scheduler == 'none':
            return None
        elif self.lr_scheduler == 'cosine':
            # Cosine annealing from lr to lr_min
            def cosine_lr(step):
                if step >= self.steps:
                    return self.lr_min / self.lr
                return self.lr_min / self.lr + (1 - self.lr_min / self.lr) * \
                       (1 + math.cos(math.pi * step / self.steps)) / 2
            return torch.optim.lr_scheduler.LambdaLR(self.optimizer, cosine_lr)
        elif self.lr_scheduler == 'warmup_cosine':
            # Linear warmup then cosine annealing
            def warmup_cosine_lr(step):
                if step < self.warmup_steps:
                    return step / self.warmup_steps
                elif step >= self.steps:
                    return self.lr_min / self.lr
                else:
                    progress = (step - self.warmup_steps) / (self.steps - self.warmup_steps)
                    return self.lr_min / self.lr + (1 - self.lr_min / self.lr) * \
                           (1 + math.cos(math.pi * progress)) / 2
            return torch.optim.lr_scheduler.LambdaLR(self.optimizer, warmup_cosine_lr)
        elif self.lr_scheduler == 'linear':
            # Linear decay from lr to lr_min
            def linear_lr(step):
                if step >= self.steps:
                    return self.lr_min / self.lr
                return 1 - (1 - self.lr_min / self.lr) * step / self.steps
            return torch.optim.lr_scheduler.LambdaLR(self.optimizer, linear_lr)
        else:
            raise ValueError(f"Unknown scheduler type: {self.lr_scheduler}")
    
    def learn(self, data_loader) -> torch.Tensor:
        """
        Learn rotation using SGD
        
        Args:
            data_loader: DataLoader with calibration activations
            
        Returns:
            Final rotation matrix Q
        """
        scheduler_info = ""
        if self.lr_scheduler != 'none':
            scheduler_info = f", scheduler={self.lr_scheduler}, lr_min={self.lr_min}"
            if self.lr_scheduler == 'warmup_cosine':
                scheduler_info += f", warmup={self.warmup_steps}"
        print(f"\nLearning rotation with SGD (lr={self.lr}, momentum={self.momentum}{scheduler_info})")
        
        pbar = tqdm(range(self.steps), desc="Learning")
        data_iter = iter(data_loader)
        
        history = {
            'loss': [],
            'grad_norm': []
        }
        
        for step in pbar:
            # Get batch
            try:
                batch = next(data_iter)[0]  # TensorDataset returns tuple
            except StopIteration:
                data_iter = iter(data_loader)
                batch = next(data_iter)[0]
            
            # Forward pass
            x_rot = self.transform(batch)
            Q = self.transform.get_matrix()
            
            # Compute loss (pass transform for sparsity loss and step for curriculum)
            losses = self.loss_fn(batch, x_rot, Q, transform=self.transform, step=step)
            total_loss = losses['total']
            
            # Backward pass
            self.optimizer.zero_grad()
            total_loss.backward()
            
            # Track gradient norm
            grad_norm = 0.0
            for p in self.transform.parameters():
                if p.grad is not None:
                    grad_norm += p.grad.norm().item() ** 2
            grad_norm = grad_norm ** 0.5
            
            # Optimizer step
            self.optimizer.step()
            
            # Learning rate scheduler step
            if self.scheduler is not None:
                self.scheduler.step()
            
            # Get current learning rate
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Record
            history['loss'].append(total_loss.item())
            history['grad_norm'].append(grad_norm)
            
            # Update progress
            progress_dict = {
                'loss': f"{total_loss.item():.6f}",
                'grad': f"{grad_norm:.6f}",
                'lr': f"{current_lr:.6f}"
            }
            pbar.set_postfix(progress_dict)
            
            # Debug output for first/last steps or every 50 steps
            if step == 0 or step == self.steps - 1 or step % 50 == 0:
                print(f"\nStep {step}:")
                print(f"  Loss: {total_loss.item():.6f}")
                print(f"  Grad norm: {grad_norm:.6f}")
                print(f"  Quant loss: {losses['quantization'].item():.6f}")
                print(f"  Learning rate: {current_lr:.6f}")
                # Show uniformity loss if present
                if 'uniformity' in losses and losses['uniformity'].item() > 0:
                    print(f"  Uniformity loss: {losses['uniformity'].item():.6f}")
        
        # Final rotation matrix (ensure float32)
        Q_final = self.transform.get_matrix().to(torch.float32)
        
        # Show improvement
        if len(history['loss']) > 1:
            initial_loss = history['loss'][0]
            final_loss = history['loss'][-1]
            improvement = (initial_loss - final_loss) / initial_loss * 100
            print(f"\nTraining complete:")
            print(f"  Initial loss: {initial_loss:.6f}")
            print(f"  Final loss: {final_loss:.6f}")
            print(f"  Improvement: {improvement:.1f}%")
        
        return Q_final