import os
import json
import numpy as np
import torch
from torch.nn import functional as F
import torch.nn as nn
import pdb
import wandb
from copy import deepcopy
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--pretrain', type=str, default='NTP', help='state or NTP')
parser.add_argument('--architecture', type=str, default='transformer', help='transformer mamba mamba2 rnn lstm')
parser.add_argument('--othello_type', type=str, default='synthetic', help='synthetic championship')
parser.add_argument('--num_conv_layers', type=int, default=2, help='synthetic championship')
parser.add_argument('--num_hidden_units', type=int, default=64, help='synthetic championship')
args = parser.parse_args()

wandb_log = True

pretrain = args.pretrain  # 'NTP' or 'state'
architecture = args.architecture  # mamba or mamba2 or transformer or rnn or lstm
othello_type = args.othello_type # 'championship' or 'synthetic'
num_examples = 5

if othello_type == 'synthetic':
    dataset = 'othello/synthetic-othello-1M'
elif othello_type == 'championship':
    dataset = 'othello/championship-othello'

if num_examples == 5:
    data_dir = f'data/{dataset}/white-noise-{pretrain}-{architecture}'
else:
    data_dir = f'data/{dataset}/{num_examples}-examples-white-noise-{pretrain}-{architecture}'

name = f'{args.num_conv_layers}-layers-{args.num_hidden_units}-units-{architecture}-{pretrain}-{othello_type}'

batch_size = 600
num_iters = 5000 # 5000
num_iters_reconstruction = 10
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dtype = 'bfloat16'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)


config_path = os.path.join(data_dir, 'config.json')
with open(config_path, 'r') as f:
    config = json.load(f)


def get_batch(data_dir, split, batch_size):
    # Set up memmap files
    states = np.memmap(os.path.join(data_dir, f'{split}_states.bin'),
                       dtype=np.uint8, mode='r',
                       shape=(config[f'num_{split}_examples'],
                              config['num_state_dimensions']))
    y = np.memmap(os.path.join(data_dir, f'{split}_y.bin'),
                  dtype=np.float16, mode='r',
                  shape=(config[f'num_{split}_examples'], config['num_heads']))
    # Sample random indices
    ix = torch.randint(len(y), (batch_size,))
    # Extract batches
    x = torch.stack([torch.from_numpy((states[i]).astype(np.int64)) for i in ix])
    y = torch.stack([torch.tensor(y[i], dtype=torch.float16) for i in ix])
    # shuffle the order of y
    # y = y[torch.randperm(y.size(0))]
    # Move to device and use pin_memory for faster transfer to GPU
    if device.type == 'cuda':
        x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    else:
        x, y = x.to(device), y.to(device)
    return x, y


class OthelloCNN(nn.Module):
    def __init__(self, hidden_channels=16, fc_hidden_dim=args.num_hidden_units, embedding_dim=8):
        super(OthelloCNN, self).__init__()
        self.embedding = nn.Embedding(num_embeddings=3, embedding_dim=embedding_dim)
        self.conv_layers = nn.ModuleList()
        in_channels = embedding_dim
        for _ in range(args.num_conv_layers):
            self.conv_layers.append(nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1))
            in_channels = hidden_channels
        conv_output_size = 8 * 8 * hidden_channels
        self.fc1 = nn.Linear(conv_output_size, fc_hidden_dim)
        self.fc2 = nn.Linear(fc_hidden_dim, config['num_heads'])
    #
    def forward(self, x):
        x = self.embedding(x)
        x = x.permute(0, 3, 1, 2)
        for conv in self.conv_layers:
            x = F.relu(conv(x))
        x = x.reshape(x.size(0), -1)
        x = F.relu(self.fc1(x))
        pred = self.fc2(x)
        return pred, x


def train_model(l1_penalty):
    model = OthelloCNN().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    best_val_r2 = float('-inf')
    best_model = None

    for iter in range(num_iters):
        model.train()
        x, y = get_batch(data_dir, 'train', batch_size)
        optimizer.zero_grad()
        with ctx:
            preds, representation = model(x.reshape(-1, 8, 8))
            # train_loss = torch.mean((preds - y) ** 2, dim=0)
            train_loss = torch.mean((preds - (y / torch.sqrt(var_ys)))[:, :] ** 2)
            l1_loss = l1_penalty * torch.norm(representation, 1)
            total_loss = train_loss + l1_loss
        total_loss.backward()
        optimizer.step()

        if (iter + 1) % 100 == 0:  # Evaluate every 100 iters
            model.eval()
            x, y = get_batch(data_dir, 'val', 10000)
            with torch.no_grad():
                with ctx:
                    preds, representation = model(x.reshape(-1, 8, 8))
                    # val_loss = torch.mean((preds - y) ** 2)
                    val_loss = torch.mean((preds - (y / torch.sqrt(var_ys)))[:, :] ** 2)
                    val_l1_loss = l1_penalty * torch.norm(representation, 1)
                    val_total_loss = val_loss + val_l1_loss
                    # val_r2 = 1 - torch.mean(val_loss / var_ys)
                    val_r2 = 1 - val_loss

            if val_r2 > best_val_r2:
                best_val_r2 = val_r2
                best_model = deepcopy(model)

            # if wandb_log:
            #     wandb.log({
            #         "iter": iter + 1,
            #         "train_loss": train_loss.item(),
            #         "train_l1_loss": l1_loss.item(),
            #         "val_loss": val_loss.item(),
            #         "val_l1_loss": val_l1_loss.item(),
            #         "val_total_loss": val_total_loss.item(),
            #         "val_r2": val_r2.item()
            #     }, step=iter)
            print(f"Iter {iter+1}, Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}, Val Total Loss: {val_total_loss.item():.4f}, Val R²: {val_r2:.4f}")

    return best_model, best_val_r2

if wandb_log:
    wandb.init(project="othello-cnn", name=name)

# Get test MSE for R2
orig_y = np.memmap(os.path.join(data_dir, f'val_y.bin'),
              dtype=np.float16, mode='r',
              shape=(config[f'num_val_examples'], config['num_heads']))

# var_y = np.mean(((orig_y - np.mean(orig_y, axis=0)) ** 2))
var_ys = torch.from_numpy(np.mean((orig_y - np.mean(orig_y, axis=0)) ** 2, axis=0)).to(device)
# del y

# l1_penalties = [0.0, 0.0001, 0.1, 1.0]  # Add or modify penalties as needed
l1_penalties = [0.0]  # Add or modify penalties as needed
best_overall_model = None
best_overall_r2 = float('-inf')
best_penalty = None

for penalty in l1_penalties:
    print(f"\nTraining with L1 penalty: {penalty}")
    model, val_r2 = train_model(penalty)
    
    if val_r2 > best_overall_r2:
        best_overall_r2 = val_r2
        best_overall_model = model
        best_penalty = penalty

print(f"\nBest model achieved with L1 penalty: {best_penalty}")
print(f"Best validation R²: {best_overall_r2:.4f}")

if wandb_log:
    wandb.log({
        "best_val_r2": best_overall_r2,
        "best_l1_penalty": best_penalty
    })
    wandb.finish()

model = best_overall_model


## Go through data, get predictions

def get_all_data(data_dir, split, batch_size=None, start_ind=None, end_ind=None):
    states = np.memmap(os.path.join(data_dir, f'{split}_states.bin'),
                       dtype=np.uint8, mode='r',
                       shape=(config[f'num_{split}_examples'],
                              config['num_state_dimensions']))
    if start_ind is not None:
        ix = None
        x = torch.from_numpy(states[start_ind:end_ind].astype(np.int64))
    else:
        ix = torch.randint(len(states), (batch_size,))
        x = torch.stack([torch.from_numpy((states[i]).astype(np.int64)) for i in ix])
    if device.type == 'cuda':
        x = x.pin_memory().to(device, non_blocking=True)
    else:
        x = x.to(device)
    return x, ix


# Get baseline loss for second part
valid_states = np.memmap(os.path.join(data_dir, f'val_states.bin'),dtype=np.uint8, mode='r',shape=(config[f'num_val_examples'],config['num_state_dimensions']))
valid_states = np.array(valid_states)
# Turn to one hot
valid_states_one_hot = np.zeros((len(valid_states), len(valid_states[0]), 3))
for i in range(len(valid_states)):
    for j in range(len(valid_states[i])):
        valid_states_one_hot[i][j][valid_states[i][j]] = 1
baseline_pred = np.mean(valid_states_one_hot, 0)
# Clip
baseline_pred = np.clip(baseline_pred, 1e-7, 1 - 1e-7)
baseline_cross_entropy = -np.mean(np.sum(valid_states_one_hot * np.log(baseline_pred), axis=-1))
del baseline_pred, valid_states, valid_states_one_hot

model.eval()
num_train_examples = config['num_train_examples']
num_val_examples = config['num_val_examples']
chunk_size = 10000  # Adjust this based on your memory constraints
train_representations = []
valid_representations = []

with torch.no_grad():
    bar = range(0, num_train_examples, chunk_size)
    for start_ind in bar:
        end_ind = min(start_ind + chunk_size, num_train_examples)
        train_x, _ = get_all_data(data_dir, 'train', start_ind=start_ind, end_ind=end_ind)
        train_x = train_x.to(device)
        with ctx:
            _, batch_repr = model(train_x.reshape(-1, 8, 8))
        train_representations.append(batch_repr.cpu())
        print(f"Processed {end_ind}/{num_train_examples} examples")
    #
    bar = range(0, num_val_examples, chunk_size)
    for start_ind in bar:
        end_ind = min(start_ind + chunk_size, num_val_examples)
        valid_x, _ = get_all_data(data_dir, 'val', start_ind=start_ind, end_ind=end_ind)
        valid_x = valid_x.to(device)
        with ctx:
            _, batch_repr = model(valid_x.reshape(-1, 8, 8))
        valid_representations.append(batch_repr.cpu())
        print(f"Processed {end_ind}/{num_val_examples} examples")

train_representations = torch.cat(train_representations, dim=0).to(device)
valid_representations = torch.cat(valid_representations, dim=0).to(device)


class StatePredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super(StatePredictor, self).__init__()
        self.input_layer = nn.Linear(input_dim, hidden_dim)
        self.hidden_layers = nn.ModuleList([
            nn.Linear(hidden_dim, hidden_dim) for _ in range(num_layers - 1)
        ])
        self.output_layer = nn.Linear(hidden_dim, output_dim * 3)
    #
    def forward(self, x):
        x = F.relu(self.input_layer(x))
        for layer in self.hidden_layers:
            x = F.relu(layer(x))
        return self.output_layer(x)

num_layers = 1
input_dim = train_representations.shape[1]
hidden_dim = 512 # 128
output_dim = 8 * 8  # Assuming 8x8 Othello board

state_predictor = StatePredictor(input_dim, hidden_dim, output_dim, num_layers).to(device)
optimizer = torch.optim.Adam(state_predictor.parameters(), lr=0.001)

if wandb_log:
    wandb.init(project="othello-cnn", name=name + '-reconstruction')

best_reconstruction_r2 = 0.
for iter in range(num_iters_reconstruction):
    _ = state_predictor.train()
    states, ix = get_all_data(data_dir, 'train', batch_size)
    reps = train_representations[ix].to(device)
    optimizer.zero_grad()
    with ctx:
        predicted_states = state_predictor(reps)
        train_loss = F.cross_entropy(predicted_states.view(-1, 3), states.view(-1))
    train_loss.backward()
    optimizer.step()
    #
    if (iter + 1) % 100 == 0:  # Print every 100 iters
        _ = state_predictor.eval()
        states, _ = get_all_data(data_dir, 'val', start_ind=0, end_ind=len(valid_representations))
        # reps = valid_representations[ix].to(device)
        reps = valid_representations.to(device)
        with torch.no_grad():
            with ctx:
                predicted_states = state_predictor(reps)
                val_loss = F.cross_entropy(predicted_states.view(-1, 3), states.view(-1))
        val_reconstruction_r2 = 1 - val_loss / baseline_cross_entropy
        print(f"Iter {iter+1}, Train Loss: {train_loss.item():.4f}, Validation R2: {val_reconstruction_r2:.4f}")
        #
        if val_reconstruction_r2 > best_reconstruction_r2:
            best_reconstruction_r2 = val_reconstruction_r2
        if wandb_log:
            wandb.log({
                "iter": iter + 1,
                "train_reconstruction_loss": train_loss.item(),
                "val_reconstruction_r2": val_reconstruction_r2,
                "best_reconstruction_r2": best_reconstruction_r2,
            }, step=iter)

if wandb_log:
    wandb.finish()

print(f"Results for {name}")
print(f"Best IB R²: {best_overall_r2:.4f}")
print(f"Best reconstruction R²: {best_reconstruction_r2:.4f}")
      