import torch
import numpy as np
import gzip
import pickle
import pathlib
from tqdm import tqdm
import torch.nn.functional as F

from config import Hs, alphas, num_reps, Ns, alphabet
from models.policies import UniformPolicy
from models.value_function import ValueFunction, EstimatedValue
from samplers.js_sampler import JSSampler
from samplers.token_sampler import sample_edge_noise, sample_tokenwise
from utils import reward, ExpectedReward, generate_sample, kl_divergence, binomial_reference, kl_bcount, one_hot_encode
import argparse

def main(H):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device {device}",flush=True)    
    data_dir = pathlib.Path("data")
    data_dir.mkdir(exist_ok=True, parents=True)


    sample_size  = 10 * (2 ** H)
    piref = UniformPolicy(alphabet=['a','b','c'], horizon=H)
    target_binom_bc = binomial_reference(H,0.5)


    print(f"=== H = {H} ===")
    
    results = []


    for rep in range(num_reps):
        print(f"    Repetition {rep+1}/{num_reps}", flush=True)

        edge_noise = sample_edge_noise(alphabet, H, alpha=1)

        for N in Ns:
            print(f"Training with N = {N}",flush=True)
            value_functions = {}
            for h in range(1, H):
                model = ValueFunction(h).to(device)
                optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

                X, y = [], []
                for _ in range(N):
                    full_seq, r = generate_sample(H)
                    prefix = full_seq[:h]
                    X.append(one_hot_encode(prefix, h))
                    y.append(r)
                X = torch.stack(X).to(device)
                y = torch.tensor(y, dtype=torch.float).to(device)

                for _ in range(100):
                    optimizer.zero_grad()
                    preds = model(X)
                    loss = F.binary_cross_entropy(preds, y)
                    loss.backward()
                    optimizer.step()

                value_functions[h] = model
            est_val = EstimatedValue(value_functions, reward, H, device)

            print("Evaluating ActionLevelSampling", flush=True)
            tw_data = [sample_tokenwise(est_val,edge_noise, H) for _ in range(sample_size)]
            tw_samples = [j[0] for j in tw_data]
            tw_steps = [j[1] for j in tw_data]

            print("Evaluating VGB", flush=True)
            js_sampler =  JSSampler(piref, H, est_val, edge_noise)
            js_data = [js_sampler.sample() for _ in range(sample_size)]
            js_samples = [j[0] for j in js_data]
            js_steps = [j[1] for j in js_data]

            results.append({
                'H': H,
                'N': N,
                'Algorithm': 'ActionLevel',
                'avg_steps': np.mean(tw_steps),
                'kl': kl_divergence(tw_samples,H)
            })
            results.append({
                'H': H,
                'N': N,
                'Algorithm': 'VGB',
                'avg_steps': np.mean(js_steps),
                'kl': kl_divergence(js_samples, H)
            })
            print(results[-2],flush=True)
            print(results[-1],flush=True)

    fname = data_dir / f"abc_results_H{H}.pkl"
    with open(fname, "wb") as f:
         pickle.dump(results, f)

    print(f"Saved results → {fname}")

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="ABC task")
    parser.add_argument('--H', type=int, default=5, help='Horizon (default: 5)')
    args = parser.parse_args()
    main(args.H) 
