import torch
import torch.nn as nn
import torch.optim as optim
import os
import wandb
from torch.optim.lr_scheduler import OneCycleLR, StepLR
from tqdm import tqdm

from dataloader import create_data_loader
from model import PretrainedWithCustomTransformer

from utils import calculate_accuracy, AstroNormWithLayerNorm


def train_epoch(model, dataloader, criterion, optimizer, scheduler, device, args):
    model.train()
    epoch_loss = 0
    epoch_acc = 0
    num_batches = len(dataloader)
    pbar = tqdm(total=num_batches, desc="Training", unit="batch")

    astro_norm = AstroNormWithLayerNorm(
        d_model=model.hidden_size, total_segments=args.num_segments)

    for data in dataloader:
        if args.dataset in ['imdb', 'imdb_long']:  # For datasets like IMDb and IMDb_Long
            input_ids = data['input_ids'].to(device)
            attention_mask = data['attention_mask'].to(device)
            labels = data['labels'].to(device)  # Assuming it's a single label in the key 'label'
        elif args.dataset in ['aan']:
            inputs, labels = data
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs['attention_mask'].to(device)
            labels = labels.to(device)  # Labels are directly passed as a tensor in these datasets
        else:  # For datasets like cifar10, listops, imdb_lra
            inputs, labels = data
            inputs = {key: value.squeeze(1).to(device) for key, value in
                      inputs.items()}  # Squeeze to remove extra dimension
            labels = labels.squeeze(1).to(device)  # Squeeze to remove extra dimension
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs['attention_mask'].to(device)
            labels = labels.to(device)  # Labels are directly passed as a tensor in these datasets

        optimizer.zero_grad()

        batch_size, seq_len = input_ids.size()
        segment_len = seq_len // args.num_segments
        segment_memories = []
        segment_losses = []
        all_logits = []

        memory = None
        for i in range(args.num_segments):
            start_idx = i * segment_len
            end_idx = (i + 1) * segment_len
            input_segment = input_ids[:, start_idx:end_idx]
            attention_segment = attention_mask[:, start_idx:end_idx]
            effective_segments = i + 1
            # Check if all elements in attention_segment are zero
            if torch.all(attention_segment == 0):
                effective_segments -= 1
                break

            if args.memory_replay_backprop:
                with torch.no_grad():
                    logits, memory = model(input_segment, attention_segment,
                                           current_segment=torch.tensor(i, device=device), memory=memory)
                    if args.astro_mem:
                        astro_norm(memory, i)
                        memory = astro_norm.memory_retention_sum if args.mem_sum else astro_norm.memory_retention
                    segment_memories.append(memory.detach().clone().requires_grad_())
            else:
                logits, memory = model(input_segment, attention_segment, current_segment=torch.tensor(i, device=device),
                                       memory=memory)
                if args.astro_mem:
                    astro_norm(memory, i)
                    memory = astro_norm.memory_retention_sum if args.mem_sum else astro_norm.memory_retention
                all_logits.append(logits)

                loss = criterion(logits, labels)
                weighted_loss = loss * (i + 1) / args.num_segments
                segment_losses.append(weighted_loss)
                weighted_loss.backward(retain_graph=True)

        if not args.memory_replay_backprop:
            optimizer.step()

        if args.memory_replay_backprop:
            memory_grads = torch.zeros_like(segment_memories[-1])  # Initialize memory gradients

            for i, segment_memory in enumerate(reversed(segment_memories)):  # Reverse segment_memories
                start_idx = (effective_segments - i - 1) * segment_len
                end_idx = (effective_segments - i) * segment_len
                input_segment = input_ids[:, start_idx:end_idx]
                attention_segment = attention_mask[:, start_idx:end_idx]

                logits, next_memory = model(input_segment, attention_segment,
                                            current_segment=torch.tensor(effective_segments - i - 1, device=device),
                                            memory=segment_memory)
                all_logits.append(logits)

                loss = criterion(logits, labels)
                weighted_loss = loss  # * (i + 1) / args.num_segments
                segment_losses.append(weighted_loss)
                weighted_loss.backward(retain_graph=True)
                next_memory.backward(memory_grads)

                memory_grads.copy_(segment_memory.grad.data)

            optimizer.step()

        epoch_loss += torch.stack(segment_losses).sum().item() / args.num_segments

        preds = torch.argmax(torch.stack(all_logits).max(dim=0)[0], dim=1)
        epoch_acc += calculate_accuracy(preds, labels)

        if scheduler is not None:
            scheduler.step()

        pbar.update(1)
        pbar.set_postfix({'loss': epoch_loss / pbar.n, 'acc': epoch_acc / pbar.n})

    pbar.close()

    return epoch_loss / num_batches, epoch_acc / num_batches


def evaluate(model, dataloader, criterion, device, args):
    model.eval()
    epoch_loss = 0
    epoch_acc = 0
    num_batches = len(dataloader)
    pbar = tqdm(total=num_batches, desc="Evaluating", unit="batch")

    astro_norm = AstroNormWithLayerNorm(
        d_model=model.hidden_size, total_segments=args.num_segments)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for data in dataloader:
            if args.dataset in ['imdb', 'imdb_long']:  # For datasets like IMDb and IMDb_Long
                input_ids = data['input_ids'].to(device)
                attention_mask = data['attention_mask'].to(device)
                labels = data['labels'].to(device)  # Assuming it's a single label in the key 'label'

            elif args.dataset in ['aan']:
                inputs, labels = data
                input_ids = inputs['input_ids'].to(device)
                attention_mask = inputs['attention_mask'].to(device)
                labels = labels.to(device)  # Labels are directly passed as a tensor in these datasets
            else:  # For datasets like cifar10, listops, imdb_lra, pathfinder32
                inputs, labels = data
                inputs = {key: value.squeeze(1).to(device) for key, value in
                          inputs.items()}  # Squeeze to remove extra dimension
                labels = labels.squeeze(1).to(device)  # Squeeze to remove extra dimension
                input_ids = inputs['input_ids'].to(device)
                attention_mask = inputs['attention_mask'].to(device)
                labels = labels.to(device)  # Labels are directly passed as a tensor in these datasets

            batch_size, seq_len = input_ids.size()
            segment_len = seq_len // args.num_segments
            segment_losses = []
            memory = None
            all_logits = []

            for i in range(args.num_segments):
                start_idx = i * segment_len
                end_idx = (i + 1) * segment_len
                input_segment = input_ids[:, start_idx:end_idx]
                attention_segment = attention_mask[:, start_idx:end_idx]
                # Check if all elements in attention_segment are zero
                if torch.all(attention_segment == 0):
                    break

                logits, memory = model(input_segment, attention_segment, current_segment=torch.tensor(i, device=device),
                                       memory=memory)
                if args.astro_mem:
                    astro_norm(memory, i)
                    memory = astro_norm.memory_retention_sum if args.mem_sum else astro_norm.memory_retention
                all_logits.append(logits)

                loss = criterion(logits, labels)
                weighted_loss = loss  # * (i + 1) / args.num_segments
                segment_losses.append(weighted_loss)

            epoch_loss += torch.stack(segment_losses).sum().item() / args.num_segments

            all_logits = torch.stack(all_logits)  # Shape: (num_segments, batch_size, num_classes)
            final_logits = all_logits.max(dim=0)[0]  # Shape: (batch_size, num_classes)

            preds = torch.argmax(final_logits, dim=1)
            epoch_acc += calculate_accuracy(preds, labels)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            pbar.update(1)
            pbar.set_postfix({'loss': epoch_loss / pbar.n, 'acc': epoch_acc / pbar.n})

    pbar.close()

    return epoch_loss / num_batches, epoch_acc / num_batches


def train(args):
    if args.wandb:
        wandb.init(project="Long-Range-Arena", config=args, name=args.wandb_run_name)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create data loaders
    dataset_name = args.dataset
    if dataset_name == 'imdb_long':
        model_name = 'google/canine-c'
    else:
        model_name = args.model_name

    train_dataloader, num_labels = create_data_loader(
        model_name, dataset_name, args.batch_size, args.max_seq_len, split='train', shuffle=True,
        sample_percentage=args.sample_percentage)
    val_dataloader, _ = create_data_loader(
        model_name, dataset_name, args.batch_size, args.max_seq_len, split='eval', shuffle=False,
        sample_percentage=args.sample_percentage)

    # Define the model
    model = PretrainedWithCustomTransformer(
        model_name=model_name,
        num_labels=num_labels,
        num_heads=args.num_heads,
        hidden_dim=args.hidden_dim,
        num_layers=args.num_layers,
        dropout=args.dropout,
        scaleD=args.scaleD,
        alpha=args.alpha,
        num_memory_tokens=args.num_memory_tokens,
        pooling=args.pooling,
        replace_attention=args.replace_attention,
        layers_to_replace=[int(x) for x in args.layers_to_replace.split(",")] if args.layers_to_replace else None,
        use_only_embeddings=args.use_only_embeddings,
        add_Hrel=args.add_Hrel,
        astro_sigmoid_nonlinearity=args.astro_sigmoid_nonlinearity,
        attention_type=args.attention_type,
        dataset_name=dataset_name,
    ).to(device)

    if args.freeze_pretrained:
        for param in model.pretrained.parameters():
            param.requires_grad = False

    criterion = nn.CrossEntropyLoss().to(device)
    optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

    if args.use_scheduler:
        scheduler = OneCycleLR(optimizer, max_lr=args.learning_rate, epochs=args.num_epochs,
                               steps_per_epoch=len(train_dataloader))  ## if args.dataset != 'pathfinder32' else StepLR(
        ##optimizer, step_size=10, gamma=0.5)
    else:
        scheduler = None

    # Print the number of learnable parameters
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Number of learnable parameters: {total_params}")

    best_val_acc = 0
    for epoch in range(args.num_epochs):
        print(
            f"Epoch: {epoch + 1}/{args.num_epochs}, LR: {scheduler.get_last_lr() if args.use_scheduler else args.learning_rate}")

        train_loss, train_acc = train_epoch(model, train_dataloader, criterion, optimizer, scheduler, device, args)
        val_loss, val_acc = evaluate(model, val_dataloader, criterion, device, args)

        train_log = {
            f"train_loss_{dataset_name}": train_loss,
            f"train_acc_{dataset_name}": train_acc,
        }

        val_log = {
            f"val_loss_{dataset_name}": val_loss,
            f"val_acc_{dataset_name}": val_acc,
        }

        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")
        print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

        if args.wandb:
            wandb.log({**train_log, **val_log})

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model_save_path = os.path.join(args.model_save_path, f'best_model_{dataset_name}_{args.wandb_run_name}.pth')
            torch.save(model.state_dict(), model_save_path)
            print(f"Best model saved with validation accuracy: {best_val_acc:.4f}")

    if args.wandb:
        wandb.finish()


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(
        description="Train a model for various NLP tasks.")

    # Dataset options
    datasets = ['imdb', 'imdb_long', 'imdb_lra', 'cifar10', 'listops', 'pathfinder32', 'aan']

    # Model options
    models = ['bert-base-uncased', 'bert-large-uncased', 'bert-base-cased', 'bert-large-cased',
              'roberta-base', 'roberta-large', 'google/canine-c']

    parser.add_argument('--dataset', type=str, default='aan', choices=datasets,
                        help='Dataset to use for training and evaluation')
    parser.add_argument('--model_name', type=str, default='bert-base-uncased', choices=models,
                        help='Pretrained model to use')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size for training')
    parser.add_argument('--max_seq_len', type=int, default=1024, help='Maximum sequence length for input texts')
    parser.add_argument('--num_epochs', type=int, default=15, help='Number of training epochs')
    parser.add_argument('--learning_rate', type=float, default=0.0001, help='Learning rate')
    parser.add_argument('--num_heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--hidden_dim', type=int, default=1024, help='Dimension of the feedforward network model')
    parser.add_argument('--num_layers', type=int, default=1, help='Number of transformer encoder layers')
    parser.add_argument('--wandb', type=bool, default=False, help='Use Weights & Biases for logging')
    parser.add_argument('--wandb_run_name', type=str, default='pretrained_with_custom_transformer',
                        help='Name for the Weights & Biases run')
    parser.add_argument('--model_save_path', type=str, default='./models', help='Path to save the model and plots')
    parser.add_argument('--memory_replay_backprop', type=bool, default=True, help='Use memory replay backpropagation')
    parser.add_argument('--num_segments', type=int, default=4,
                        help='Number of segments to split input sequences for memory replay backpropagation')
    parser.add_argument('--num_memory_tokens', type=int, default=32, help='Number of memory tokens')
    parser.add_argument('--scaleD', type=float, default=100.0, help='Scaling factor for AstroAttention')
    parser.add_argument('--alpha', type=float, default=0.25, help='Alpha parameter for AstroAttention')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout factor of the model')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay factor of the models optimizer')
    parser.add_argument('--pooling', type=str, default='cls', choices=['cls', 'average'], help='Pooling method')
    parser.add_argument('--sample_percentage', type=float, default=1,
                        help='Percentage of dataset to use for training')
    parser.add_argument('--freeze_pretrained', type=bool, default=False, help='Freeze pretrained model parameters')
    parser.add_argument('--use_only_embeddings', type=bool, default=True,
                        help='Use only embeddings from pretrained model')
    parser.add_argument('--replace_attention', type=bool, default=False,
                        help='Replace attention layers with AstroAttention')
    parser.add_argument('--layers_to_replace', type=str, default=None,
                        help='Comma-separated list of layer indices to replace attention (e.g., "0,1,2"). If None, replace all layers.')
    parser.add_argument('--add_Hrel', type=bool, default=True, help='Flag to add Hrel/scaling_factor')
    parser.add_argument('--astro_sigmoid_nonlinearity', type=bool, default=False,
                        help='Flag to use sigmoid nonlinearity over the H_neurona nd H_astro')
    parser.add_argument('--mem_sum', type=bool, default=False,
                        help='Flag to use the summed astrocytic memory or the memory of that segment only')
    parser.add_argument('--astro_mem', type=bool, default=True, help='Use inherent Astrocytic memory')
    parser.add_argument('--use_scheduler', type=bool, default=True, help='Use learning rate scheduler')
    parser.add_argument('--clip', type=int, default=10, help='Clip value for relative positional encoding')
    parser.add_argument('--attention_type', type=str, default='astro', choices=['astro', 'softmax'],
                        help='Type of attention to use')

    args = parser.parse_args()

    train(args)
