import hydra
from omegaconf import OmegaConf

import sys
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random

from transformers import AutoModelForCausalLM, LlamaForCausalLM, AutoTokenizer, DynamicCache
#sys.path.append('src')
from lm import conditional_nn_generate, nn_intermediates
from value_functions import ValueFunction, NormalizedValueFunction
from torch.utils.data import TensorDataset, DataLoader
from utils import get_free_device
import pickle
import argparse
from dataset_utils import load_dataset


def compute_rep_from_seq(token_ids, model, layer, which_tokens, mean_range, device):
    return np.array(nn_intermediates(model, [token_ids], batch_size=1, positions=which_tokens, layer=layer, mean_range = mean_range, device=device)[0])


def compute_pool_range(prompt_ids, all_ids, which_tokens):
    if which_tokens == 'last':
        max_window_length = 1
    elif which_tokens == 'mean':
        max_window_length = len(all_ids)
    elif which_tokens[:4] == 'mean':
        max_window_length = int(which_tokens[4:])
    else:
        assert False
    low = max(len(prompt_ids), len(all_ids) - max_window_length)
    high = len(all_ids)
    return range(low, high)

def train_value_function_from_reps(model, tokenizer, raw_data, reps_save_path, vf_save_dir, model_dim, cfg_rep, cfg_vf, device="cuda"):
    #if os.path.exists(vf_save_dir):
    #    print("Value functions already exist")
    #    return
    print(cfg_rep)
    print(cfg_vf)
    assert ("testing" in reps_save_path) == ("testing" in vf_save_dir)
    ValueFunctionClass = eval(cfg_vf.class_name)
    vf = ValueFunctionClass(model_dim, cfg_vf.hidden_dim, cfg_vf.num_layers).to(device)  # BOS + A + h
    optimizer = optim.AdamW(vf.parameters(), lr=cfg_vf.lr, weight_decay=cfg_vf.wd)
    
    if os.path.exists(reps_save_path):
        print(f"Rep data already exists; loading rep data from {reps_save_path}")
        with open(reps_save_path,'rb') as f:
            reps = pickle.load(f)
            X, Y = reps
        for i in range(3):
            d = raw_data[i]
            print("Index",i)
            prompt_tokens = d['prompt_ids']
            all_tokens = d['output_ids']
            assert all_tokens[:len(prompt_tokens)] == prompt_tokens
            for pos in range(len(prompt_tokens)+1, len(all_tokens)+1):
                pool_range = compute_pool_range(prompt_tokens, all_tokens[:pos], cfg_rep.tokens)
                hidden_rep = compute_rep_from_seq(all_tokens[:pos], model, cfg_rep.layer_index, cfg_rep.tokens, pool_range, device)
                print(f"Position {pos}, norm of difference in reps {np.linalg.norm(hidden_rep - np.array(X[i]))}")
    else:
        print("Rep data does not already exist; computing rep data from raw data")

        X = []
        Y = []
        for d in raw_data:
            prompt_tokens = d['prompt_ids']
            all_tokens = d['output_ids']
            assert all_tokens[:len(prompt_tokens)] == prompt_tokens
            pos = random.choice(range(len(prompt_tokens)+1, len(all_tokens)+1))
            mean_range = compute_pool_range(prompt_tokens, all_tokens[:pos], cfg_rep.tokens)
            hidden_rep = compute_rep_from_seq(all_tokens[:pos], model, cfg_rep.layer_index, cfg_rep.tokens, mean_range, device)
            
            X.append(hidden_rep)
            Y.append(d['reward'])
            if len(X)%100 == 0:
                print(f"Computed reps for {len(X)} generations", flush=True)

        X = torch.tensor(np.array(X), dtype=torch.float32)
        Y = torch.tensor(np.array(Y), dtype=torch.float32)

        print(f"Saving rep data to {reps_save_path}")
        os.makedirs(os.path.dirname(reps_save_path), exist_ok=True)
        with open(reps_save_path,'wb') as f:
            pickle.dump((X,Y), f)

    batch_size = cfg_vf.batch_size
    num_epochs = cfg_vf.num_epochs


    num_samples = len(raw_data)
    print("Num samples:", num_samples)
    num_batches = num_samples // batch_size
    if num_batches * batch_size < num_samples:
        num_batches += 1
    num_train_batches = int(0.9 * num_batches)
    num_train_samples = batch_size * num_train_batches
    
    Xtrain = X[:num_train_samples]
    Ytrain = Y[:num_train_samples]
    Xtest = X[num_train_samples:]
    Ytest = Y[num_train_samples:]

    train_dataset = TensorDataset(Xtrain, Ytrain)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = TensorDataset(Xtest, Ytest)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    save_path = vf_save_dir + "/" + "e" + str(0) + ".pkl"
    print(f"Saving initial checkpoint to {save_path}",flush=True)
    os.makedirs(vf_save_dir, exist_ok=True)
    with open(save_path,'wb') as f:
        pickle.dump(vf, f)
    
    for epoch in range(num_epochs):
        total_training_loss = 0
        for Xb, Yb in train_loader:
            optimizer.zero_grad()
            preds = vf(Xb.to(device))
            #if i <= 10:
            #    print(preds)
            loss = nn.MSELoss()(preds, Yb.to(device))
            total_training_loss += loss.item() * Xb.size(0)
            loss.backward()
            optimizer.step()
        avg_training_loss = total_training_loss / len(train_loader.dataset)
        print(f"Epoch {epoch}: training loss: {avg_training_loss:.4f}")
        vf.eval()
        with torch.no_grad():
            total_test_loss = 0
            total_test_cond = 0
            for Xb, Yb in test_loader:
                preds = vf(Xb.to(device))
                loss = nn.MSELoss()(preds, Yb.to(device))
                true_rewards = Yb.to(device)
                mask = (true_rewards==1)
                cond_mse = ((preds[mask] - true_rewards[mask])**2).mean()
                total_test_loss += loss.item() * Xb.size(0)
                total_test_cond += cond_mse * Xb.size(0)
            avg_test_loss = total_test_loss / len(test_loader.dataset)
            avg_test_cond = total_test_cond / len(test_loader.dataset)
        print(f"Epoch {epoch}: validation loss: {avg_test_loss:.4f}, conditional MSE: {avg_test_cond}", flush=True)
        vf.train()
        save_path = vf_save_dir + "/" + "e" + str(epoch+1) + ".pkl"
        print(f"Saving checkpoint to {save_path}",flush=True)
        os.makedirs(vf_save_dir, exist_ok=True)
        with open(save_path,'wb') as f:
            pickle.dump(vf, f)

CONFIG_NAME = None
if "--config-name" in sys.argv:
        CONFIG_NAME = sys.argv[sys.argv.index("--config-name") + 1]
else:
    CONFIG_NAME = "main_codellama"

@hydra.main(config_path='../hydra_configs', config_name=CONFIG_NAME, version_base=None)
def main(cfg):
    # add runtime info to cfg
    OmegaConf.set_struct(cfg, False)
    cfg.meta = OmegaConf.create({})
    cfg.meta.original_dir = hydra.utils.get_original_cwd()
    cfg.meta.run_dir = os.getcwd()
    if torch.cuda.is_available():
        free_mem = [torch.cuda.mem_get_info(i)[0] for i in range(torch.cuda.device_count())]
        best_gpu = free_mem.index(max(free_mem))
        cfg.meta.device = f"cuda:{best_gpu}"
    else:
        cfg.meta.device = "cpu"
    print(cfg)
    torch.manual_seed(cfg.seed)
    random.seed(cfg.seed)


    if cfg.model.precision == 'hp':
        dtype = torch.float16
    else:
        dtype = torch.float32

    if "llama" in CONFIG_NAME:
        model = LlamaForCausalLM.from_pretrained(cfg.model.name,torch_dtype=dtype).to(cfg.meta.device)
    else:
        model = AutoModelForCausalLM.from_pretrained(cfg.model.name,torch_dtype=dtype).to(cfg.meta.device)

    model.eval()
    tokenizer = AutoTokenizer.from_pretrained(cfg.model.name,torch_dtype=dtype)

    raw_data = load_dataset(cfg.fs.generation_save_path)
    
    reps_save_path = cfg.fs.reps_save_path
    vf_save_dir = cfg.fs.vf_save_dir

    train_value_function_from_reps(model, tokenizer, raw_data, reps_save_path, vf_save_dir, cfg.model.dim, cfg.rep, cfg.vf, device=cfg.meta.device)





if __name__ == "__main__":
    main()



