import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

#sys.path.append('src')
from dyck import RandomWalkDyck
from lm import conditional_nn_generate, nn_intermediates
from utils import dyck_reward, one_hot_encode
from value_functions import LightValueFunction, ValueFunction
from config_for_dyck import config

import pickle
import argparse

NN_FEATURE_SIZE = 512 # put this in config somewhere, maybe

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train_value_functions_from_reps(model, train_data, A, B, save_dir, checkpoints, batch_size=32, all_pos=False, steps=None, layer = -1):
    if steps is None:
        steps = range(1, B+1)
    vf_list = {}
    optimizer_list = {}
    for h in steps:
        if all_pos:
            vf_list[h] = LightValueFunction(NN_FEATURE_SIZE * (B+1)).to(device)  # BOS + A + h
        else:
            vf_list[h] = LightValueFunction(NN_FEATURE_SIZE).to(device)  # BOS + A + h
        optimizer_list[h] = optim.AdamW(vf_list[h].parameters(), lr=0.003, weight_decay=0.1)
    num_samples = len(train_data[0])
    num_batches = num_samples // batch_size
    if num_batches * batch_size < num_samples:
        num_batches += 1
    Xtrain = {}
    Ytrain = {}
    for h in steps:
        sequences, rewards = zip(*train_data[h-1])
        if all_pos:
            intermediate_reps = nn_intermediates(model, sequences, positions='all', layer=layer)
            flattened_reps = []
            print(np.array(intermediate_reps).shape)
            for rep in intermediate_reps:
                flattened_reps.append(np.array(rep)[A:].flatten())
            Xtrain[h] = torch.tensor(np.array(flattened_reps), dtype=torch.float32)
            print(Xtrain[h].shape, flush=True)
        else:
            Xtrain[h] = torch.tensor(np.array(nn_intermediates(model, sequences, layer=layer)), dtype=torch.float32)
        Ytrain[h] = torch.tensor(np.array(rewards), dtype=torch.float32)
    num_train_batches = int(0.9 * num_batches)
    num_train_samples = batch_size * num_train_batches
    
    h_test = 16
    Xtest = Xtrain[h_test][num_train_samples:]
    Ytest = Ytrain[h_test][num_train_samples:]
    for epoch in range(40):
        for i in range(num_train_batches):
            start_idx = i * batch_size
            end_idx = min((i+1) * batch_size, num_train_samples)
            for h in steps:
                optimizer_list[h].zero_grad()
                preds = vf_list[h](Xtrain[h][start_idx:end_idx].to(device))
                loss = nn.MSELoss()(preds, Ytrain[h][start_idx:end_idx].to(device))
                loss.backward()
                optimizer_list[h].step()
        vf_list[h_test].eval()
        with torch.no_grad():
            preds = vf_list[h_test](Xtrain[h_test][:num_train_samples].to(device))
            loss = nn.MSELoss()(preds, Ytrain[h_test][:num_train_samples].to(device))
            thresholded_preds = (preds>=0.5)
            true_rewards = Ytrain[h_test][:num_train_samples].to(device)
            #print(preds.shape, thresholded_preds.shape, true_rewards.shape)
            accuracy = sum(thresholded_preds == true_rewards) / (num_train_samples)
            print(f"Epoch {epoch}, horizon {h_test}: training loss: {loss.item():.4f}, accuracy: {accuracy}", flush=True)
            preds = vf_list[h_test](Xtest.to(device))
            loss = nn.MSELoss()(preds, Ytest.to(device))
            thresholded_preds = (preds>=0.5)
            true_rewards = Ytest.to(device)
            #print(preds.shape, thresholded_preds.shape, true_rewards.shape)
            accuracy = sum(thresholded_preds == true_rewards) / (num_samples - num_train_samples)
            print(f"Epoch {epoch}, horizon {h_test}: validation loss: {loss.item():.4f}, accuracy: {accuracy}", flush=True)
        vf_list[h_test].train()
            #if i in checkpoints:
                #save_path = save_dir + "/" + "n" + str(end_idx) + ".pt"
                #print(f"Saving checkpoint to {save_path}")
                #os.makedirs(save_dir, exist_ok=True)
                #with open(save_path,'wb') as f:
                #    pickle.dump(vf_list, f)


def evaluate_value_functions_denoised(model, train_data, A, B, load_dir, vocab_size, batch_size=32):
    preds = {}
    errors = []
    for fname in ['e1.pkl', 'e40.pkl']:
        load_path = load_dir + "/" + fname
        print(f"Evaluating {load_path}")
        with open(load_path, 'rb') as f:
            vf_list = pickle.load(f)
        num_samples = len(train_data[0])
        num_batches = num_samples // batch_size
        if num_batches * batch_size < num_samples:
            num_batches += 1
        Xtrain = {}
        Ytrain = {}
        for h in range(1,B+2):
            sequences, rewards = zip(*train_data[h-1])
            sequences = [one_hot_encode(seq, 1+A+h, vocab_size) for seq in sequences]
            Xtrain[h] = torch.tensor(np.array(sequences), dtype=torch.float32).to(device)
            Ytrain[h] = torch.tensor(np.array(rewards), dtype=torch.float32).to(device)
        num_train_batches = int(0.9 * num_batches)
        num_train_samples = batch_size * num_train_batches
        
        Xtest = {}
        Ytest = {}
        for h in range(1,B+2):
            Xtest[h] = Xtrain[h][num_train_samples:]
            Ytest[h] = Ytrain[h][num_train_samples:]
        with torch.no_grad():
            preds[fname] = {}
            for h in range(1, B+2):
                vf_list[h].eval()
                preds[fname][h] = vf_list[h](Xtrain[h][:num_train_samples]).detach().cpu()
    
    for h in range(1, B+2):
        error_to_last = nn.MSELoss()(preds['e1.pkl'][h], preds['e40.pkl'][h])
        print(f"Horizon {h}, error {error_to_last:.4f}")
        errors.append({'epoch': 'e1.pkl', 'horizon': h, 'error_to_last': error_to_last})
    
    output_path = load_dir + "/" + "error_to_last.pkl"
    print(f"Saving errors to {output_path}")
    with open(output_path, "wb") as f:
        pickle.dump(errors, f)

def evaluate_value_functions(model, train_data, A, B, load_dir, vocab_size, batch_size=32):
    losses = []
    for e in range(1, 11):
        fname = f"e{e}.pkl"
        load_path = load_dir + "/" + fname
        print(f"Evaluating {load_path}")
        with open(load_path, 'rb') as f:
            vf_list = pickle.load(f)
        num_samples = len(train_data[0])
        num_batches = num_samples // batch_size
        if num_batches * batch_size < num_samples:
            num_batches += 1
        Xtrain = {}
        Ytrain = {}
        for h in range(1,B+2):
            sequences, rewards = zip(*train_data[h-1])
            sequences = [one_hot_encode(seq, 1+A+h, vocab_size) for seq in sequences]
            Xtrain[h] = torch.tensor(np.array(sequences), dtype=torch.float32).to(device)
            Ytrain[h] = torch.tensor(np.array(rewards), dtype=torch.float32).to(device)
        num_train_batches = int(0.9 * num_batches)
        num_train_samples = batch_size * num_train_batches
        
        Xtest = {}
        Ytest = {}
        for h in range(1,B+2):
            Xtest[h] = Xtrain[h][num_train_samples:]
            Ytest[h] = Ytrain[h][num_train_samples:]
        with torch.no_grad():
            for h in range(1, B+2):
                vf_list[h].eval()
                preds = vf_list[h](Xtrain[h][:num_train_samples]).detach().cpu()
                true_rewards = Ytrain[h][:num_train_samples].cpu()
                train_mse = ((preds - true_rewards)**2).mean()
                mask = (true_rewards==1)
                train_cond_mse = ((preds[mask] - true_rewards[mask])**2).mean()
                #print(preds.shape, thresholded_preds.shape, true_rewards.shape)
                preds = vf_list[h](Xtest[h]).detach().cpu()
                true_rewards = Ytest[h].cpu()
                val_mse = ((preds - true_rewards)**2).mean()
                mask = (true_rewards==1)
                val_cond_mse = ((preds[mask] - true_rewards[mask])**2).mean()
                #print(preds.shape, thresholded_preds.shape, true_rewards.shape)
                print(f"Epoch {fname}, horizon {h}: validation loss: {val_mse:.4f}, validation cond loss: {val_cond_mse:.4f}",flush=True)
                losses.append({'epoch': fname, 'horizon': h, 'train_mse': train_mse, 'train_cond_mse': train_cond_mse, 'val_mse': val_mse, 'val_cond_mse': val_cond_mse})
    output_path = load_dir + "/" + "cond_mse.pkl"
    print(f"Saving MSEs to {output_path}")
    with open(output_path, "wb") as f:
        pickle.dump(losses, f)

        

def train_value_functions_from_scratch(model, train_data, A, B, save_dir, num_epochs, vocab_size, batch_size=32, intermediate_checkpoint_period=10000, eval_acc = True):
    vf_list = {}
    optimizer_list = {}
    #print(len(train_data))
    
    last_saved_epoch = -1
    for epoch in range(num_epochs+1):
        save_path = f"{save_dir}/e{epoch}.pkl"
        if os.path.exists(save_path):
            last_saved_epoch = epoch
        else:
            break

    if last_saved_epoch > -1:  
        save_path = f"{save_dir}/e{last_saved_epoch}.pkl"
        with open(save_path, 'rb') as f:
            vf_list = pickle.load(f)
    else:
        for h in range(1, B+2):
            vf_list[h] = ValueFunction(1+A+h, vocab_size).to(device)  # BOS + A + h
    
    for h in range(1, B+2):
        optimizer_list[h] = optim.AdamW(vf_list[h].parameters(), lr=0.003, weight_decay=0.1)
    
    num_samples = len(train_data[0])
    num_batches = num_samples // batch_size
    if num_batches * batch_size < num_samples:
        num_batches += 1
    Xtrain = {}
    Ytrain = {}
    for h in range(1,B+2):
        sequences, rewards = zip(*train_data[h-1])
        sequences = [one_hot_encode(seq, 1+A+h, vocab_size) for seq in sequences]
        Xtrain[h] = torch.tensor(np.array(sequences), dtype=torch.float32).to(device)
        Ytrain[h] = torch.tensor(np.array(rewards), dtype=torch.float32).to(device)
    num_train_batches = int(0.9 * num_batches)
    num_train_samples = batch_size * num_train_batches
    
    Xtest = {}
    Ytest = {}
    for h in range(1,B+2):
        Xtest[h] = Xtrain[h][num_train_samples:]
        Ytest[h] = Ytrain[h][num_train_samples:]
    
    if last_saved_epoch == -1:
        #Save initial checkpoint
        save_path = save_dir + "/" + "e" + str(0) + ".pkl"
        print(f"Saving checkpoint to {save_path}")
        os.makedirs(save_dir, exist_ok=True)
        with open(save_path,'wb') as f:
            pickle.dump(vf_list, f)
    
    for epoch in range(max(last_saved_epoch, 0), num_epochs):
        last_intermediate_checkpoint = 0
        for i in range(num_train_batches):
            start_idx = i * batch_size
            end_idx = min((i+1) * batch_size, num_train_samples)
            for h in range(1, B+2):
                optimizer_list[h].zero_grad()
                preds = vf_list[h](Xtrain[h][start_idx:end_idx])
                loss = nn.MSELoss()(preds, Ytrain[h][start_idx:end_idx])
                loss.backward()
                optimizer_list[h].step()
            if end_idx // intermediate_checkpoint_period > last_intermediate_checkpoint // intermediate_checkpoint_period:
                save_path = f"{save_dir}/e{epoch}n{end_idx}.pkl"
                print(f"Saving intermediate checkpoint to {save_path}",flush=True)
                os.makedirs(save_dir, exist_ok=True)
                with open(save_path,'wb') as f:
                    pickle.dump(vf_list, f)
                last_intermediate_checkpoint = end_idx
        if eval_acc:
            with torch.no_grad():
                for h in range(1, B+2):
                    vf_list[h].eval()
                    preds = vf_list[h](Xtrain[h][:num_train_samples])
                    loss = nn.MSELoss()(preds, Ytrain[h][:num_train_samples])
                    thresholded_preds = (preds>=0.5)
                    true_rewards = Ytrain[h][:num_train_samples]
                    #print(preds.shape, thresholded_preds.shape, true_rewards.shape)
                    accuracy = sum(thresholded_preds == true_rewards) / (num_train_samples)
                    print(f"Epoch {epoch}, horizon {h}: training loss: {loss.item():.4f}, accuracy: {accuracy}", flush=True)
                    preds = vf_list[h](Xtest[h])
                    loss = nn.MSELoss()(preds, Ytest[h])
                    thresholded_preds = (preds>=0.5)
                    true_rewards = Ytest[h]
                    #print(preds.shape, thresholded_preds.shape, true_rewards.shape)
                    accuracy = sum(thresholded_preds == true_rewards) / (num_samples - num_train_samples)
                    print(f"Epoch {epoch}, horizon {h}: validation loss: {loss.item():.4f}, accuracy: {accuracy}",flush=True)
                    vf_list[h].train()
        
        save_path = save_dir + "/" + "e" + str(epoch+1) + ".pkl"
        print(f"Saving checkpoint to {save_path}",flush=True)
        os.makedirs(save_dir, exist_ok=True)
        with open(save_path,'wb') as f:
            pickle.dump(vf_list, f)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str)
    parser.add_argument("--train_path", type=str)
    parser.add_argument("--prefix_length", type=int)
    parser.add_argument("--total_length", type=int)
    parser.add_argument("--num_epochs", type=int, default=10)
    parser.add_argument("--checkpoint_period", type=int, default=10000)
    parser.add_argument("--evaluate_existing", action='store_true')
    args = parser.parse_args()
    
    #Load model
    with open(args.model_path, 'rb') as f:
        model = pickle.load(f)

    with open(args.train_path, 'rb') as f:
        training_data = pickle.load(f)

    save_dir = args.train_path.replace("value_datasets", "trained_values")
    save_dir = os.path.splitext(save_dir)[0]

    print("Save directory",save_dir)

    num_samples = len(training_data[0])
    vocab_size = 2 * config['num_types'] + 4

    if args.evaluate_existing:
        print("Evaluating existing value functions...")
        evaluate_value_functions(model, training_data, args.prefix_length, args.total_length - args.prefix_length, save_dir, vocab_size)
    else:
        print("Training new value functions...")
        train_value_functions_from_scratch(model, training_data, args.prefix_length, args.total_length - args.prefix_length, save_dir, args.num_epochs, vocab_size, intermediate_checkpoint_period = args.checkpoint_period)



