import numpy as np
import pandas as pd
import torch
from joblib import Parallel, delayed
from scipy.spatial.distance import pdist
import pytorch_lightning as pl
from sklearn.datasets import make_spd_matrix
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score


def gen_params(A_dim, Z_dim, seed=1):
    np.random.get_state()
    random_state = np.random.RandomState(seed=seed)
    M = random_state.uniform(-2, 2, size=(A_dim, Z_dim))
    cov_ez = make_spd_matrix(Z_dim, random_state=random_state)

    return dict(M=M, cov_ez=cov_ez)


def compute_r2(predZ, Z):
    lr = LinearRegression()
    lr.fit(X=predZ, y=Z)
    r2 = r2_score(Z, lr.predict(X=predZ))
    return r2


def compute_mse(model, X):
    if X.__class__ == np.ndarray:
        X = to_torch(X)
    X_pred = model(X)
    mse = torch.mean((X_pred - X) ** 2).item()
    return mse


def sample_cov_matrix(d, k, seed, ind=False, var_diff=True):
    rand_state = np.random.RandomState(seed)
    if ind:
        cov = np.eye(d)
    else:
        W = rand_state.randn(d, k)
        S = W.dot(W.T) + np.diag(rand_state.rand(d))
        inv_std = np.diag(1. / np.sqrt(np.diag(S)))
        cov = inv_std @ S @ inv_std

    if var_diff:
        var = rand_state.uniform(0.5, 3, size=(d,))
        std = np.diag(np.sqrt(var))
        cov = std @ cov @ std

    return cov


def compute_V(predZ, A):
    lr = LinearRegression()
    lr.fit(A, predZ)
    V = predZ - lr.predict(A)

    return V, lr


def med_sigma(x):
    if x.ndim == 1:
        x = x[:, np.newaxis]

    return np.sqrt(np.median(pdist(x, 'sqeuclidean')) * .5)


def get_med_sigma(hsic_net, data, s_z=False):
    X, Y, Z = data
    res = (Y - hsic_net(X)).detach().numpy()
    if s_z:
        return med_sigma(res), med_sigma(Z)
    else:
        return med_sigma(res)


def rnorm(n):
    return np.random.normal(size=n)


def rnorm_d(n, d, mean=None, cov=None):
    if mean is None:
        return np.random.normal(size=(n, d))
    else:
        return np.random.multivariate_normal(mean, cov, size=n)


def to_torch(arr):
    return torch.from_numpy(np.array(arr).astype(np.float32))


class MedianHeuristicMMR(pl.Callback):
    # a callback to update the kernel parameter with median heuristic after each epoch ends
    def __init__(self):
        self.epoch = 0

    def on_train_start(self, trainer, pl_module):
        X, A = trainer.train_dataloader.dataset.tensors
        s_a = med_sigma(A.cpu().detach().numpy())
        pl_module.kernel_a.set_kernel_param(s_a)


def get_trainer(callback=None, accelerator='mps', max_epochs=1):
    if callback is None:
        return pl.Trainer(max_epochs=max_epochs, enable_progress_bar=True,
                          enable_checkpointing=False,
                          enable_model_summary=False, accelerator=accelerator)
    else:
        return pl.Trainer(max_epochs=max_epochs, callbacks=[callback], enable_progress_bar=True,
                          enable_checkpointing=False,
                          enable_model_summary=False, accelerator=accelerator)


def choose_lambda(model_MMR, model_baseline, n_iter, cut_off=0.5, accelerator='cpu'):
    lmbs = [1e+1, 5e+1, 1e+2, 5e+2, 1e+3, 5e+3]
    res_df = pd.DataFrame()
    res_df['lmd'] = lmbs
    trainer = get_trainer(max_epochs=n_iter, callback=MedianHeuristicMMR(), accelerator=accelerator)
    trainer.fit(model_baseline)
    mse_baseline = compute_mse(model_baseline, model_baseline.X)

    def get_rel_mse(lmd):
        model_MMR.load_state_dict(model_baseline.state_dict())
        model_MMR.lmd = lmd

        trainer = get_trainer(max_epochs=n_iter, callback=MedianHeuristicMMR(), accelerator=accelerator)
        trainer.fit(model_MMR)

        mse = compute_mse(model_MMR, model_MMR.X)
        rel_mse = (mse - mse_baseline) / mse_baseline

        return rel_mse

    rel_mse = Parallel(n_jobs=len(lmbs))(
        delayed(get_rel_mse)(lmd) for lmd in lmbs
    )

    res_df['rel_mse'] = rel_mse

    lmd_df = res_df.sort_values(by='lmd', ascending=False).reset_index(drop=True)
    final_lmd = lmd_df.lmd.min()
    for index, row in lmd_df.iterrows():
        if row['rel_mse'] < cut_off:
            final_lmd = row['lmd']
            break

    return final_lmd, lmd_df
