import torch
import torch.nn as nn
import torch.nn.functional as F
import LSA_layer
import noisy_lr_data_generation
from tqdm import tqdm
from dotmap import DotMap

"""
train_config will contain the following information:
- lsa_config (itself consisting of the info to be passed to LSA layer)

Training Process
- num_training_steps
- lr
- grad_clip

Training Data
- batch_size
- dimension
- output_variance
- sequence_length

Other
- eval_batch_size
"""

TQDM = False

def lsa_training_loop(train_config):
    lsa_layer = LSA_layer.LinearSelfAttention(train_config.lsa_config)
    optimizer = torch.optim.Adam(lsa_layer.parameters(), lr=train_config.lr)
    iter = None
    if TQDM:
        iter = tqdm(range(train_config.num_training_steps))
    else:
        iter = range(train_config.num_training_steps)

    for i in iter:
        # Prediction
        current_task_batch = noisy_lr_data_generation.NoisyLinearRegressionTask(dimension=train_config.dimension, 
                                                                                output_variance=train_config.output_variance,
                                                                                batch_size=train_config.batch_size,
                                                                                device=train_config.lsa_config.device)
        xs = current_task_batch.sample_xs(sequence_length=train_config.sequence_length)
        context, last_y, _ = current_task_batch.evaluate(xs)
        predicted_y = lsa_layer(context)[:, -1, -1]
        
        # Gradient Step
        optimizer.zero_grad()
        loss = F.mse_loss(predicted_y, last_y)
        loss.backward()
        if train_config.do_grad_clip:
            nn.utils.clip_grad_norm_(lsa_layer.parameters(), max_norm=train_config.grad_clip)
        optimizer.step()

        # Evaluate
        if i % 100 == 0:
            eval_task_batch = noisy_lr_data_generation.NoisyLinearRegressionTask(dimension=train_config.dimension, 
                                                                                 output_variance=train_config.output_variance,
                                                                                 batch_size=train_config.eval_batch_size,
                                                                                 device=train_config.lsa_config.device)
            xs = eval_task_batch.sample_xs(sequence_length=train_config.sequence_length)
            context, last_y, _ = eval_task_batch.evaluate(xs)
            predicted_y = lsa_layer(context)[:, -1, -1]
            eval_loss = F.mse_loss(predicted_y, last_y)
            print("Training step ", i, " Evaluation Loss: ", eval_loss)
        
    return lsa_layer, train_config

# Hard-coded config
# Similar hyperparameters to https://arxiv.org/pdf/2212.07677.pdf
def icl_gd_config():
    config = DotMap()

    # Data
    config.batch_size = 2048
    config.dimension = 10
    config.output_variance = 0.5
    config.sequence_length = 11  # (# of context tokens + query token)

    # Training
    config.num_training_steps = 12000 # 4000 used in the previous work
    config.lr = 0.0001
    config.do_grad_clip = True
    config.grad_clip = 10

    # Define lsa_config
    config.lsa_config = DotMap()
    config.lsa_config.device = None
    if torch.cuda.is_available():
        config.lsa_config.device = torch.device("cuda")
    else:
        config.lsa_config.device = torch.device("cpu")
    config.lsa_config.num_heads = 1
    config.lsa_config.point_dim = config.dimension + 1
    config.lsa_config.query_dim = config.dimension + 1
    config.lsa_config.head_dim = config.dimension + 1

    # Other
    config.eval_batch_size = 10000
    return config

if __name__ == "__main__":
    print(icl_gd_config())
    lsa_layer, _ = lsa_training_loop(icl_gd_config())