# DISABLE_SMART_IME
import matplotlib.pyplot as plt
import hydra
from omegaconf import OmegaConf
import os
from market.specifications.base import CMEBase
from market.specifications.neural_phi import NeuralPhi
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset
import torch
from utils import set_seed
from datasets import Dataset

class NCME(CMEBase):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)
        self.phi = NeuralPhi(cfg, **kwargs)

    def generate(self):
        self.KY = self.kernel_y(self.Y, self.Y)
        if self.cfg['fit']:
            return self._generate_training()
        else:
            self.estimate()
            return None

    def estimate(self):
        Phi = self.phi(self.X).T
        self.BX = 1
        I = torch.eye(Phi.shape[0], device=self.device)
        invW = torch.linalg.inv(Phi @ Phi.T + self.n * self.lambd * I)
        self.Q = Phi.T @ invW

    def _generate_training(self):
        state_generator = self.phi.fit_yield(self.X, self.KY)
        for state in state_generator:
            self.phi.load_state_dict(state)
            self.estimate()
            yield

class CME(CMEBase):
    def __init__(self, cfg, X, Y, **kwargs):
        super().__init__(cfg, X, Y, **kwargs)
        self.lambd = 0.001

    def generate(self, *args, **kwargs):
        self.KX = self.kernel_x(self.X, self.X)
        self.KY = self.kernel_y(self.Y, self.Y)
        I = torch.eye(self.n, device=self.device)
        self.W = torch.linalg.inv(self.KX + self.n * self.lambd * I)
        # self.norm = torch.trace(self.W @ self.KY @ self.W @ self.KX).item()  # W.T @ KY @ W @ KX

@torch.no_grad()
def MMD_NCME(X, Y, ncme: NCME):
    """
          ||sum_{i=1}^n(sum_{j=1}^m psi(hy_j)q_j)phi(x_i)-sum_{i=1}^n psi(y_i)||^2
        = sum_{i,i'} <
            sum_{j=1}^m psi(hy_j) * q_j * phi(x_i)  - psi(y_i),
            sum_{j=1}^m psi(hy_j) * q_j * phi(x_i') - psi(y_i'),
          >
        = sum_{i,i'} (
                sum_{j,j'} kY(hy_j, hy_j') * (q_j phi(x_i)) * (q_j' phi(x_i'))
              + kY(y_i, y_i')
              - sum_j kY(hy_j, y_i') * (q_j phi(x_i))
              - sum_j kY(hy_j, y_i)  * (q_j phi(x_i'))
          )
        = sum_{i,i'}sum_{j,j'} K_{hatY}_{j,j'} * (q_j phi(x_i)) * (q_j' phi(x_i'))
          + sum_{i,i'} KY_{i,i'}
          - 2 * sum_{i,i'}sum_j kY(hy_j, y_i') * (q_j phi(x_i))
    """
    hatY = ncme.Y                       # (n1, dy)
    Q = ncme.Q                          # (n1, d)
    X = X.to(ncme.device)               # (n2, dx)
    Y = Y.to(ncme.device)               # (n2, dy)
    Phi = ncme.phi(X).T                 # (d, n2)
    QPhi = Q @ Phi                      # (n1, n2)
    QPhi1 = QPhi.sum(1)                 # (n1, )
    # KY  = ncme.kernel_x(Y, Y)           # (n2, n2)
    KY_ = ncme.kernel_y(hatY, Y)        # (n1, n2)
    term1 = QPhi1 @ ncme.KY @ QPhi1     # sum_{i,i'}sum_{j,j'} K_{hatY}_{j,j'} * (q_j phi(x_i)) * (q_j' phi(x_i'))
    term2 = ncme.kernel_y(Y, Y).sum()   # sum_{i,i'}KY_{i,i'}
    term3 = (KY_.T @ QPhi1).sum()       # sum_{i,i'}sum_j kY(hy_j, y_i') * (q_j phi(x_i))

    mmd = term1 + term2 - 2 * term3
    return mmd.item() / X.shape[0]

def MMD_CME(X, Y, cme: CME):
    hatX = cme.X                    # (n1, dx)
    hatY = cme.Y                    # (n1, dy)
    W = cme.W                       # (n1, n1), (K_X + n1 * lambda * I)^{-1}
    X = X.to(cme.device)            # (n2, dx)
    Y = Y.to(cme.device)            # (n2, dy)
    KX_ = cme.kernel_x(hatX, X)     # (n1, n2), (KX_)_{ij} = kX(hatx_i, x_j)
    KY_ = cme.kernel_y(hatY, Y)     # (n1, n2), (KY_)_{ij} = kY(haty_i, y_j)
    WPhix = W @ KX_                 # (n1, n2)
    WPhix1 = WPhix.sum(1)           # (n1, )
    term1 = WPhix1 @ cme.KY @ WPhix1     # sum_{i,i'}sum_{j,j'} K_{hatY}_{j,j'} * (w_j kX(hatX_j, x_i)) * (w_j' kX(hatX_j', x_i'))
    term2 = cme.kernel_y(Y, Y).sum()      # sum_{i,i'}KY_{i,i'}
    term3 = (KY_.T @ WPhix1).sum()        # sum_{i,i'}sum_j kY(haty_j, y_i') * (w_j kX(hatX_j, x_i))
    mmd = term1 + term2 - 2 * term3
    return mmd.item() / X.shape[0]

def plot_training_curve(mmds, cme_mmd):
    plt.rcParams.update({
        'font.size': 30,
        # 'axes.labelweight': 'bold',
        'axes.labelsize': 30,
        'axes.titlesize': 30,
        'legend.fontsize': 30,
        'xtick.labelsize': 30,
        'ytick.labelsize': 30,
        'axes.linewidth': 0.8,
        'grid.linewidth': 0.5,
        'grid.alpha': 0.5,
        'savefig.dpi': 300,
        'savefig.bbox': 'tight'
    })
    plt.figure(figsize=(15, 10))
    n = len(mmds)
    steps = list(range(n))

    untrained = mmds[0]
    final = min(mmds)

    plt.plot(steps, mmds, linewidth=4, color='#1F77B4')
    plt.scatter(0, untrained, color='#D62728', marker='*', s=500, zorder=5, label=f'Untrained MMD = {untrained:.4f}')

    plt.axhline(y=final, color='#9467BD', linestyle='--', linewidth=4, label=f'Converged MMD = {final:.4f}', alpha=0.7)
    plt.axhline(y=cme_mmd, color='#FF7F0E', linestyle='--', linewidth=4, label=f'RBF-kernel CME MMD = {cme_mmd:.4f}')


    plt.xlabel('Training Step')
    plt.ylabel('MMD Value')
    plt.xticks(range(0, n, 10))
    # plt.ylim(0.75, 1.4)
    # plt.yticks([0.5, 1, 1.5, 2, 2.5, 3])
    # plt.title('MMD vs Training Steps', pad=10)
    plt.legend(loc='upper right')
    # plt.grid(True, alpha=0.3)
    plt.savefig('train.pdf')
    plt.tight_layout()
    plt.show()

@hydra.main(version_base=None, config_path='configs', config_name='config')
def main(cfg=None):
    cfg = OmegaConf.to_container(cfg, resolve=True)
    cfg.update(cfg['task'])
    cfg.update(cfg['specification'])
    set_seed(cfg['seed'])
    dataset = Dataset(cfg)
    X_train, y_train = dataset.get_data('train')
    X_test,  y_test  = dataset.get_data('test')
    device = torch.device('cuda')
    cme = CME(cfg, X_train, y_train, path='./complex/trained_cme', device=device, phi_path='./complex/phi_cme')
    cme.generate()
    cme_mmd = MMD_CME(X_test, y_test, cme)
    mmd_curve = []
    trained_ncme = NCME(cfg, X_train, y_train, path='./synthetic/trained', device=device, phi_path='./synthetic/phi')
    trained_ncme.generate()
    for _ in trained_ncme.generate():
        cur_mmd = MMD_NCME(X_test, y_test, trained_ncme)
        mmd_curve.append(cur_mmd)
        print(cur_mmd)

    # cme_mmd = 0.8136499924858799
    # mmd_curve = [1.3694059720493388, 1.3316093552990351, 1.2958545682581608, 1.2621368245058693, 1.2303705691378564, 1.20046267305559, 1.172338354371721, 1.1459270347494166, 1.121162625660305, 1.097985213314183, 1.0763368263801094, 1.05615819031687, 1.0373875904855085, 1.0199618051736616, 1.0038172119506634, 0.9888903541477629, 0.9751190882341471, 0.9624441438321956, 0.9508098044230137, 0.9401635150108486, 0.9304550843969919, 0.9216360329454765, 0.9136593192941509, 0.9064792123561493, 0.9000512523025973, 0.8943324437910923, 0.8892815116761486, 0.8848589259047294, 0.8810267538258341, 0.8777485665061977, 0.8749894920374499, 0.8727162845886778, 0.8708973279861966, 0.8695026488776785, 0.8685038820854388, 0.8678740652151173, 0.8675873673176393, 0.8676189568213886, 0.8679450897803763, 0.8685433396514273, 0.8693928391512018]
    plot_training_curve(mmd_curve, cme_mmd)


if __name__ == '__main__':
    main()
