import numpy as np
import pandas as pd

import torch

from tqdm import tqdm

import sklearn.metrics as skmetrics

#TODO: replace with hydra
from explicit import ExplicitSiren, ExplicitFourierNet, ExplicitGaussNet, ExplicitPositionalReLUNet
from order_params import CorrLenComputer, sqrt_trace_covar_loss_grad, norm_mean_loss_grad
from approximation import gaussian_approx, mag_hessian
import linalg 
import data_utils


def loss_func(input, target):
    return 0.5 * (input - target).pow(2).mean()

def compute_dfdx_mag(x, model):
    def dfdx(coords):
        out_size = model.net[-1].out_features
        g0 = coords.new(coords.shape[0], out_size).fill_(1.)
        
        _, cache = model.forward(coords, keep_cache=True)    
        dfdx, _ = model.reverse(g0, cache)

        return dfdx
    
    return (dfdx(x) ** 2).sum(-1).cpu().data.numpy()

def log_log_grad(xt, t):
    return np.gradient(np.log(1+xt), np.log(1+t))

def compute_ntk(hs, fgrads, skip_first=False):
    ntk = 0
    for i, (_h, _fz) in enumerate(zip(hs, fgrads)):
        if i==0 and skip_first:
            continue
        ntk += (_fz @ _fz.T) * (1 + _h @ _h.T) / len(_h)
    return ntk

def extract_weight_norms(net):
    norms = []
    for layer in net:
        norms.append(layer.weight.data.norm().cpu().item())

    return norms

def extract_ntk(model, x, y):
    g0 = torch.ones_like(y)

    with torch.no_grad():
        _, cache = model(x, keep_cache=True)
        
        _, dfdzs = model.reverse(g0, cache)
        _ins = [x] + cache['hs'][:-1]
        _outs = dfdzs

        ntk = compute_ntk(_ins, _outs)#.cpu().data.numpy()
    
        comp = {
            'ins' : [i.data.clone().cpu().numpy() for i in _ins], #hs - should I store hs, or the cosine?
            'outs' : [o.data.clone().cpu().numpy() for o in _outs] #dfdzs
        }
        
    return ntk, comp

def make_steps_to_collect(n_epoch=10_000, n_steps_to_collect=100):
    log_epoch = np.log(n_epoch) / np.log(10)
    steps_to_collect = (10 ** (np.linspace(0, log_epoch, n_steps_to_collect)))
    steps_to_collect = np.ceil(np.diff(steps_to_collect)).cumsum().astype(int)
    steps_to_collect = steps_to_collect[steps_to_collect < n_epoch - 1]
    steps_to_collect = np.append(steps_to_collect, n_epoch)
    steps_to_collect = np.insert(steps_to_collect, 0, 0)

    return steps_to_collect

def compute_extra_stats(ntk, cos_ntk, CLC, R, dA):
    _stats = {}
    _x = CLC.dist_bins
    mean_K_ntk, _ = CLC._groupby_dist(cos_ntk)
    
    _y = mean_K_ntk

    _stats.update(CLC.fit_fort_params(ntk))
    _stats.update(CLC.fit_exp_params(_x, _y))

    k_cutoff = 0.5 * (1 + _stats[r'$C_{\infty}$'])

    #Gradient noise
    _stats[r'$\mu_{\theta}$'] = norm_mean_loss_grad(R, ntk)
    _stats[r'$\sigma_{\theta}$'] = sqrt_trace_covar_loss_grad(R, ntk)

    _stats[r'$\mu_{r}$'] = R.mean().item()
    _stats[r'$\sigma_{r}$'] = R.std().item()
    
    #Correlation Area
    valid = (cos_ntk >= k_cutoff).float()
    areas = valid.sum(-1) * dA
    corr_lens = torch.sqrt(areas / np.pi).mean().item()
    _stats[r'$\xi_{FWHM}$'] = corr_lens

    return _stats

def cos_x_y(x, y):
    mag_x = torch.sqrt((x ** 2).sum())
    mag_y = torch.sqrt((y ** 2).sum())
    dot_xy = (x * y).sum(-1)
    return (dot_xy / (mag_x * mag_y)).item()

class Trainer():
    def __init__(self, config, device, kernel_sigma_2=0.1):
        torch.manual_seed(config['seed'])
        self.config = config
        self.device = device
        self.model = self.build_model(config)
        self.model = self.model.to(self.device)

        sidelength = config.get('sidelength', 64)
        self.do_gauss_approximation = config.get('do_gauss_approximation', True)
        self.log_ntk = config.get('log_ntk', True)

        if isinstance(config['dset'], str):
            self.dset = data_utils.ImageDataset(im_name=config['dset'], sidelength=sidelength)
        else:
            self.dset = config['dset']
        
        self.optimizer = torch.optim.SGD(lr=config['optimizer.lr'], params=self.model.parameters())
        self.loss_func = loss_func
        self.n_epoch = config['n_epoch']
        self.steps_to_collect = make_steps_to_collect(n_epoch=self.n_epoch, n_steps_to_collect=config['n_steps_to_collect'])

        self.xt = self.dset.dset.coords.to(device)
        self.yt = self.dset.dset.pixels.to(device)

        self.xv = self.dset.super_dset.coords.to(device)
        self.yv = self.dset.super_dset.pixels.to(device)

        self.CLC = CorrLenComputer(self.xt, 50)

        self.exp_dict = {
            'H' : [], 'a2' : [], 'D' : [], 'res' : [], 'U' : [],
            'ntk' : [], 'ntk_comp' : [], 
            #r"$||dfdx||^2$" : [], 'H2' : []
        }

        self.U = []

        self.stats = {
            r'$\min C_{NTK}$' : [],
            r'$\text{CKA}(K_X, K_{NTK})$' : [],
            r'$\text{CKA}(K_Y, K_{NTK})$' : [],
            r'$L_{\text{train}}$' : [],
            r'$L_{\text{eval}}$' : [],
            r'epoch' : [],
            r'$\mu_{\theta}$' : [],
            r'$\sigma_{\theta}$' : [],
            r'$\mu_{r}$' : [],
            r'$\sigma_{r}$' : [],
            r'$\xi_{corr}$' : [],
            r'$\xi_{FWHM}$' : [],
            r'$C_{\infty}$' : [],
            r'$\xi_{FORT}$' : [],
            r'$\alpha_{FORT}$' : [],
            r'$\beta_{FORT}$' : [],
            'edge_alignment' : [],
            'residual_alignment' : [],
            'alignment_with_init' : [],
            'cos_K1_u0' : [],
        }
        for i in range(config['n_components']):
            self.stats[f'$\lambda_{i}$'] = []

        #for i in range(config['model.hidden_layers'] + 2):
        #    self.stats[f'$||w_{i}||$'] = []

        self.n_components = config['n_components']

        self.kernel_sigma_2 = kernel_sigma_2
        self.K_Y = torch.exp( -(self.yt - self.yt.T) ** 2 / (self.kernel_sigma_2) )
        self.K_X = torch.exp(-self.CLC.distances ** 2 / (self.kernel_sigma_2) )

        deltax = 2 / self.dset.sidelength
        self.dA = deltax ** 2

    def build_model(self, config):
        #TODO:
        #Add sigma to config!!
        hidden_features = config['model.hidden_features']
        hidden_layers = config['model.hidden_layers']
        scale = config.get('model.architecture.scale', 10)
        if config['model.architecture'] == 'SIREN':
            model = ExplicitSiren(
                in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
                out_features=1, outermost_linear=True, first_omega_0=3 * scale,
            )
        elif config['model.architecture'] == 'Fourier':
            model = ExplicitFourierNet(
                in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
                out_features=1, sigma=scale
            )
        elif config['model.architecture'] == 'Gauss':
            model = ExplicitGaussNet(
                in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
                out_features=1, sigma=scale
            )
        elif config['model.architecture'] == 'ReluPE':
            model = ExplicitPositionalReLUNet(
                in_features=2, hidden_features=hidden_features, hidden_layers=hidden_layers, 
                out_features=1, sidelength=64
            )
        else:
            raise ValueError(f"{config['model.architecture']} is not a recognized architecture")

        return model
        

    def train_end(self):
        for k, v in self.stats.items():
            self.stats[k] = np.array(v)
        
        self.stats[r'$\dot{L}_{\text{train}}$'] = log_log_grad(
            self.stats[r'$L_{\text{train}}$'], self.stats['epoch']
        )

        self.stats[r'$\dot{L}_{\text{eval}}$'] = log_log_grad(
            self.stats[r'$L_{\text{eval}}$'], self.stats['epoch']
        )

        self.stats = pd.DataFrame(self.stats)

        D = np.array(self.exp_dict['D'])
        a2 = np.array(self.exp_dict['a2'])
        # Norm spatial mean log D
        spatial_mean_log_D = (D / a2[:, :, None]).mean(1)
        norm_spatial_mean_log_D = (spatial_mean_log_D ** 2).sum(-1)
        self.stats['norm_spatial_mean_log_D'] = norm_spatial_mean_log_D

    def collect(self, epoch):
        _ntk, _comp = extract_ntk(self.model, self.xt, self.yt)
        U, s = linalg.randomized_svd(_ntk, n_components=self.n_components)
        _u0 = U[:, 0]
        if epoch == 0:
            self.u0 = _u0
        self.exp_dict['U'].append(U.cpu().data.numpy())
        R = torch.FloatTensor(self.exp_dict['res'][-1]).to(self.device)

        if self.log_ntk:
            self.exp_dict['ntk'].append(_ntk.cpu().data.numpy())
            self.exp_dict['ntk_comp'].append(_comp)
            #self.exp_dict[r"$||dfdx||^2$"].append(compute_dfdx_mag(self.xt, self.model))
            #self.exp_dict["H2"].append(mag_hessian(self.xt, self.model))

        if self.do_gauss_approximation:
            infos = gaussian_approx(self.xt, self.model)
            for k, v in infos.items():
                self.exp_dict[k].append(v)

        #weight_norms = extract_weight_norms(self.model.net)

        _cos_ntk = linalg.cos_normalize_arr(_ntk)

        #Update stats
        self.stats[r'epoch'].append(epoch)
        self.stats[r'$\min C_{NTK}$'].append(_cos_ntk.min().item())
        self.stats[r'$\text{CKA}(K_X, K_{NTK})$'].append(linalg.CKA(self.K_X, _ntk).item())
        self.stats[r'$\text{CKA}(K_Y, K_{NTK})$'].append(linalg.CKA(self.K_Y, _ntk).item())
        self.stats['edge_alignment'].append(skmetrics.roc_auc_score(self.dset.edges, _u0.abs().cpu().numpy()))
        self.stats['residual_alignment'].append(abs(cos_x_y(self.u0, R)))

        # Eigenvector change
        self.stats['alignment_with_init'].append(abs(cos_x_y(self.u0, _u0)))

        # Mean alignment
        self.stats['cos_K1_u0'].append(abs(cos_x_y(_u0, _ntk.sum(-1))))

        for i in range(self.n_components):
            self.stats[f'$\lambda_{i}$'].append(s[i].item())

        #for i in range(self.config['model.hidden_layers'] + 2):
        #    self.stats[f'$||w_{i}||$'].append(weight_norms[i])

        #This needs to be fixed.
        corr_stats = compute_extra_stats(_ntk, _cos_ntk, self.CLC, R, self.dA)
        for k, v in corr_stats.items(): 
            self.stats[k].append(v)
        
    def run(self):
        for epoch in tqdm(range(self.n_epoch)):
            train_out = self.model(self.xt)

            if epoch in self.steps_to_collect:
                self.exp_dict['res'].append((self.yt - train_out).cpu().data.numpy().squeeze(-1))
                self.collect(epoch)
            
            loss = self.loss_func(train_out, self.yt)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if epoch in self.steps_to_collect:
                self.stats[r'$L_{\text{train}}$'].append(loss.item())
                with torch.no_grad():
                    validation_out = self.model(self.xv)
                    validation_loss = self.loss_func(validation_out, self.yv).item()
                    self.stats[r'$L_{\text{eval}}$'].append(validation_loss)

        # End of training
        with torch.no_grad():
            validation_out = self.model(self.xv)
            validation_loss = loss_func(validation_out, self.yv).item()
            self.stats[r'$L_{\text{eval}}$'].append(validation_loss)
        
            train_out = self.model(self.xt)
            train_loss = loss_func(train_out, self.yt).item()
            self.stats[r'$L_{\text{train}}$'].append(train_loss)

            self.exp_dict['res'].append((self.yt - train_out).cpu().data.numpy().squeeze(-1))
        self.collect(self.n_epoch + 1)

        self.train_end()
        return self.gen_report()

    def gen_report(self):
        report = {
            'min_eval_loss' : float(self.stats[r'$L_{\text{eval}}$'].min()),
            't_transition' : int(self.stats['epoch'][self.stats[r'$\dot{L}_{\text{eval}}$'].argmin()]),
            'best_epoch' : int(self.stats[r'$L_{\text{eval}}$'].argmin())
        }

        return report