import torch
from torch.optim import Optimizer
import numpy as np
import math

from typing import List, Tuple, Dict, Optional, Callable, Union, Any, Iterable
from typing_extensions import ParamSpec, Self, TypeAlias
from torch import Tensor


class AdamPaLM2Beta(torch.optim.Optimizer):
    """
    Adam with PaLM2-style beta2 scheduling.
    Changes beta2 from an initial to a final value at a specified training step.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.95), eps=1e-8, weight_decay=0,
                 beta2_final=0.99, warmup_iters=10000):
        
        # We store the initial beta2 from the betas tuple
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid initial beta_2 parameter: {betas[1]}")

        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamPaLM2Beta, self).__init__(params, defaults)
        
        # Custom parameters for our scheduler
        self.beta2_initial = betas[1]
        self.beta2_final = beta2_final
        self.schedule_start_step = warmup_iters
        self.global_step = 0
        self._schedule_triggered = False

    @torch.no_grad()
    def step(self, closure=None):
        self.global_step += 1

        # --- NEW: Scheduling Logic ---
        # Check if we have reached the step to switch beta2 and haven't switched yet.
        if self.global_step >= self.schedule_start_step and not self._schedule_triggered:
            print(f"Step {self.global_step}: Switching beta2 from {self.beta2_initial} to {self.beta2_final}")
            for group in self.param_groups:
                # Update beta2 for all parameter groups
                beta1, _ = group['betas']
                group['betas'] = (beta1, self.beta2_final)
            self._schedule_triggered = True # Ensure this only runs once
        
        # The rest of the function is identical to the standard Adam.step()
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr, (beta1, beta2), eps, weight_decay = group['lr'], group['betas'], group['eps'], group['weight_decay']
            
            for p in group['params']:
                if p.grad is None: continue
                grad = p.grad
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                state['step'] += 1
                
                if weight_decay != 0: grad = grad.add(p, alpha=weight_decay)
                
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                m_hat = exp_avg / bias_correction1
                v_hat = exp_avg_sq / bias_correction2
                
                update = m_hat / (v_hat.sqrt().add_(eps))
                p.add_(update, alpha=-lr)
                
        return loss


class AdamBeta2Schedule(torch.optim.Optimizer):
    """
    Adam with a linear schedule for beta2.
    Linearly interpolates beta2 from an initial value (e.g., beta1) to its final 
    value over a set number of training steps.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0,
                 warmup_iters=10000):
        
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super(AdamBeta2Schedule, self).__init__(params, defaults)
        
        # Custom parameters for our scheduler
        self.beta1_val = betas[0]
        self.beta2_final = betas[1]
        # We start beta2 at the value of beta1 as requested
        self.beta2_start = self.beta1_val 
        self.total_steps = warmup_iters
        self.global_step = 0

    @torch.no_grad()
    def step(self, closure=None):
        self.global_step += 1

        # --- NEW: Scheduling Logic ---
        # Calculate the progress of training
        progress = min(self.global_step / self.total_steps, 1.0)
        
        # Linearly interpolate beta2 from its start value to its final value
        current_beta2 = self.beta2_start + progress * (self.beta2_final - self.beta2_start)
        
        # Update beta2 in all parameter groups
        for group in self.param_groups:
            group['betas'] = (self.beta1_val, current_beta2)
        
        # The rest of the function is identical to the standard Adam.step()
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr, (beta1, beta2), eps, weight_decay = group['lr'], group['betas'], group['eps'], group['weight_decay']

            for p in group['params']:
                if p.grad is None: continue
                grad = p.grad
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                state['step'] += 1
                
                if weight_decay != 0: grad = grad.add(p, alpha=weight_decay)
                
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                m_hat = exp_avg / bias_correction1
                v_hat = exp_avg_sq / bias_correction2
                
                update = m_hat / (v_hat.sqrt().add_(eps))
                p.add_(update, alpha=-lr)
                
        return loss


class AdamEpsilonSchedule(torch.optim.Optimizer):
    """
    Adam with a linear schedule for epsilon.
    Linearly interpolates epsilon from eps1 to eps2 over a set number of steps.
    """
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), weight_decay=0,
                 eps1=1e-12, eps2=1e-0, warmup_iters=10000):
        
        if not 0.0 <= eps1: raise ValueError(f"Invalid eps1 value: {eps1}")
        if not 0.0 <= eps2: raise ValueError(f"Invalid eps2 value: {eps2}")
            
        # The initial epsilon is eps1
        defaults = dict(lr=lr, betas=betas, eps=eps1, weight_decay=weight_decay)
        super(AdamEpsilonSchedule, self).__init__(params, defaults)
        
        # Custom parameters for our scheduler
        self.eps1 = eps1
        self.eps2 = eps2
        self.total_steps = warmup_iters
        self.global_step = 0

    @torch.no_grad()
    def step(self, closure=None):
        self.global_step += 1

        # --- NEW: Scheduling Logic ---
        # Calculate training progress
        progress = min(self.global_step / self.total_steps, 1.0)
        
        # Linearly interpolate epsilon
        current_eps = self.eps1 + progress * (self.eps2 - self.eps1)
        
        # Update epsilon in all parameter groups
        for group in self.param_groups:
            group['eps'] = current_eps
        
        # The rest of the function is identical to the standard Adam.step()
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr, (beta1, beta2), eps, weight_decay = group['lr'], group['betas'], group['eps'], group['weight_decay']

            for p in group['params']:
                if p.grad is None: continue
                grad = p.grad
                state = self.state[p]

                if len(state) == 0:
                    state['step'] = 0
                    state['exp_avg'] = torch.zeros_like(p)
                    state['exp_avg_sq'] = torch.zeros_like(p)
                
                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                state['step'] += 1
                
                if weight_decay != 0: grad = grad.add(p, alpha=weight_decay)
                
                exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
                
                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                
                m_hat = exp_avg / bias_correction1
                v_hat = exp_avg_sq / bias_correction2
                
                denominator = torch.maximum(v_hat, torch.tensor(eps, device=v_hat.device)).sqrt()
                update = m_hat / denominator
                p.add_(update, alpha=-lr)
                
        return loss
