import os
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict

class NeuralPhi(nn.Module):
    def __init__(self, cfg, **kwargs):
        super().__init__()
        self.cfg = cfg
        input_dim  = cfg['input_dim']
        output_dim = cfg['phi_output_dim']
        self.phi = nn.Sequential(
            nn.Linear(input_dim, output_dim),
            nn.Sigmoid()
        ).double()
        self.device = kwargs['device']
        self.path = kwargs['phi_path']
        self.to(self.device)

    def forward(self, x):
        return self.phi(x.double()).double()

    def fit(self, X, KY):
        optimizer = optim.Adam(self.parameters(), lr=self.cfg['phi_lr'])
        self.train()
        print(f'Training Neural Phi {self.cfg["phi_steps"]} steps...')
        for step in range(self.cfg['phi_steps']):
            self.fit_astep(X, KY, optimizer)
        self.eval()

    def fit_yield(self, X, KY):
        optimizer = optim.Adam(self.parameters(), lr=self.cfg['phi_lr'])
        self.train()
        yield self.state_dict().copy()
        for step in range(self.cfg['phi_steps']):
            self.fit_astep(X, KY, optimizer)
            yield self.state_dict().copy()
        self.eval()

    def fit_astep(self, X, KY, optimizer):
        Phi = self(X).T
        I = torch.eye(Phi.shape[0], device=self.device)
        W = torch.linalg.inv(Phi @ Phi.T + X.shape[0] * self.cfg['lambd'] * I)
        loss = -torch.trace(KY @ Phi.T @ W @ Phi)
        loss.backward()
        optimizer.step()

    def save(self):
        state_dict = OrderedDict()
        for k, v in self.state_dict().items():
            state_dict[k] = v.cpu()
        os.makedirs(os.path.dirname(self.path), exist_ok=True)
        torch.save(state_dict, self.path)

    def load(self):
        if os.path.exists(self.path):
            self.load_state_dict(torch.load(self.path, map_location=self.device))
            self.to(self.device)
            return True
        return False