import torch
import copy
import pytorch_lightning as pl
import numpy as np

class ContNet(pl.LightningModule):
    def __init__(self, logcontvar0: int, cont_lr: float, cont_reg: float, warmup_epochs: int):
        super().__init__()

        if np.isnan(logcontvar0):
            self.logcontvar = torch.tensor(logcontvar0)
        else:
            self.logcontvar0 = logcontvar0
            self.logcontvar = torch.nn.Parameter(torch.tensor(logcontvar0))
        
        self.cont_lr = cont_lr
        self.warmup_epochs = warmup_epochs
        self.cont_reg = cont_reg

        self.automatic_optimization = False
    
    def gen_perturbed_params(self, ref_params: dict, contvar: float):
        rand_samp = {}
        perturb_params = {}

        for pname in ref_params:
            if str(pname).startswith('logcontvar'):
                perturb_params[pname] = ref_params[pname]
            else:
                #stdev = torch.sqrt(torch.tensor(1/ref_params[pname].shape[-1]))
                rand_samp[pname] = torch.randn(ref_params[pname].shape, device=self.device)
                delta_param = torch.sqrt(contvar)*rand_samp[pname]
                perturb_params[pname] = ref_params[pname] - delta_param

        return rand_samp, perturb_params

    def perturb_params(self):
        ref_params = self.state_dict()

        if not torch.isnan(self.logcontvar):
            contvar = torch.exp(torch.clamp(self.logcontvar, max=self.logcontvar0 + 2.0))

            self.log('logcontvar', self.logcontvar)
            self.log('contvar', contvar)

            # generate and load perturbed parameters
            rand_samp, perturb_params = self.gen_perturbed_params(ref_params, contvar)
            self.load_state_dict(perturb_params)
        else:
            rand_samp = None

        return rand_samp, ref_params

    def contvar_grad(self, rand_samp: dict, loss: torch.Tensor):
        if not torch.isnan(self.logcontvar):
            n_params = sum(p.numel() for p in self.parameters()) - 1

            # compute gradient of logcontvar
            logcontvar_grad = -n_params/2.0
            for pname in rand_samp:
                logcontvar_grad += rand_samp[pname].pow(2.0).sum()/2.0
            logcontvar_grad *= loss.detach()
        
            if self.current_epoch > self.warmup_epochs:
                self.logcontvar.grad = logcontvar_grad + self.cont_reg*n_params*torch.exp(self.logcontvar - self.logcontvar0)
            self.log('logcontvar_grad', logcontvar_grad)