from src.datasets.associative_recall import AssociativeRecallDataLoader
from src.datasets.input_copying import InputCopyingDataLoader
from src.datasets.regression import RegressionDataloader
from src.scheduler.scheduler import TransformerLR
from src.models.task_head import HyenaWithTaskHead, GPT2WithTaskHead
from src.metrics.masked_metrics import masked_cross_entropy, masked_accuracy

import random
import numpy as np
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import Literal, Optional, Any
import sys, os
import yaml
import time
import pickle
import matplotlib.pyplot as plt

@dataclass
class TaskConfig:
    task_name: Literal["associative_recall", "input_copying", "regression"]
    vocab_size: Optional[int] = None
    max_seq_length: Optional[int] = None 

@dataclass
class TrainConfig:
    batch_size: int
    num_batch: int = 10**9
    num_epochs: int = 1
    num_warmup_steps: Optional[int] = None
    lr: Optional[float] = None
    logging_steps: Optional[int] = None
    curriculum: dict[str, Any] = field(default_factory=lambda: {"type": None})
    do_eval: bool = False

@dataclass
class ModelConfig:
    model_name: Literal["hyena", "gpt2"]
    n_layer: int
    d_model: int
    d_inner: int
    mlp_depth: int
    time_emb_dim: int
    front_mlp: bool

def manual_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed + 1)
    torch.manual_seed(seed + 2)
    torch.cuda.manual_seed(seed + 2)
    torch.backends.cudnn.deterministic = True
    torch.use_deterministic_algorithms = True

def training_loop(epoch_idx: int, loop_idx: int, num_loops: int, 
                  log_step_history: list, loss_history: list, acc_history: list,
                  model: nn.Module, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler._LRScheduler,
                  task_config: TaskConfig, train_config: TrainConfig, device: torch.device):
    model.train()

    # load the settings
    batch_size = train_config.batch_size
    num_batch = train_config.num_batch // num_loops
    vocab_size = task_config.vocab_size
    max_seq_length = task_config.max_seq_length

    # update the setting based on the curriculum
    curriculum = train_config.curriculum
    if curriculum['type'] is None:
        pass
    elif curriculum['type'] == 'seq-len':
        max_seq_length = round(max_seq_length * (2 ** (loop_idx - num_loops + 1)))
    else:
        raise NotImplementedError(f"Invalid curriculum type: {curriculum['type']}")
        
    # print primary settings
    settings_txt = f"Loop {loop_idx}/{num_loops}: "
    settings_txt += f"batch_size = {batch_size}, num_batch = {num_batch}, "
    settings_txt += f"vocab_size = {vocab_size}, max_seq_length = {max_seq_length}"
    print(settings_txt, flush=True)

    # define the dataloader
    if task_config.task_name == "associative_recall":
        dataloader = AssociativeRecallDataLoader(batch_size=batch_size,
                                                 num_batch=num_batch,
                                                 vocab_size=vocab_size,
                                                 max_seq_length=max_seq_length)
    elif task_config.task_name == "input_copying":
        dataloader = InputCopyingDataLoader(batch_size=batch_size,
                                            num_batch=num_batch,
                                            vocab_size=vocab_size,
                                            bos_id=0, copy_id=1, start_id=2,
                                            max_seq_length=max_seq_length)
    elif task_config.task_name == "regression":
        dataloader = RegressionDataloader(batch_size=batch_size, num_batch=num_batch, 
                                          vocab_size=vocab_size, max_seq_length=max_seq_length)
    else:
        raise ValueError(f"Invalid task name: {task_config.task_name}")

    # define the loss function
    if task_config.task_name != "regression":
        loss_fn = masked_cross_entropy
    else:
        loss_fn = lambda out, labels, masks: nn.MSELoss()(out.squeeze(-1), labels)
    
    # train the model
    acc_sum, last_step = 0, -1
    for step, batch in enumerate(dataloader):
        # load the batch
        input_ids = batch["input_ids"].to(device) # (batch_size, seq_length)
        labels = batch["labels"].to(device) # (batch_size, label_seq_length)
        masks = batch["masks"] # (batch_size, seq_length)
        masks = masks.to(device) if masks is not None else None

        # compute the loss and update the model
        optimizer.zero_grad()
        out = model(input_ids, labels) # logits for "cls", and prediction for "reg" (batch_size, label_seq_length, vocab_size)
        loss = loss_fn(out, labels, masks)
        loss.backward()
        optimizer.step()
        scheduler.step()

        # compute the accuracy if the task is not regression
        if task_config.task_name != "regression":
            acc = masked_accuracy(out, labels, masks).item()
            acc_sum += acc

        # logging
        if (step % train_config.logging_steps == 0) or (step == len(dataloader) - 1):
            # print the log
            acc = acc_sum / (step - last_step)
            acc_sum = 0
            last_step = step
            log_txt = f"Epoch = {epoch_idx}, LoopIdx = {loop_idx}/{num_loops}, Step = {step}/{len(dataloader)}: "
            log_txt += f"Loss = {loss.item():.8f}, Accuracy = {acc:.4f}"
            log_txt += f", lr = {scheduler.get_lr()[0]:.4f}"
            print(log_txt, flush=True)
            
            # record the training history
            log_step_history.append(step + (epoch_idx * num_loops + loop_idx) * len(dataloader))
            loss_history.append(loss.item())
            acc_history.append(acc)

@torch.no_grad()
def get_pred_example(model: nn.Module, task_config: TaskConfig, train_config: TrainConfig, device: torch.device):
    model.eval()

    # load the settings
    batch_size = train_config.batch_size
    num_batch = 1
    vocab_size = task_config.vocab_size
    max_seq_length = task_config.max_seq_length

    # define the dataloader
    if task_config.task_name == "regression":
        dataloader = RegressionDataloader(batch_size=batch_size, num_batch=num_batch, 
                                          vocab_size=vocab_size, max_seq_length=max_seq_length)
    else:
        raise ValueError(f"Invalid task name: {task_config.task_name}")
    
    # get the prediction example
    batch = next(iter(dataloader))
    input_ids = batch["input_ids"].to(device) # (batch_size, seq_length)
    labels = batch["labels"].to(device) # (batch_size, label_seq_length)
    masks = batch["masks"] # (batch_size, seq_length)
    masks = masks.to(device) if masks is not None else None

    # compute the loss and update the model
    out = model(input_ids, labels) # logits for "cls", and prediction for "reg" (batch_size, label_seq_length, vocab_size)

    # print the prediction example
    for i in range(4):
        if task_config.task_name == "regression":
            print("Prediction: ", out[i].squeeze(-1).cpu().numpy())
            print("Label: ", labels[i].cpu().numpy())
        else:
            print("Prediction: ", torch.argmax(out[i], dim=-1).cpu().numpy())
            print("Label: ", labels[i].cpu().numpy())

def train(task_config: TaskConfig, train_config: TrainConfig, model_config: ModelConfig):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # define the model
    if model_config.model_name == "hyena":
        _model = HyenaWithTaskHead(is_reg=(task_config.task_name == "regression"),
                                   n_layer=model_config.n_layer,
                                   d_model=model_config.d_model,
                                   d_inner=model_config.d_inner,
                                   mlp_depth=model_config.mlp_depth, 
                                   layer={"l_max": 8192, "emb_dim": model_config.time_emb_dim}, 
                                   front_mlp=model_config.front_mlp,
                                   vocab_size=task_config.vocab_size, 
                                   use_head=False) 
        model = nn.DataParallel(_model)
    elif model_config.model_name == "gpt2":
        _model = GPT2WithTaskHead(is_reg=(task_config.task_name == "regression"),
                                  d_model=model_config.d_model,
                                  vocab_size=task_config.vocab_size,
                                  n_layer=model_config.n_layer,
                                  d_inner=model_config.d_inner)
        model = nn.DataParallel(_model)
    else:
        raise ValueError(f"Invalid model name: {model_config.model_name}")
    model.to(device)
    print(model.module.param_count(), flush=True)
    
    # define the optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=train_config.lr)

    # define the learning rate scheduler
    scheduler = TransformerLR(optimizer, warmup_epochs=train_config.num_warmup_steps)

    # loops for training
    log_step_history, loss_history, acc_history = [], [], []
    for epoch in range(train_config.num_epochs):
        num_loops = train_config.curriculum["num_loops"] if train_config.curriculum["type"] is not None else 1
        for loop_idx in range(num_loops):
            training_loop(epoch, loop_idx, num_loops, log_step_history, loss_history, acc_history,
                          model, optimizer, scheduler, task_config, train_config, device)
                
    # get the prediction example
    get_pred_example(model, task_config, train_config, device)

    return log_step_history, loss_history, acc_history
    
if __name__ == "__main__":
    # Set the random seed
    manual_seed(0)

    # Check the number of available GPUs
    print("Number of available GPUs: ", torch.cuda.device_count(), flush=True)

    # read the config file
    config_path = sys.argv[1] if len(sys.argv) > 1 else "configs/config.yaml"
    with open(config_path) as f:
        config = yaml.safe_load(f)

    # parse the config
    task_config = TaskConfig(**config["task"])
    train_config = TrainConfig(**config["train"])
    model_config = ModelConfig(**config["model"])

    # train the model
    start = time.time()
    log_step_history, loss_history, acc_history = train(task_config, train_config, model_config)
    end = time.time()
    print(f"Training time: {end - start:.2f} seconds")

    # save and plot the training history
    if len(sys.argv) > 2:
        # save the training history
        save_dir = sys.argv[2]
        with open(os.path.join(save_dir, "loss_history.pkl"), "wb") as f:
            pickle.dump(loss_history, f)
        with open(os.path.join(save_dir, "acc_history.pkl"), "wb") as f:
            pickle.dump(acc_history, f)

        # plot the training history
        fig, axes = plt.subplots(2, 1, figsize=(8, 6))
        axes[0].plot(log_step_history, loss_history)
        axes[0].set_title("Loss history")
        if task_config.task_name == "regression":
            axes[0].set_yscale("log")
        axes[1].plot(log_step_history, acc_history)
        axes[1].set_title("Accuracy history")
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, "training_history.png"))