import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.functional import normalize
import os
import math

convert_cuda_to_cpu = True
#print(convert_cuda_to_cpu)
'''
This is for optimizing 2D parameters
----------- CONSTRAINED SGD ---------------
'''

class SAM(optim.Optimizer):
    def __init__(self, params, base_optimizer, adaptive=False, rho=0.05, **kwargs):
        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)
        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)
        self.convert_cuda_to_cpu = convert_cuda_to_cpu

    def __setstate__(self, state):
        super(SAM, self).__setstate__(state)

    @torch.no_grad()
    def ascent_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group['rho'] / (grad_norm + 1e-12)

            for p in group['params']:
                if p.grad is None: continue
                cur_theta   = p.data.clone()        #current params
                d_p         = p.grad.data.clone()   #euclidean gradient (at cur_theta)
                if d_p.is_cuda and self.convert_cuda_to_cpu:
                    d_p       = d_p.cpu()
                    cur_theta = cur_theta.cpu()
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def descent_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"
        self.base_optimizer.step()  # do the actual "sharpness-aware" update
        if zero_grad: self.zero_grad()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups


class RSAM(optim.Optimizer):
    def __init__(   
                self,  
                params,
                lr=1e-3, 
                momentum = 0, 
                weight_decay = 0, 
                dampening = 0, 
                rho=0, 
                nesterov = False, 
                manifold = None, 
                proj = None, 
                retr = None, 
                transp = None
                ):

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))

        defaults = dict(
            lr=lr, 
            momentum=momentum, 
            dampening = dampening, 
            rho=rho, 
            nesterov=nesterov, 
            weight_decay=weight_decay,
            manifold=manifold, 
            proj=proj, 
            retr=retr, 
            transp=transp)
        
        super(RSAM, self).__init__(params, defaults)
        self.convert_cuda_to_cpu = convert_cuda_to_cpu
    def __setstate__(self, state):
        super(RSAM, self).__setstate__(state)

    @torch.no_grad()
    def ascent_step(self, zero_grad=False):
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            manifold = group['manifold']
            rho = group['rho']
            for p in group['params']:
                if p.grad is None: continue
                if manifold is None:
                    raise NotImplementedError('Supposed to use Manifold Here')
                proj = group['proj']                 
                retr = group['retr']                 
                cur_theta = p.data.clone()                
                d_p = p.grad.data.clone()           

                if d_p.is_cuda and self.convert_cuda_to_cpu:
                    d_p       = d_p.cpu()
                    cur_theta = cur_theta.cpu()
                if weight_decay != 0:
                    d_p.add_(cur_theta, alpha=weight_decay)
                
                riemann_grad = proj(cur_theta, d_p)     
                self.state[p]["old_p"] = p.data.clone()

                Z = torch.abs(cur_theta.clone().detach()).pow(.5)
                vect = torch.mul(Z, riemann_grad)
                u = vect/(torch.norm(vect) + 1e-12)
                increment = torch.mul(Z, u)
                retract_dir = proj(cur_theta, increment)
                theta_hat = retr(cur_theta, -rho*retract_dir)
                p.data.copy_(theta_hat)
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def descent_step(self, zero_grad=False):
        for group in self.param_groups:
            weight_decay    = group['weight_decay']
            manifold        = group['manifold']
            for p in group["params"]:
                if p.grad is None: continue
                if manifold is None: 
                    raise NotImplementedError('Supposed to use Manifold Here')
                p.data      = self.state[p]["old_p"]       
                proj        = group['proj']                
                retr        = group['retr']                
                cur_theta   = p.data.clone()                
                d_p         = p.grad.data.clone()           
                if d_p.is_cuda and self.convert_cuda_to_cpu:
                    flag_cuda = 1
                    d_p       = d_p.cpu()
                    cur_theta = cur_theta.cpu()
                else:
                    flag_cuda = 0
                if weight_decay != 0:
                    d_p.add_(cur_theta, alpha=weight_decay)
                riemann_grad = proj(cur_theta, d_p) 
                new_theta = retr(cur_theta, -group['lr'] * riemann_grad)
                if flag_cuda==1:
                    new_theta = new_theta.cuda()
                p.data.copy_(new_theta)
        if zero_grad: self.zero_grad()


class RSAM_OCNN(optim.Optimizer):
    def __init__(   
                self,  
                params,
                lr=1e-3, 
                momentum = 0, 
                weight_decay = 0, 
                dampening = 0, 
                rho=0, 
                nesterov = False, 
                manifold = None, 
                proj = None, 
                retr = None, 
                transp = None
                ):

        if not 0.0 <= lr:
            raise ValueError("Invalid learning rate: {}".format(lr))

        defaults = dict(
            lr=lr, 
            momentum=momentum, 
            dampening = dampening, 
            rho=rho, 
            nesterov=nesterov, 
            weight_decay=weight_decay,
            manifold=manifold, 
            proj=proj, 
            retr=retr, 
            transp=transp)
        
        super(RSAM_OCNN, self).__init__(params, defaults)
        self.convert_cuda_to_cpu = convert_cuda_to_cpu
    def __setstate__(self, state):
        super(RSAM_OCNN, self).__setstate__(state)

    @torch.no_grad()
    def ascent_step(self, zero_grad=False):
        for group in self.param_groups:
            weight_decay = group['weight_decay']
            manifold = group['manifold']
            rho = group['rho']
            for p in group['params']:
                if p.grad is None: continue
                if manifold is None:
                    raise NotImplementedError('Supposed to use Manifold Here')
                proj = group['proj']               
                retr = group['retr']              
                cur_theta = p.data.clone()              
                
                num_filters, filters_depth, filters_height, filters_width = (
                    cur_theta.size(dim = 0), 
                    cur_theta.size(dim = 1), 
                    cur_theta.size(dim = 2), 
                    cur_theta.size(dim = 3)
                )
                cur_theta = torch.reshape(cur_theta, 
                    (filters_depth*filters_height*filters_width, num_filters)
                )

                d_p = p.grad.data.clone()           
                d_p = torch.reshape(d_p, 
                    (filters_depth*filters_height*filters_width, num_filters)
                )

                if d_p.is_cuda and self.convert_cuda_to_cpu:
                    d_p       = d_p.cpu()
                    cur_theta = cur_theta.cpu()
                if weight_decay != 0:
                    d_p.add_(cur_theta, alpha=weight_decay)
                
                riemann_grad = proj(cur_theta, d_p)     
                self.state[p]["old_p"] = p.data.clone()

                Z = torch.abs(cur_theta.clone().detach()).pow(.5) 
                vect = torch.mul(Z, riemann_grad)
                u = vect/(torch.norm(vect) + 1e-12)
                increment = torch.mul(Z, u)
                retract_dir = proj(cur_theta, increment)
                theta_hat = retr(cur_theta, -rho*retract_dir)
                theta_hat = torch.reshape(theta_hat, 
                    (num_filters, filters_depth, filters_height, filters_width)
                )

                p.data.copy_(theta_hat)
        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def descent_step(self, zero_grad=False):
        for group in self.param_groups:
            weight_decay    = group['weight_decay']
            manifold        = group['manifold']
            for p in group["params"]:
                if p.grad is None: continue
                if manifold is None: 
                    raise NotImplementedError('Supposed to use Manifold Here')
                p.data      = self.state[p]["old_p"]        
                proj        = group['proj']                 
                retr        = group['retr']               
                cur_theta   = p.data.clone()                
                num_filters, filters_depth, filters_height, filters_width = cur_theta.size(dim = 0), cur_theta.size(dim = 1), cur_theta.size(dim = 2), cur_theta.size(dim = 3)
                cur_theta   = torch.reshape(cur_theta, (filters_width*filters_depth*filters_height, num_filters))
                d_p         = p.grad.data.clone()           
                d_p         = torch.reshape(d_p, (filters_width*filters_depth*filters_height, num_filters))
                if d_p.is_cuda and self.convert_cuda_to_cpu:
                    flag_cuda = 1
                    d_p       = d_p.cpu()
                    cur_theta = cur_theta.cpu()
                else:
                    flag_cuda = 0
                if weight_decay != 0:
                    d_p.add_(cur_theta, alpha=weight_decay)
                riemann_grad = proj(cur_theta, d_p) 
                new_theta = retr(cur_theta, -group['lr'] * riemann_grad)
                if flag_cuda==1:
                    new_theta = new_theta.cuda()
                new_theta   = torch.reshape(new_theta, (num_filters, filters_depth, filters_height, filters_width))
                p.data.copy_(new_theta)
        if zero_grad: self.zero_grad()
