import random
import wandb
import hydra
import torch
import torch.nn as nn
import torch.optim as optim
from src.mamba_model import AlteranteMambaTrans, TransformerOnly
from omegaconf import DictConfig, OmegaConf
from config import settings


# Model instantiation
n_tokens = 29
input_map = {character: i+3 for i, character in enumerate('abcdefghijklmnopqrstuvwxyz')}
input_map['begin_seq'] = 0
input_map['copy_or_eos'] = 1
input_map['ignore'] = 2


def run_sampler(max_length=300, min_length=1, context_len=620):
    # Final sample always begins with a begin sequence token
    curr_len, loop_break, final_input = 0, False, [input_map['begin_seq']]

    while True:
        # Unsure, whether this length is randomly sampled, per sample or not
        length = random.randint(min_length, max_length)
        curr_len = curr_len + length
        # If the context length crossed, then pick a smaller length, fill the context and break
        if curr_len > context_len:
            # Fill out the remaining length in the context with a random string
            final_input += [input_map['ignore'] for i in range(length - (curr_len - context_len))]
            break
        # Sample the input here, (since last token has to be copy, length-1 sampled)
        # Since for sampling, the value sampled could be alpha_len_sample, for indexing = 25 (not 26)        
        final_input += [random.randint(3, 28) for i in range(length-1)]
        final_input += [input_map['copy_or_eos']]

    # Create the corresponding output for the input
    copy_idx = final_input.index(input_map['copy_or_eos'])
    output = [input_map['ignore'] for _ in range(copy_idx)] + final_input[1:]
    return final_input, output[:context_len + 1]


def train_loop(model, iter_batches, batch_size, category, device, optimizer, lossfn, scheduler):
    if category == 'train':
        model.train()  # Turn on the train mode
    else: 
        model.eval()

    for _ in range(iter_batches):
        batch_input, batch_target, total_loss = [], [], 0
        # Sample a batch of 64 of the copying task here
        for i in range(batch_size):
            # In training mode, max_length - 300, min_length - 1, context - 620
            if category == 'test':
                sample_input, sample_output = run_sampler(300, 300)
            else:
                sample_input, sample_output = run_sampler()
            
            # Convert lists to tensors and reshape to desired format
            input_tensor = torch.tensor(sample_input).reshape(1, -1)  # (1, 600)
            output_tensor = torch.tensor(sample_output).reshape(1, -1)  # (1, 600)
            
            # Append tensors to lists for batch processing
            batch_input.append(input_tensor)
            batch_target.append(output_tensor)

        # Concatenate tensors within the loop to form batch tensors
        batch_input = torch.cat(batch_input, dim=0)  # (batch_size, 600)
        batch_target = torch.cat(batch_target, dim=0)  # (batch_size, 600)
        
        # Send the data over to the correct device
        batch_input, batch_target = batch_input.to(device), batch_target.to(device)

        # Get the output from the model
        output = model(batch_input)

        if category == 'train':
            # Apply loss, as per the next token loss prediction
            loss = lossfn(output.view(-1, n_tokens), batch_target.view(-1))
            # Create mask for positions to compute loss (this would handle different lengths in the same sample as well)
            mask = (batch_target.view(-1) != input_map['ignore'])
            # Apply mask and remove the loss components that were initially a part of the mask 
            masked_loss = (mask.view(-1) * lossfn(output.view(-1, n_tokens), batch_target.view(-1))).sum() / mask.sum()
            # Backpropagation
            optimizer.zero_grad()
            masked_loss.backward()
            # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            optimizer.step()
            scheduler.step()
            
            # Add total loss 
            total_loss += masked_loss.item()

        # Compute accuracy here 
        accuracy = get_accuracy(output, batch_target)
        # Print stats
        wandb.log({
            'category': category,
            'loss': total_loss,
            'accuracy': accuracy
        })
    print(f'category {category} loss {total_loss:5.2f} accuracy {accuracy}')
    return total_loss


def get_accuracy(outputs, targets):
    """
    Computes the accuracy of predictions, ignoring the positions with padding.

    Args:
    outputs (torch.Tensor): The raw model outputs of shape (batch_size, seq_len, num_classes).
    targets (torch.Tensor): The ground truth target labels of shape (batch_size, seq_len).

    Returns:
    float: The accuracy of the model on the given batch, expressed as a percentage.
    """
    # Get the predicted classes from outputs: take the argmax over the last dimension (num_classes)
    preds = outputs.argmax(dim=-1)
    
    # Create a mask to ignore padding in the targets
    mask = targets != input_map['ignore']
    
    # Calculate the number of correct predictions
    correct_predictions = (preds[mask] == targets[mask]).sum()
    
    # Calculate total number of non-padding tokens
    total_non_padding = mask.sum()

    # Compute accuracy
    accuracy = (correct_predictions.float() / total_non_padding.float()) * 100 if total_non_padding > 0 else 0
    
    return accuracy.item()


@hydra.main(config_path='configs', config_name="defaults", version_base=None)
def run_pipeline(cfg: DictConfig) -> None:
    use_wandb = cfg.basic.use_wandb
    if use_wandb is True:
        # Login into the wandb system
        cfg_copy = OmegaConf.to_container(cfg, resolve=True)
        wandb.login(key=settings.WANDB_API_KEY)
        wandb.init(project=cfg.dataset.name, entity=settings.WANDB_TEAM,
                   name=cfg.basic.wandb_run, config=cfg_copy)

    batch_size, num_epochs, iter_batches = cfg.dataset.batch_size, cfg.train.epochs, cfg.dataset.iter_batches
    test_batch_size, test_epochs = cfg.dataset.test_batch_size, cfg.dataset.test_epochs
    lossfn = nn.CrossEntropyLoss(reduction='none')
    device = torch.device(cfg.train.device if torch.cuda.is_available() else "cpu")

    # Get the parameters to run the hybrid model
    positional_mask, num_attn_heads = cfg.model.positional_mask, cfg.model.num_attn_heads
    d_conv, d_expand, layers = cfg.model.d_conv, cfg.model.d_expand, cfg.model.num_layers
    d_state, d_channels, model_debug = cfg.model.d_state, cfg.model.d_channels, cfg.model.debug
    # Initialise the model with all the parameters from the Yaml file
    model = TransformerOnly(n_tokens, n_tokens, d_channels, d_state, layers,
                            positional_mask, num_attn_heads, d_conv, d_expand, device, model_debug)
    # model = AlteranteMambaTrans(n_tokens, n_tokens, d_channels, d_state, layers,
    #                             positional_mask, num_attn_heads, d_conv, d_expand, device, model_debug)
    
    # The learning rate is added to the optimizer by default, model parameters added manually
    optimizer = hydra.utils.instantiate(cfg.optimizer, params=model.parameters())
    scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=optimizer)
    print(scheduler.step_size)
    model.to(device)
    # Train loop
    for i in range(num_epochs):
        total_loss = train_loop(model, iter_batches, batch_size, 'train', device, optimizer, lossfn, scheduler)
    
    # Test loop
    for i in range(test_epochs):
        _ = train_loop(model, 1, test_batch_size, 'test', device, optimizer, lossfn, scheduler)

    if use_wandb is True:
        wandb.finish()


if __name__ == "__main__":
    run_pipeline()
