import os
import time
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from gensim.models import Word2Vec
from gmrbm import GMRBM
from utils import setup_logging

POTTS_STATES = [2]
TOTAL_CAPACITY = 800_000
VISIBLE_SIZE = 400
BATCH_SIZE = 64
EPOCHS = 10000
LR = 1e-4
CD_STEP = 10
CD_BURNIN = 2
INIT_VAR = 0.1
INFERENCE_METHOD = "Gibbs"
LANGEVIN_STEP = 10
LANGEVIN_ETA = 0.1
LANGEVIN_ADJUST = 5
DATASET_SIZES = [1500]

class WordPairsDataset(Dataset):
    def __init__(self, word_pairs, embedding_size=200):
        self.word_pairs = word_pairs
        self.embedding_size = embedding_size
        self.model, self.mean, self.std = self._train_word2vec()
        self.data = self._preprocess_data()

    def _train_word2vec(self):
        sentences = [[s, r] for s, r in self.word_pairs]
        model = Word2Vec(sentences, size=self.embedding_size, window=5, min_count=1, sg=0, iter=100)
        vectors = np.array([model.wv[w] for pair in self.word_pairs for w in pair])
        mean = vectors.mean(axis=0)
        std = vectors.std(axis=0) + 1e-8
        return model, mean, std

    def _preprocess_data(self):
        processed = []
        for s, r in self.word_pairs:
            s_vec = (self.model.wv[s] - self.mean) / self.std
            r_vec = (self.model.wv[r] - self.mean) / self.std
            processed.append(np.concatenate([s_vec, r_vec]))
        return torch.tensor(np.array(processed), dtype=torch.float32)

    def __len__(self):
        return len(self.word_pairs)

    def __getitem__(self, idx):
        return self.data[idx], 0

def create_dataset(file_path, num_samples=None):
    df = pd.read_csv(file_path, header=None, names=["word1", "word2"])
    df = df.dropna().astype(str)
    df = df[~df["word1"].str.contains("_") & ~df["word2"].str.contains("_")]
    if num_samples and num_samples < len(df):
        df = df.sample(n=num_samples, random_state=42)
    word_pairs = list(df.itertuples(index=False, name=None))
    return WordPairsDataset(word_pairs)

def associate(stimulus_word, model, w2v_model, mean, std, steps=100):
    stim_vec = (w2v_model.wv[stimulus_word] - mean) / std
    input_vec = np.concatenate([stim_vec, np.zeros_like(stim_vec)])
    input_tensor = torch.tensor(input_vec, dtype=torch.float32).cuda() if next(model.parameters()).is_cuda else torch.tensor(input_vec)
    v = input_tensor.clone()
    for _ in range(steps):
        h = model.prob_h_given_v(v.unsqueeze(0), model.get_var())
        h_idx = torch.multinomial(h.view(-1, h.shape[-1]), 1).squeeze()
        h = torch.nn.functional.one_hot(h_idx, num_classes=h.shape[-1]).float().view_as(h)
        v_recon = model.prob_v_given_h(h)
        v_recon[:, :stim_vec.shape[0]] = input_tensor[:stim_vec.shape[0]]
        v = v_recon.squeeze()
    resp_vec = v[stim_vec.shape[0]:].cpu().numpy() * std + mean
    return w2v_model.wv.most_similar([resp_vec], topn=1)[0][0]

def test_associations(model, dataset):
    model.eval()
    correct = 0
    for s, r in dataset.word_pairs:
        try:
            pred = associate(s, model, dataset.model, dataset.mean, dataset.std)
            correct += int(pred == r)
        except KeyError:
            continue
    return correct / len(dataset.word_pairs)

def train_one_epoch(model, train_loader, optimizer, config):
    model.train()
    total_loss = 0
    for batch in train_loader:
        data = batch[0].cuda() if config['cuda'] else batch[0]
        optimizer.zero_grad()
        model.CD_grad(data)
        optimizer.step()
        with torch.no_grad():
            total_loss += model.reconstruction(data).item()
    return total_loss / len(train_loader)

def run_experiment_sweep_words(file_path, dataset_sizes, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    logger = setup_logging('INFO', os.path.join(output_dir, 'benchmark.log'))
    results = []

    for k in POTTS_STATES:
        for num_samples in dataset_sizes:
            dataset = create_dataset(file_path, num_samples=num_samples)
            hidden_size = (TOTAL_CAPACITY // (VISIBLE_SIZE * k))

            logger.info(f"Potts States: {k}, Dataset size: {num_samples}, Hidden Units: {hidden_size}")
            model = GMRBM(
                visible_size=VISIBLE_SIZE,
                hidden_size=hidden_size,
                num_potts_states=k,
                CD_step=CD_STEP,
                CD_burnin=CD_BURNIN,
                init_var=INIT_VAR,
                inference_method=INFERENCE_METHOD,
                Langevin_step=LANGEVIN_STEP,
                Langevin_eta=LANGEVIN_ETA,
                Langevin_adjust_step=LANGEVIN_ADJUST
            )
            if torch.cuda.is_available():
                model.cuda()
            optimizer = torch.optim.Adam(model.parameters(), lr=LR)
            train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

            accuracy_history = []
            stopping_epoch = EPOCHS
            best_accuracy = 0.0
            start_time = time.time()

            for epoch in range(1, EPOCHS + 1):
                avg_loss = train_one_epoch(model, train_loader, optimizer, config={"cuda": torch.cuda.is_available()})
                if epoch % 50 == 0:
                    val_acc = test_associations(model, dataset)
                    accuracy_history.append(val_acc)
                    logger.info(f"Epoch {epoch}, Validation Accuracy: {val_acc:.4f}")

                    if val_acc > best_accuracy:
                        best_accuracy = val_acc

                    # Stop if accuracy crosses 0.98
                    if val_acc >= 0.98:
                        stopping_epoch = epoch
                        logger.info(f"Early stopping at epoch {epoch} (accuracy reached 0.98)")
                        break

                    # Stop if std dev of last 20 is low
                    if len(accuracy_history) >= 20 and np.std(accuracy_history[-20:]) < 0.01:
                        stopping_epoch = epoch
                        logger.info(f"Early stopping at epoch {epoch} (std dev < 0.01 in last 20 epochs)")
                        break

                    # Stop if no improvement in last 10
                    if len(accuracy_history) > 10 and max(accuracy_history[-10:]) <= max(accuracy_history[:-10]):
                        stopping_epoch = epoch
                        logger.info(f"Early stopping at epoch {epoch} (no improvement in last 10 epochs)")
                        break

            time_to_converge = time.time() - start_time
            accuracy = test_associations(model, dataset)
            logger.info(f"Final Accuracy @ {num_samples} samples, k={k}: {accuracy:.4f}, Best: {best_accuracy:.4f}")

            results.append({
                "PottsStates": k,
                "NumSamples": num_samples,
                "HiddenUnits": hidden_size,
                "Accuracy": accuracy,
                "TimeToConvergeSeconds": time_to_converge,
                "StoppingEpoch": stopping_epoch
            })

    return pd.DataFrame(results)

def plot_word_sweep_results(results_df, output_dir):
    plt.figure(figsize=(10, 6))
    for k in POTTS_STATES:
        subset = results_df[results_df['PottsStates'] == k]
        plt.plot(subset['NumSamples'], subset['Accuracy'], marker='o', label=f'k={k}')
    plt.title("Association Accuracy vs Number of Word Pairs for Varying Potts States")
    plt.xlabel("Number of Word Pairs")
    plt.ylabel("Recall Accuracy")
    plt.legend()
    plt.grid(True)
    plt.savefig(os.path.join(output_dir, "accuracy_vs_words_multi_k.png"))
    plt.close()

if __name__ == "__main__":
    output_dir = "word_sweep_results_multi_k"
    print("Running sweep on dataset sizes across multiple Potts states...")
    results_df = run_experiment_sweep_words("word_relationships.csv", DATASET_SIZES, output_dir)
    results_df.to_csv(os.path.join(output_dir, "word_sweep_results_multi_k.csv"), index=False)
    plot_word_sweep_results(results_df, output_dir)
    print("Done. Results saved to:", output_dir)
