import os
import json
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn
import pdb
import wandb
from copy import deepcopy
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--pretrain', type=str, default='NTP', help='NTP or state')
parser.add_argument('--architecture', type=str, default='transformer', help='rnn lstm transformer mamba mamba2')
args = parser.parse_args()

pretrain = args.pretrain  # 'NTP' or 'state'
architecture = args.architecture  # mamba or mamba2 or transformer or rnn or lstm

wandb_log = True

dataset = 'gridworld/5-states-restrict-moves'
data_dir = f'data/{dataset}/white-noise-{pretrain}-{architecture}'

wandb_project = 'gridworld'

name = f'*{architecture}-{pretrain}'

batch_size = 600
num_iters = 5000 # 5000
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dtype = 'bfloat16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)


config_path = os.path.join(data_dir, 'config.json')
with open(config_path, 'r') as f:
    config = json.load(f)


def get_batch(data_dir, split, batch_size):
    # Set up memmap files
    states = np.memmap(os.path.join(data_dir, f'{split}_states.bin'),
                       dtype=np.uint8, mode='r',
                       shape=(config[f'num_{split}_examples'],
                              config['num_state_dimensions']))
    y = np.memmap(os.path.join(data_dir, f'{split}_y.bin'),
                  dtype=np.float16, mode='r',
                  shape=(config[f'num_{split}_examples'], config['num_heads']))
    # Sample random indices
    ix = torch.randint(len(y), (batch_size,))
    # Extract batches
    x = torch.stack([torch.from_numpy((states[i]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.tensor(y[i], dtype=torch.float16) for i in ix])
    # shuffle the order of y
    # y = y[torch.randperm(y.size(0))]
    # Move to device and use pin_memory for faster transfer to GPU
    if device.type == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y


class GridworldNet(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=2):
        super(GridworldNet, self).__init__()
        self.embedding = nn.Embedding(config['num_states'], hidden_dim)
        self.hidden_layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers)
        ])
        self.output_layer = nn.Linear(hidden_dim, config['num_heads'])

    def forward(self, x):
        x = self.embedding(x)
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        pred = self.output_layer(x)
        return pred, x


def train_model(l1_penalty):
    model = GridworldNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    best_val_r2 = 0.
    best_model = None

    for iter in range(num_iters):
        model.train()
        x, y = get_batch(data_dir, 'train', batch_size)
        optimizer.zero_grad()
        with ctx:
            preds, representation = model(x.view(-1))
            train_loss = torch.mean((preds - y) ** 2)
            l1_loss = l1_penalty * torch.norm(representation, 1)
            total_loss = train_loss + l1_loss
        total_loss.backward()
        optimizer.step()

        if (iter + 1) % 100 == 0:  # Evaluate every 100 iters
            model.eval()
            x, y = get_batch(data_dir, 'val', 10000)
            with torch.no_grad():
                with ctx:
                    preds, representation = model(x.view(-1))
                    val_loss = torch.mean((preds - y) ** 2)
                    val_l1_loss = l1_penalty * torch.norm(representation, 1)
                    val_total_loss = val_loss + val_l1_loss

            val_r2 = 1 - val_loss / var_y

            if val_r2 > best_val_r2:
                best_val_r2 = val_r2
                best_model = deepcopy(model)

            # if wandb_log:
            #     wandb.log({
            #         "iter": iter + 1,
            #         "train_loss": train_loss.item(),
            #         "train_l1_loss": l1_loss.item(),
            #         "val_loss": val_loss.item(),
            #         "val_l1_loss": val_l1_loss.item(),
            #         "val_total_loss": val_total_loss.item(),
            #         "val_r2": val_r2.item()
            #     }, step=iter)
            print(f"Iter {iter+1}, Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}, Val Total Loss: {val_total_loss.item():.4f}, Val R²: {val_r2:.4f}")

    return best_model, best_val_r2

if wandb_log:
    wandb.init(project=wandb_project, name=name)

# Get test MSE for R2
y = np.memmap(os.path.join(data_dir, f'val_y.bin'),
              dtype=np.float16, mode='r',
              shape=(config[f'num_val_examples'], config['num_heads']))

var_y = np.mean(((y - np.mean(y, axis=0)) ** 2))
del y

l1_penalties = [0.0, 0.0001, 0.1, 1.0]  # Add or modify penalties as needed
best_overall_model = None
best_overall_r2 = 0.
best_penalty = None

for penalty in l1_penalties:
    print(f"\nTraining with L1 penalty: {penalty}")
    model, val_r2 = train_model(penalty)
    
    if val_r2 > best_overall_r2:
        best_overall_r2 = val_r2
        best_overall_model = model
        best_penalty = penalty

print(f"\nBest model achieved with L1 penalty: {best_penalty}")
print(f"Best R²: {best_overall_r2:.4f}")

if wandb_log:
    wandb.log({
        "best_val_r2": best_overall_r2,
        "best_l1_penalty": best_penalty
    })

if wandb_log:
    wandb.finish()


def get_all_data(data_dir, split, batch_size=None, start_ind=None, end_ind=None):
    states = np.memmap(os.path.join(data_dir, f'{split}_states.bin'),
                       dtype=np.uint8, mode='r',
                       shape=(config[f'num_{split}_examples'],
                              config['num_state_dimensions']))
    if start_ind is not None:
        ix = None
        x = torch.from_numpy(states[start_ind:end_ind].astype(np.int64))
    else:
        ix = torch.randint(len(states), (batch_size,))
        x = torch.stack([torch.from_numpy((states[i]).astype(np.int64)) for i in ix])
    if device.type == 'cuda':
        x = x.pin_memory().to(device, non_blocking=True)
    else:
        x = x.to(device)
    return x, ix


model = best_overall_model
model.eval()
num_train_examples = config['num_train_examples']
num_val_examples = config['num_val_examples']
chunk_size = 10000  # Adjust this based on your memory constraints
train_representations = []
valid_representations = []

with torch.no_grad():
    bar = range(0, num_train_examples, chunk_size)
    for start_ind in bar:
        end_ind = min(start_ind + chunk_size, num_train_examples)
        train_x, _ = get_all_data(data_dir, 'train', start_ind=start_ind, end_ind=end_ind)
        train_x = train_x.to(device)
        with ctx:
            _, batch_repr = model(train_x.view(-1))
        train_representations.append(batch_repr.cpu())
        print(f"Processed {end_ind}/{num_train_examples} examples")
    #
    bar = range(0, num_val_examples, chunk_size)
    for start_ind in bar:
        end_ind = min(start_ind + chunk_size, num_val_examples)
        valid_x, _ = get_all_data(data_dir, 'val', start_ind=start_ind, end_ind=end_ind)
        valid_x = valid_x.to(device)
        with ctx:
            _, batch_repr = model(valid_x.view(-1))
        valid_representations.append(batch_repr.cpu())
        print(f"Processed {end_ind}/{num_val_examples} examples")

train_representations = torch.cat(train_representations, dim=0).to(device)
valid_representations = torch.cat(valid_representations, dim=0).to(device)


class StatePredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(StatePredictor, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 1)
        ])
        self.output_layer = nn.Linear(hidden_dim, config['num_states'])
    #
    def forward(self, x):
        x = F.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        return self.output_layer(x)

num_layers = 1
input_dim = train_representations.shape[1]
hidden_dim = 512 # 128

state_predictor = StatePredictor(input_dim, hidden_dim, num_layers).to(device)
optimizer = torch.optim.Adam(state_predictor.parameters(), lr=0.001)

if wandb_log:
    wandb.init(project=wandb_project, name=name + '-reconstruction')

best_val_loss = float('inf')
for iter in range(num_iters):
    _ = state_predictor.train()
    states, ix = get_all_data(data_dir, 'train', batch_size)
    reps = train_representations[ix].to(device)
    optimizer.zero_grad()
    with ctx:
        predicted_states = state_predictor(reps)
        train_loss = F.cross_entropy(predicted_states, states.view(-1))
    train_loss.backward()
    optimizer.step()
    #
    if (iter + 1) % 100 == 0:  # Print every 100 iters
        _ = state_predictor.eval()
        states, _ = get_all_data(data_dir, 'val', start_ind=0, end_ind=len(valid_representations))
        reps = valid_representations.to(device)
        with torch.no_grad():
            with ctx:
                predicted_states = state_predictor(reps)
                val_loss = F.cross_entropy(predicted_states, states.view(-1))
        print(f"Iter {iter+1}, Train Loss: {train_loss.item():.4f}, Validation Loss: {val_loss.item():.4f}")
        #
        if val_loss < best_val_loss:
            best_val_loss = val_loss
        if wandb_log:
            wandb.log({
                "iter": iter + 1,
                "train_reconstruction_loss": train_loss.item(),
                "val_reconstruction_loss": val_loss,
                "best_reconstruction_loss": best_val_loss,
            }, step=iter)

if wandb_log:
    wandb.finish()

print(f"Results for {name}")
print(f"Best R²: {best_overall_r2:.4f}")
print(f"Best reconstruction loss: {best_val_loss:.4f}")
