from bitsandbytes.optim.optimizer import Optimizer2State

import torch

from .galore_projector import GaLoreProjector


class AdamW8bit(Optimizer2State):
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-2, amsgrad=False, optim_bits=32,args=None, min_8bit_size=4096, percentile_clipping=100, block_wise=True, is_paged=False, new_optimizer=False, rank=None, update_proj_gap=200):
        super().__init__( "adam", params, lr, betas, eps, weight_decay, 8, args, min_8bit_size, percentile_clipping, block_wise, is_paged=is_paged )
        if new_optimizer:
            for group in self.param_groups:
                for p in group["params"]:
                    self.state[p]["backup"] = p.data.clone().detach().to(torch.device('cpu'))
            self.rank = rank
            self.update_proj_gap = update_proj_gap

    @torch.no_grad()
    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        overflows = []

        if not self.initialized:
            self.check_overrides()
            self.to_gpu()  # needed for fairseq pure fp16 training
            self.initialized = True

        #if self.is_paged: self.page_mng.prefetch_all()
        for gindex, group in enumerate(self.param_groups):
            for pindex, p in enumerate(group["params"]):
                if p.grad is None:
                    continue
                state = self.state[p]
                
                if "step" not in state:
                    state["step"] = 0

                state["step"] += 1
                
                # GaLore Projection
                if "rank" in group:
                    if "projector" not in state:
                        state["projector"] = GaLoreProjector(group["rank"], update_proj_gap=group["update_proj_gap"], scale=group["scale"], proj_type=group["proj_type"])
                        
                    if 'weight_decay' in group and group['weight_decay'] > 0:
                        # ensure that the weight decay is not applied to the norm grad
                        group['weight_decay_saved'] = group['weight_decay']
                        group['weight_decay'] = 0
                    
                    grad = state["projector"].project(p.grad, state["step"])
                    
                    # suboptimal implementation
                    p.saved_data = p.data.clone()
                    p.data = grad.clone().to(p.data.dtype).to(p.data.device)
                    p.data.zero_()
                    p.grad = grad

                if 'state1' not in state:
                    self.init_state(group, p, gindex, pindex)

                self.prefetch_state(p)
                self.update_step(group, p, gindex, pindex)
                torch.cuda.synchronize()
                
                if 'backup' in state:
                    state['backup'] = state['backup'].to(torch.device('cpu'))
                
                # GaLore Projection Back
                if "rank" in group:
                    p.data = p.saved_data.add_(state["projector"].project_back(p.data))  
                    
                    # apply weight decay
                    if 'weight_decay_saved' in group:
                        p.data.add_(p.data, alpha=-group['lr'] * group['weight_decay_saved'])
                        group['weight_decay'] = group['weight_decay_saved']
                        del group['weight_decay_saved']

                with torch.no_grad():
                    if 'backup' in state and 'rank' in group:
                        original_device = p.data.device
                        p_cpu = p.data.detach().to(torch.device('cpu'))
                        if (state['step']) % self.update_proj_gap == 0:
                            delta_param = p_cpu - state['backup']
                            if delta_param.dtype != torch.float:
                                float_data = False
                                original_type = delta_param.dtype
                                matrix = delta_param.float()
                            else:
                                float_data = True
                                matrix = delta_param
                            U, s, Vh = torch.linalg.svd(matrix, full_matrices = False)
                            # print('---------------------------------')
                            # print(s[:rank])
                            # print(s[rank:])
                            # s[ s < 1e-2 ] = 0
                            s[self.rank:] = 0
                            delta_param = torch.matmul(U, torch.matmul(torch.diag(s), Vh))
                            if not float_data:
                                delta_param = delta_param.to(original_device).type(original_type)
                            p.data = delta_param + state['backup'].to(original_device).to(original_device)
                            del delta_param              
                
        if self.is_paged:
            # all paged operation are asynchronous, we need
            # to sync to make sure all tensors are in the right state
            torch.cuda.synchronize()


        return loss