import numpy as np

import torch

from sklearn.decomposition import randomized_svd
from tqdm import tqdm

from explicit import ExplicitSiren, ExplicitFourierNet, ExplicitGaussNet, ExplicitPositionalReLUNet

def loss_func(input, target):
    return 0.5 * (input - target).pow(2).mean()

def compute_ntk(hs, fgrads, skip_first=False):
    ntk = 0
    for i, (_h, _fz) in enumerate(zip(hs, fgrads)):
        if i==0 and skip_first:
            continue
        ntk += (_fz @ _fz.T) * (1 + _h @ _h.T) / len(_h)
    return ntk

def fast_fit_q(x, y, nbr_scale=2):
    valid = x != 0
    _wght = torch.exp(-nbr_scale * x ** 2)
    _wght = _wght / _wght.sum()
    return (
        (y[valid] / (x[valid] ** 2)) * _wght[valid]
    ).sum()

DYN_SAMPLE_PTS = {
    'deterministic' : list(sorted(set(list(range(0, 20, 1)) + list(range(20, 200, 10)) + list(range(200, 1000, 50)) + list(range(1000, 10_000, 1000))))),
    'stochastic' : list(sorted(set(list(range(0, 20, 1)) + list(range(20, 1_000, 20)))))
}

class ExperimentRunner():
    def __init__(
            self, config, model, dset, validation_dset, device, 
            skip_first_ntk=False, train_step=None
        ):
        self.config = config
        #self.dset = dset
        #self.val_dset = validation_dset
        self.device = device
        self.model = model.to(device)
        self.skip_first_ntk = skip_first_ntk

        self.distances = torch.sqrt((dset.coords[:, :, None] - dset.coords.T[None, :, :]).square().sum(1)).data
        self.distances = self.distances.to(device)
        
        self.exp_dict = {
            #losses for marking phase transition
            'train_losses' : [], 'val_losses' : [],
            'corr_len' : [],
            'ntk_min' : [],
            #residuals
            'loss_field' : [],
            #ntk
            'ntk' : [], 
            #spectral
            's' : [], 'U' : [], 
            #components of the ntk decomposition
            'ntk_comp' : []
        }

        self.x, self.y = dset.coords.to(device), dset.pixels.to(device)
        self.g0 = torch.ones_like(self.y)
        self.val_x, self.val_y = validation_dset.coords.to(device), validation_dset.pixels.to(device)

        if train_step is not None:
            self._make_train_step()

    def loss_func(self, input, target):
        return 0.5 * (input - target).pow(2).mean()

    def _make_train_step(self):
        self.lr = self.config['optimizer.lr']
        if self.config['optimizer.choice'] == 'gd':
            self.optimizer = torch.optim.SGD(lr=self.lr, params=self.model.parameters())
            def train_step():
                #Optimizer: GD, ADAM
                train_out = self.model(self.x)
                loss = self.loss_func(train_out, self.y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            self.train_step = train_step
            self.n_steps = 10_000
            self.steps_to_collect = set(
                list(range(0, 20, 1)) + list(range(20, 200, 10)) 
                + list(range(200, 1000, 50)) + list(range(1000, self.n_steps, 1000))
            )
        elif self.config['optimizer.choice'] == 'adam':
            self.optimizer = torch.optim.Adam(lr=self.lr, params=self.model.parameters())
            def train_step():
                #Optimizer: GD, ADAM
                train_out = self.model(self.x)
                loss = self.loss_func(train_out, self.y)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
            self.train_step = train_step
            self.n_steps = 10_000
            self.steps_to_collect = set(
                list(range(0, 20, 1)) + list(range(20, 200, 10)) 
                + list(range(200, 1000, 50)) + list(range(1000, self.n_steps, 1000))
            )
        elif self.config['optimizer.choice'] == 'sgd':
            self.optimizer = torch.optim.SGD(lr=self.lr, params=self.model.parameters())
            training_dset = torch.utils.data.TensorDataset(self.x, self.y)
            training_dataloader = torch.utils.data.DataLoader(training_dset, batch_size=256, shuffle=True)
            def train_step():
                for batch in training_dataloader:
                    batch_x, batch_y = batch
                    batch_x = batch_x.to(self.device)
                    batch_y = batch_y.to(self.device)
                    out = self.model(batch_x)
                    loss = self.loss_func(out, batch_y)
                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
            self.train_step = train_step
            self.n_steps = 1_000
            self.steps_to_collect = set(
                list(range(0, 20, 1)) + list(range(20, self.n_steps, 20))
            )
        elif self.config['optimizer.choice'] == 'lbfgs':
            self.optimizer = torch.optim.LBFGS(lr=self.lr, params=self.model.parameters())
            def train_step():
                def closure():
                    self.optimizer.zero_grad()
                    out = self.model(self.x)
                    loss = self.loss_func(out, self.y)
                    loss.backward()
                    return loss
                self.optimizer.step(closure)
            self.train_step = train_step
            self.n_steps = 10_000
            self.steps_to_collect = set(
                list(range(0, 20, 1)) + list(range(20, 200, 10)) 
                + list(range(200, 1000, 50)) + list(range(1000, self.n_steps, 1000))
            )
        else:
            raise ValueError(f"{self.config['optimizer.choice']} is not a recognized optimizer")

    def eval_step(self):
        with torch.no_grad():
            validation_out = self.model(self.val_x)
            validation_loss = self.loss_func(validation_out, self.val_y).item()
    
            train_out = self.model(self.x)
            train_loss = self.loss_func(train_out, self.y).item()

        return train_loss, validation_loss

    def log_step(self):
        with torch.no_grad():
            out, cache = self.model(self.x, keep_cache=True)
            
            dfdx, dfdzs = self.model.reverse(self.g0, cache)
            _ins = [self.x] + cache['hs'][:-1]
            _outs = dfdzs
        
            comp = {
                'ins' : [i.data.clone().cpu().numpy() for i in _ins], #hs - should I store hs, or the cosine?
                'outs' : [o.data.clone().cpu().numpy() for o in _outs] #dfdzs
            }
            self.exp_dict['ntk_comp'].append(comp)
            
            ntk = compute_ntk(_ins, _outs, skip_first=self.skip_first_ntk)
            
            #TODO: track min values
            self.exp_dict['ntk_min'].append(ntk.min().item())

            #TODO: track corr length
            _ddiag = torch.diag(ntk)
            _cos_ntk = ntk / torch.sqrt(_ddiag[None, :] * _ddiag[:, None])
            _x = self.distances.flatten()
            _y = (torch.log(abs(_cos_ntk) + 1e-5).flatten())
            self.exp_dict['corr_len'].append(fast_fit_q(_x, _y).item())
            
            ntk = ntk.cpu().data.numpy()
    
        #TODO: track spectra
        U, s, V = randomized_svd(ntk, 10, n_iter=3)
        self.exp_dict['s'].append(s)
        self.exp_dict['U'].append(U)
    
        self.exp_dict['ntk'].append(ntk)
    
        #track residuals
        loss_field = 0.5 * (self.y - out) ** 2
        self.exp_dict['loss_field'].append(loss_field.data.clone().cpu().numpy())

    def run_exp(self):
        for i in tqdm(range(self.n_steps)):
            #eval loop
            train_loss, validation_loss = self.eval_step()
            self.exp_dict['train_losses'].append(train_loss)
            self.exp_dict['val_losses'].append(validation_loss)
        
            #log
            if i in self.steps_to_collect:
                self.log_step()
        
            self.train_step()

        self.exp_dict['s'] = np.array(self.exp_dict['s'])
        self.exp_dict['U'] = np.array(self.exp_dict['U'])
        self.exp_dict['corr_len'] = np.array(self.exp_dict['corr_len'])
        self.exp_dict['ntk_min'] = np.array(self.exp_dict['ntk_min'])

        self.steps_to_collect = list(sorted(self.steps_to_collect))

        self.model = self.model.to('cpu')

        
def build_model(config):
    #'model.architecture' : ['SIREN', 'Fourier', 'WIRE'],
    #torch.manual_seed(config['seed'])
    hidden_features = config['model.hidden_features']
    hidden_layers = config['model.hidden_layers']
    if config['model.architecture'] == 'SIREN':
        model = ExplicitSiren(
            in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
            out_features=1, outermost_linear=True, first_omega_0=30, hidden_omega_0=30
        )
    elif config['model.architecture'] == 'Fourier':
        model = ExplicitFourierNet(
            in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
            out_features=1, sigma=10
        )
    elif config['model.architecture'] == 'Gauss':
        model = ExplicitGaussNet(
            in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
            out_features=1, sigma=10
        )
    elif config['model.architecture'] == 'ReluPE':
        model = ExplicitPositionalReLUNet(
            in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
            out_features=1, sidelength=64
        )
    else:
        raise ValueError(f"{config['model.architecture']} is not a recognized architecture")

    return model


