import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import random
from collections import Counter
from scipy.stats import entropy
import pandas as pd
import itertools
import time
import argparse
import pickle
import signal
import sys
import os
from datetime import datetime

np.random.seed(0)
torch.manual_seed(0)
random.seed(0)

alphabet = [0,1]
alphabet_idx = {ch: i for i, ch in enumerate(alphabet)}

def one_hot_encode(seq, h, device=None):
    vec = torch.zeros(len(alphabet) * h)
    for i, ch in enumerate(seq):
        vec[len(alphabet) * i + alphabet_idx[ch]] = 1
    if device is not None:
        vec = vec.to(device)
    return vec

class ParityPolicy:
    def __init__(self, chunk_size, n_chunks):
        self.alphabet = alphabet
        self.chunk_size = chunk_size
        self.n_chunks = n_chunks
        self.horizon = chunk_size * n_chunks
    def next_token(self, sequence, condition_on_good = False):
        i = len(sequence)
        assert(i < self.horizon)
        if (i+1) % self.chunk_size == 0:
            return sum(sequence[1-self.chunk_size:])%2
        elif condition_on_good and ((i+2) % self.chunk_size == 0):
            return sum(sequence[2-self.chunk_size:])%2
        else:
            return np.random.choice(self.alphabet)
    def sequence_prob(self, sequence):
        choices = 0
        for i in range(len(sequence)):
            if (i+1)%self.chunk_size != 0:
                choices += 1
            else:
                if sum(sequence[i+1-self.chunk_size:i])%2 != sequence[i]:
                    return 0
        return (1 / len(self.alphabet)) ** choices
    def generate(self, condition_on_good = False):
        sequence = []
        for i in range(self.horizon):
            sequence = sequence + [self.next_token(sequence, condition_on_good)]
        return sequence

def reward(seq, chunk_size):
    for i in range(chunk_size-1, len(seq), chunk_size):
        if seq[i] == 1:
            return 0
    return 1

class ValueFunction(nn.Module):
    def __init__(self, h):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(len(alphabet) * h, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x_onehot):
        return self.fc(x_onehot).squeeze(-1)

class EstimatedValue:
    def __init__(self, value_functions, reward_fn, H, device=None):
        self.value_functions = value_functions
        self.reward = reward_fn
        self.H = H
        self.device = device

    def __call__(self, seq):
        h = len(seq)
        if h == self.H:
            return self.reward(seq)
        x = one_hot_encode(seq, h, self.device).unsqueeze(0)
        with torch.no_grad():
            return self.value_functions[h](x).item()

class EstimatedValueUncorrected:
    def __init__(self, value_functions, H, device=None):
        self.value_functions = value_functions
        self.H = H
        self.device = device

    def __call__(self, seq):
        h = len(seq)
        x = one_hot_encode(seq, h, self.device).unsqueeze(0)
        with torch.no_grad():
            return self.value_functions[h](x).item()

class Sampler:
    def __init__(self, piref, estimated_value, horizon):
        self.piref = piref
        self.estimated_value = estimated_value
        self.horizon = horizon
    def sample(self):
        pass
    
    def timeout_handler(self, signum, frame):
        raise TimeoutError("Sample operation timed out")
    
    def test(self, N):
        t_list = []
        timeout_seconds = 60  # 1 minute timeout
        
        print(f"Testing {self.__class__.__name__}")
        for i in range(N):
            print(f"Test {i}", flush=True)
            t_start = time.perf_counter()
            
            # Set up timeout signal
            old_handler = signal.signal(signal.SIGALRM, self.timeout_handler)
            signal.alarm(timeout_seconds)
            
            try:
                sequence, steps = self.sample()
                signal.alarm(0)  # Cancel the alarm
                t_stop = time.perf_counter()
                t_list.append((t_stop - t_start, steps))
            except TimeoutError:
                signal.alarm(0)  # Cancel the alarm
                print(f"Sample {i} timed out after {timeout_seconds} seconds")
                t_list.append((timeout_seconds, np.inf))  # Record timeout as the time taken
            except Exception as e:
                signal.alarm(0)  # Cancel the alarm
                print(f"Sample {i} failed with error: {e}")
                t_list.append((timeout_seconds, np.inf))  # Record timeout as the time taken
            finally:
                # Restore original signal handler
                signal.signal(signal.SIGALRM, old_handler)
                
        return t_list

class TokenwiseSampler(Sampler):
    def sample(self, repeat_if_fail = True):
        steps = 0
        while True:
            seq = []
            for _ in range(self.horizon):
                steps += 1
                scores = []
                for ch in self.piref.alphabet:
                    prefix = seq + [ch]
                    f_val = self.piref.sequence_prob(prefix) * self.estimated_value(prefix)
                    scores.append(f_val)

                probs = np.array(scores, dtype=np.float64)
                total = probs.sum()

                if total == 0.0 or np.isnan(total):
                    if repeat_if_fail:
                        break
                    else:
                        return None, steps
                probs /= total
                next_ch = np.random.choice(self.piref.alphabet, p=probs)
                seq.append(next_ch)
            else:
                return seq, steps

class JSSampler(Sampler):
    def get_weight(self, seq):
        return self.estimated_value(seq)
    def forward(self, sequence):
        choices = []
        probs = []
        #choices = [sequence] ##to make it lazy
        #probs = [0.5] ##to make it lazy

        if len(sequence) > 0:
            choices.append(sequence[:-1])
            probs.append( self.piref.sequence_prob(sequence) * self.get_weight(sequence))

        if len(sequence) < self.horizon:
            for a in self.piref.alphabet:
                next_seq = sequence + [a]
                choices.append(next_seq)
                probs.append(self.piref.sequence_prob(next_seq) * self.get_weight(next_seq))

        probs = np.array(probs, dtype=np.float64)
        total = probs.sum()
        if total == 0 or np.isnan(total):
            return sequence 
        probs /= total
        return choices[np.random.choice(len(choices), p=probs)]
    def sample(self):
        sequence = []
        steps = 0
        while len(sequence) < self.horizon:
            steps += 1
            sequence = self.forward(sequence)
        return sequence, steps




def generate_batch(piref, reward_func, val_func, h, batch_size, condition_on_good=False, device=None):
    X, y, y_true = [], [], []
    for _ in range(batch_size):
        seq = piref.generate(condition_on_good)
        r = reward_func(seq)
        prefix = seq[:h]                
        X.append(one_hot_encode(prefix, h, device))
        y.append(r)
        y_true.append(val_func(prefix))
    X = torch.stack(X)
    y = torch.tensor(y, dtype=torch.float)
    y_true = torch.tensor(y_true, dtype=torch.float)
    
    # Move tensors to device if specified
    if device is not None:
        X = X.to(device)
        y = y.to(device)
        y_true = y_true.to(device)
    
    return X,y,y_true

def estimate_deviation(piref, chunk_size, models, h, batch_size, device=None):
    deviation_probs = []
    for _ in range(batch_size):
        seq = piref.generate(condition_on_good = True)
        good_prefix = seq[:h]
        if (h+1)%chunk_size == 0:
            bad_prefix = good_prefix[:-1] + [1 - good_prefix[-1]]
        elif h%chunk_size == 0:
            bad_prefix = good_prefix[:-2] + [1 - good_prefix[-2], 1 - good_prefix[-1]]
        else:
            assert(False)
        good_x = one_hot_encode(good_prefix, h, device)
        bad_x = one_hot_encode(bad_prefix, h, device)
        good_p = models[h](good_x).item()
        bad_p = models[h](bad_x).item()
        deviation_probs.append(bad_p / (good_p + bad_p))
    return np.mean(deviation_probs)

def evaluate_model(piref, H, models, reward_func, batch_size, epoch, device=None):
    estimated_value = EstimatedValueUncorrected(models, H, device)
    tw_sampler = TokenwiseSampler(piref, estimated_value, H)
    successes = 0
    for _ in range(batch_size):
        seq = tw_sampler.sample(False)
        if epoch == 299:
            print(f"Epoch 299 sample - Reward: {reward_func(seq)}, Sequence: {seq}")
        successes += reward_func(seq)
    success_rate = successes / batch_size
    print(f"Model evaluation at epoch {epoch}: Success rate = {success_rate:.4f}", flush=True)
    return success_rate


def evaluate_conditional_loss(pred_value, value):
    return np.mean([(i-j)**2 for (i,j) in zip(pred_value, value) if j>0])

def run_sampling_eval(piref, models, reward_func, H, n_eval_samples, device=None):
    estimated_value = EstimatedValue(models, reward_func, H, device)
    tw_sampler = TokenwiseSampler(piref, estimated_value, H)
    sampler = JSSampler(piref, estimated_value, H)
    """ print("Testing JS") """
    js_work = sampler.test(n_eval_samples)
    js_avg_time = np.mean([t[0] for t in js_work])
    js_avg_steps = np.mean([t[1] for t in js_work])
    print(f"JS Sampler average time: {js_avg_time:.4f} seconds; average steps: {js_avg_steps:.4f}")
    """ print("Testing TW") """
    tw_work = tw_sampler.test(n_eval_samples)
    tw_avg_time = np.mean([t[0] for t in tw_work])
    tw_avg_steps = np.mean([t[1] for t in tw_work])
    print(f"Tokenwise Sampler average time: {tw_avg_time:.4f} seconds; average steps: {tw_avg_steps:.4f}")
    return js_work, tw_work
    

def train(chunk_size, n_chunks, batch_size = 128, epochs = 5000, learning_rate = 0.001, n_eval_samples = 100, eval_frequency = 100, first_eval_epoch = 0):
    # Check if GPU is available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}",flush=True)
    
    def reward_func(seq):
        return reward(seq, chunk_size)
    piref = ParityPolicy(chunk_size, n_chunks)
    H = chunk_size * n_chunks
    def val_func(prefix):
        rem_choices = n_chunks
        for i in range(chunk_size-1, H, chunk_size):
            if i <= len(prefix):
                if sum(prefix[i+1-chunk_size:i])%2 == 1:
                    return 0
                rem_choices -= 1
        val = 2 ** (-rem_choices)
        return val
    models = {}
    optimizers = {}
    fig_lines = {}
    fig_x_data = []
    fig_y_data = {}
    for h in range(1, H+1):
        models[h] = ValueFunction(h).to(device)
        optimizers[h] = torch.optim.Adam(models[h].parameters(), lr=learning_rate)
        fig_y_data[h] = []
    epoch_list = []
    js_work_list = {}
    tw_work_list = {}
    
    print(f"Starting training with {epochs} epochs, chunk_size={chunk_size}, n_chunks={n_chunks}")
    print(f"Horizon: {H}, Batch size: {batch_size}, Learning rate: {learning_rate}")
    
    for epoch in range(epochs):
        print(f"Training epoch {epoch}", flush=True)
        train_errors = {}
        
        # Monitor memory usage every 1000 epochs
        if epoch % 1000 == 0:
            if torch.cuda.is_available():
                gpu_memory_allocated = torch.cuda.memory_allocated() / 1024**3
                gpu_memory_cached = torch.cuda.memory_reserved() / 1024**3
                print(f"Epoch {epoch}: GPU memory allocated: {gpu_memory_allocated:.2f} GB")
                print(f"Epoch {epoch}: GPU memory cached: {gpu_memory_cached:.2f} GB")
        
        for h in range(1, H+1):
            X, y, y_true = generate_batch(piref, reward_func, val_func, h, batch_size, device=device)
            optimizers[h].zero_grad()
            preds = models[h](X)
            assert preds.requires_grad
            loss = F.mse_loss(preds, y)
            loss.backward()
            optimizers[h].step()
            if (h+1)%chunk_size in [0,1]:
                if epoch%10 == 0:
                    fig_y_data[h].append(estimate_deviation(piref, chunk_size, models, h, 128, device))
        if epoch%10 == 0:
            fig_x_data.append(epoch)
            
            # Periodic memory cleanup to prevent fragmentation
            if torch.cuda.is_available() and epoch % 1000 == 0:
                torch.cuda.empty_cache()
                print(f"Epoch {epoch}: Memory cache cleared")
        if epoch%eval_frequency == 0 and epoch >= first_eval_epoch:
            print(f"Running evaluation at epoch {epoch}")
            js_work, tw_work = run_sampling_eval(piref, models, reward_func, H, n_eval_samples, device=device)
            js_work_list[epoch] = js_work
            tw_work_list[epoch] = tw_work
            epoch_list.append(epoch)
            print(f"Evaluation at epoch {epoch} completed")
    
    print("Training completed successfully")
    return epoch_list, js_work_list, tw_work_list, fig_x_data, fig_y_data



def save_training_results(filename, epoch_list, js_work_list, tw_work_list, train_progress_x, train_progress_y, chunk_size, n_chunks, epochs, first_eval_epoch):
    """Save training results to disk"""
    results = {
        'epoch_list': epoch_list,
        'js_work_list': js_work_list,
        'tw_work_list': tw_work_list,
        'train_progress_x': train_progress_x,
        'train_progress_y': train_progress_y,
        'chunk_size': chunk_size,
        'n_chunks': n_chunks,
        'epochs': epochs,
        'first_eval_epoch': first_eval_epoch
    }
    with open(filename, 'wb') as f:
        pickle.dump(results, f)
    print(f"Training results saved to {filename}")

def load_training_results(filename):
    """Load training results from disk"""
    with open(filename, 'rb') as f:
        results = pickle.load(f)
    print(f"Training results loaded from {filename}")
    return results

if __name__ == "__main__":

    # Parse command line arguments
    parser = argparse.ArgumentParser(description="Train Parity RL Model")
    parser.add_argument('--chunk_size', type=int, default=5, help='Size of each chunk (default: 5)')
    parser.add_argument('--n_chunks', type=int, default=8, help='Number of chunks (default: 8)')
    parser.add_argument('--epochs', type=int, default=10000, help='Number of training epochs (default: 10000)')
    parser.add_argument('--first_eval_epoch', type=int, default=1000, help='First evaluation epoch (default: 1000)')
    parser.add_argument('--eval_frequency', type=int, default=1000, help='How often to evaluate during training (default: 1000)')
    parser.add_argument('--output_name', type=str, help='Output filename for training results (auto-generated if not specified)')
    parser.add_argument('--verbose', action='store_true', help='Enable verbose logging')
    args = parser.parse_args()
    
    # Set default output filename based on chunk_size and n_chunks if not specified
    if args.output_name is None:
        output_name = f'parity_results_c{args.chunk_size}_n{args.n_chunks}'
    else:
        output_name = args.output_name

    
    # Log command line arguments
    print(f"Command line arguments: {vars(args)}")
    print(f"Output name: {output_name}")

    try:
        epoch_list, js_work_list, tw_work_list, train_progress_x, train_progress_y = train(
            args.chunk_size, args.n_chunks, epochs=args.epochs, first_eval_epoch=args.first_eval_epoch, eval_frequency=args.eval_frequency
        )
        
        # Save training results
        save_training_results(output_name + '.pkl', epoch_list, js_work_list, tw_work_list, train_progress_x, train_progress_y, 
                             args.chunk_size, args.n_chunks, args.epochs, args.first_eval_epoch)
        
        
        print("=== Training completed successfully ===")
        print(f"Results saved to: {output_name}.pkl")
        
    except Exception as e:
        print(f"Training failed with error: {e}")
        print(f"Error details: {type(e).__name__}: {str(e)}")
        import traceback
        print(f"Traceback: {traceback.format_exc()}")
        raise
