import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils import cosine_schedule

class GMRBM(nn.Module):
    """ Gaussian-Multinoulli Restricted Boltzmann Machine (GPBM) """

    def __init__(self,
                 visible_size,
                 hidden_size,
                 num_potts_states=4,
                 CD_step=1,
                 CD_burnin=0,
                 init_var=1e-0,                 
                 inference_method='Gibbs',
                 Langevin_step=10,
                 Langevin_eta=1.0,
                 is_anneal_Langevin=True,
                 Langevin_adjust_step=0) -> None:
        super().__init__()
        assert CD_burnin >= 0 and CD_burnin <= CD_step
        assert inference_method in ['Gibbs', 'Langevin', 'Gibbs-Langevin']

        self.visible_size = visible_size
        self.hidden_size = hidden_size
        self.num_potts_states = num_potts_states
        self.CD_step = CD_step
        self.CD_burnin = CD_burnin
        self.init_var = init_var
        self.inference_method = inference_method
        self.Langevin_step = Langevin_step
        self.Langevin_eta = Langevin_eta
        self.is_anneal_Langevin = is_anneal_Langevin
        self.Langevin_adjust_step = Langevin_adjust_step

        # 3D weight tensor [visible, hidden, states]
        self.W = nn.Parameter(torch.Tensor(visible_size, hidden_size, num_potts_states))
        # 2D bias tensor [hidden, states]
        self.b = nn.Parameter(torch.Tensor(hidden_size, num_potts_states))
        self.mu = nn.Parameter(torch.Tensor(visible_size))
        self.log_var = nn.Parameter(torch.Tensor(visible_size))
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.normal_(self.W, std=1.0*self.init_var/np.sqrt(self.visible_size+self.hidden_size))
        nn.init.constant_(self.b, 0.0)
        nn.init.constant_(self.mu, 0.0)
        nn.init.constant_(self.log_var, np.log(self.init_var))

    def get_var(self):
        return self.log_var.exp().clip(min=1e-8)

    @torch.no_grad()
    def energy(self, v, h_onehot):
        var = self.get_var()
        eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1)
        eng -= torch.einsum('bi,ihk,bhk->b', v/var, self.W, h_onehot)
        eng -= (h_onehot * self.b).sum(dim=(1,2))
        return eng / v.size(0)

    @torch.no_grad()
    def _vectorized_h_sample(self, probs):
        shape = probs.shape  # [batch, hidden, states]
        probs_flat = probs.view(-1, shape[-1])
        samples_flat = torch.multinomial(probs_flat, 1).squeeze()
        return samples_flat.view(shape[0], shape[1])

    @torch.no_grad()
    def prob_h_given_v(self, v, var):
        logits = torch.einsum('bi,ihk->bhk', v/var, self.W) + self.b
        return F.softmax(logits, dim=2)

    @torch.no_grad()
    def prob_v_given_h(self, h_onehot):
        return torch.einsum('bhk,ihk->bi', h_onehot, self.W) + self.mu

    @torch.no_grad()
    def Gibbs_sampling_vh(self, v, num_steps=10, burn_in=0):
        samples, var = [], self.get_var()
        std = var.sqrt()
        
        # Initial hidden sampling
        probs = self.prob_h_given_v(v, var)
        h_idx = self._vectorized_h_sample(probs)
        h_onehot = F.one_hot(h_idx, self.num_potts_states).float()

        for ii in range(num_steps):
            # Visible update
            mu = self.prob_v_given_h(h_onehot)
            v = mu + torch.randn_like(mu) * std
            
            # Hidden update
            probs = self.prob_h_given_v(v, var)
            h_idx = self._vectorized_h_sample(probs)
            h_onehot = F.one_hot(h_idx, self.num_potts_states).float()

            if ii >= burn_in:
                samples.append((v, h_onehot))

        return samples

    @torch.no_grad()
    def energy_grad_param(self, v, h_onehot):
        var = self.get_var()
        grad = {
            'W': -torch.einsum('bi,bhk->ihk', v/var, h_onehot) / v.size(0),
            'b': -h_onehot.mean(dim=0),
            'mu': ((self.mu - v) / var).mean(dim=0),
            'log_var': (-0.5*(v-self.mu)**2/var + 
                       (v/var)*torch.einsum('bhk,ihk->bi', h_onehot, self.W)).mean(dim=0)
        }
        return grad

    @torch.no_grad()
    def marginal_energy(self, v):
        B = v.shape[0]
        var = self.get_var()
        eng = 0.5 * ((v - self.mu)**2 / var).sum(dim=1)
        logits = torch.einsum('bi,ihk->bhk', v/var, self.W) + self.b
        eng -= torch.logsumexp(logits, dim=2).sum(dim=1)
        return eng / B

    @torch.no_grad()
    def reconstruction(self, v):
        v, var = v.view(v.shape[0], -1), self.get_var()
        probs_h = self.prob_h_given_v(v, var)
        v_bar = self.prob_v_given_h(probs_h)
        return F.mse_loss(v, v_bar)


    @torch.no_grad()
    def sampling(self, v_init, num_steps=1, save_gap=1):
        if num_steps < 2:
            print("[WARNING] sampling() called with num_steps < 2. Adjusting to num_steps = 2 to avoid empty sampling.")
            num_steps = 2

        v_shape = v_init.shape
        v = v_init.view(v_shape[0], -1)
        var = self.get_var()
        var_mean = var.mean().item()

        samples = []
        if self.inference_method == 'Gibbs':
            raw_samples = self.Gibbs_sampling_vh(v, num_steps=num_steps - 1)
            if len(raw_samples) == 0:
                print("[WARNING] Gibbs_sampling_vh returned empty list. Returning initial sample only.")
                samples = [v]
            else:
                samples = [xx[0] for xx in raw_samples]

        elif self.inference_method == 'Langevin':
            samples = self.Langevin_sampling_v(v,
                                            num_steps=num_steps - 1,
                                            eta=self.Langevin_eta * var_mean,
                                            is_anneal=self.is_anneal_Langevin,
                                            adjust_step=self.Langevin_adjust_step)

        if len(samples) == 0:
            # Fall back to using initial sample only
            v_list = [(0, v_init), (1, v_init)]
            return v_list

        # Final conditional mean
        probs_h = self.prob_h_given_v(samples[-1], var)
        mu = self.prob_v_given_h(probs_h)

        # Format output
        v_list = [(0, v_init)] + [(ii+1, samples[ii].view(v_shape))
                                for ii in range(len(samples)) if (ii+1) % save_gap == 0] + \
                [(num_steps, mu.view(v_shape))]

        return v_list


    @torch.no_grad()
    def positive_grad(self, v):
        probs = self.prob_h_given_v(v, self.get_var())
        h_idx = self._vectorized_h_sample(probs)
        h_onehot = F.one_hot(h_idx, self.num_potts_states).float()
        return self.energy_grad_param(v, h_onehot)

    @torch.no_grad()
    def CD_grad(self, v, data_ratio=0.0):
        v = v.view(v.shape[0], -1)
        grad_pos = self.positive_grad(v)

        B = v.size(0)
        num_data = int(data_ratio * B)
        idx = torch.randperm(B)
        data_idx, noise_idx = idx[:num_data], idx[num_data:]

        v_neg = torch.empty_like(v)
        v_neg[data_idx] = v[data_idx]                     # init from data
        v_neg[noise_idx] = torch.randn_like(v[noise_idx]) # init from noise

        grad_neg = self.negative_grad(v_neg)

        for name, param in self.named_parameters():
            param.grad = grad_pos[name] - grad_neg[name]


    @torch.no_grad()
    def negative_grad(self, v):
        var = self.get_var()
        var_mean = var.mean().item()
        
        if self.inference_method == 'Gibbs':
            samples = self.Gibbs_sampling_vh(v,
                                            num_steps=self.CD_step,
                                            burn_in=self.CD_burnin)
            v_neg = torch.cat([xx[0] for xx in samples], dim=0)
            h_neg = torch.cat([xx[1] for xx in samples], dim=0)
            grad = self.energy_grad_param(v_neg, h_neg)
            
        elif self.inference_method == 'Langevin':
            samples = self.Langevin_sampling_v(v,
                                            num_steps=self.CD_step,
                                            burn_in=self.CD_burnin,
                                            eta=self.Langevin_eta * var_mean,
                                            is_anneal=self.is_anneal_Langevin,
                                            adjust_step=self.Langevin_adjust_step)
            v_neg = torch.cat(samples, dim=0)
            grad = self.marginal_energy_grad_param(v_neg)
            
        elif self.inference_method == 'Gibbs-Langevin':
            samples = self.Gibbs_Langevin_sampling_vh(
                v,
                num_steps=self.CD_step,
                burn_in=self.CD_burnin,
                num_steps_Langevin=self.Langevin_step,
                eta=self.Langevin_eta * var_mean,
                is_anneal=self.is_anneal_Langevin,
                adjust_step=self.Langevin_adjust_step)
            v_neg = torch.cat([xx[0] for xx in samples], dim=0)
            h_neg = torch.cat([xx[1] for xx in samples], dim=0)
            grad = self.energy_grad_param(v_neg, h_neg)

        return grad
    
    def set_Langevin_adjust_step(self, step):
        self.Langevin_adjust_step = step

