# -*- coding: utf-8 -*-


import os
from os.path import exists
import random
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
#import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

import pdb

RUN_EXAMPLES = True

def is_interactive_notebook():
  return __name__ == "__main__"

def show_example(fn, args = []):
  if __name__ == '__main__' and RUN_EXAMPLES:
    return fn(*args)

def execute_example(fn, args = []):
  if __name__ == '__main__' and RUN_EXAMPLES:
    fn(*args)

class DummyOptimizer(torch.optim.Optimizer):
  def __init__(self):
    self.param_groups = [{'lr': 0}]
    None

  def step(self):
    None

  def zero_grad(self, set_to_none = False):
    None


class DummyScheduler:
  def step(self):
    None

'Embedding'

class Embedding(nn.Module):
  def __init__(self, vocab, d_model, d_pos = 0):
    super(Embedding, self).__init__()
    self.vocab = vocab
    self.d_pos = d_pos
    self.d_model = d_model

  def forward(self, x):
    '''
    x: Size([*, L])

    Output: Size([*, d_model, L])
    '''
    assert self.vocab + self.d_pos <= self.d_model
    F = nn.functional
    if self.d_pos == 0:
      y = F.one_hot(x, num_classes=self.d_model).to(torch.float) # Size([*, L, vocab])

    else:
      y = F.one_hot(x, num_classes=self.d_model).to(torch.float) # Size([*, L, vocab])

      batch = x.size(0)
      L = x.size(1)
      for b in range(batch):
        for l in range(L):
          y[b, l, self.vocab + l % self.d_pos] = 1.0

    return y.transpose(-2,-1)

class Transformer(nn.Module):
  def __init__(self, embedding, vocab, d_model):
    super(Transformer, self).__init__()
    self.d_model = d_model
    self.vocab = vocab
    self.embedding = embedding
    self.W_kq = nn.Parameter(
        torch.zeros(self.d_model, self.d_model, requires_grad=True)
        )

    self.W_ffv = nn.Parameter(
        torch.eye(self.vocab, self.d_model, requires_grad=True)
        )

  def decoder(self, x):
    x = self.embedding(x) #[batch, d_model, L] tensor

    #pdb.set_trace()
    Batch = x.size(0)
    scores = torch.matmul(
        torch.matmul(x.transpose(-2,-1), self.W_kq), x[:,:,-1].reshape(Batch, self.d_model, 1)
        ).reshape(Batch, -1)
    #pdb.set_trace()
    p_attn = scores.softmax(dim=-1).reshape(Batch, -1, 1)           # [batch, L, 1] tensor
    representation = torch.matmul(x, p_attn)                        # [batch, d_model, 1] tensor
    output = torch.matmul(self.W_ffv, representation)               # [Batch, vocab] tensor
    return output.reshape(Batch, self.vocab)

  def forward(self, x):
    return self.decoder(x).softmax(-1)

class Transformer_FixV(nn.Module):
  def __init__(self, embedding, vocab, d_model, W_ffv):
    super(Transformer_FixV, self).__init__()
    self.d_model = d_model
    self.vocab = vocab
    self.embedding = embedding
    b = 0.1
    self.W_kq = nn.Parameter(
        2* b * torch.rand(self.d_model, self.d_model, requires_grad=True) - b
        )

    self.W_ffv = W_ffv

  def decoder(self, x):
    x = self.embedding(x) #[batch, d_model, L] tensor

    #pdb.set_trace()
    Batch = x.size(0)
    scores = torch.matmul(
        torch.matmul(x.transpose(-2,-1), self.W_kq), x[:,:,-1].reshape(Batch, self.d_model, 1)
        ).reshape(Batch, -1)
    #pdb.set_trace()
    p_attn = scores.softmax(dim=-1).reshape(Batch, -1, 1)           # [batch, L, 1] tensor
    representation = torch.matmul(x, p_attn)                        # [batch, d_model, 1] tensor
    output = torch.matmul(self.W_ffv, representation)               # [Batch, vocab] tensor
    return output.reshape(Batch, self.vocab)

  def forward(self, x):
    return self.decoder(x).softmax(-1)

class Transformer_FixKQ(nn.Module):
  def __init__(self, embedding, vocab, d_model, W_KQ=None):
    super(Transformer_FixKQ, self).__init__()
    self.d_model = d_model
    self.vocab = vocab
    self.embedding = embedding
    if W_KQ == None:
      self.W_kq = torch.zeros(self.d_model, self.d_model, requires_grad=False)
    else:
      self.W_kq = W_KQ
    b = 0.1
    self.W_ffv = nn.Parameter(
        2*b*torch.rand(self.vocab, self.d_model, requires_grad=True)-b
        )

  def decoder(self, x):
    x = self.embedding(x) #[batch, d_model, L] tensor


    Batch = x.size(0)
    scores = torch.matmul(
        torch.matmul(x.transpose(-2,-1), self.W_kq), x[:,:,-1].reshape(Batch, self.d_model, 1)
        ).reshape(Batch, -1)

    p_attn = scores.softmax(dim=-1).reshape(Batch, -1, 1)           # [batch, L, 1] tensor
    representation = torch.matmul(x, p_attn)                        # [batch, d_model, 1] tensor
    output = torch.matmul(self.W_ffv, representation)               # [Batch, vocab] tensor
    return output.reshape(Batch, self.vocab)

  def forward(self, x):
    return self.decoder(x).softmax(-1)

class Batch:
  '''
  Object for holding a batch of data with mask during training.
  '''

  def __init__(self, tgt=None, pad=-1):
    '''
    tgt is a (batch of) sentence(s): Size([*, L])
    '''
    if tgt is not None:
      self.tgt = tgt[:,:-1]
      self.tgt_y = tgt[:,-1]
      self.ntokens = (self.tgt_y != pad).data.sum()

def data_gen(vocab, vocab_book, order_book, batch_size=1, nbatches=40, L_max=10):
  '''
  Generate random data with vocabulary size vocab
  '''
  for i in range(nbatches):
    data = torch.randint(0, vocab, size=(batch_size, L_max))
    data = torch.LongTensor(generate_training_dataset(vocab, batch_size, L_max, vocab_book, order_book))
    tgt = data.requires_grad_(False).clone().detach()

    yield Batch(tgt)

class NegativeLogLikelihoodLoss:
  'A simple loss compute and train function.'

  def __init__(self, generator=None):
    self.generator = generator

    cross_entropy = nn.CrossEntropyLoss(reduction='sum')
    self.criterion = cross_entropy

  def __call__(self, x, y, norm):
    if self.generator is not None:
      x = self.generator(x)
    sloss = (
        self.criterion(
            x.contiguous().view(-1, x.size(-1)), y.contiguous().view(-1)
        )
        / norm
    )
    return sloss.data * norm, sloss

# Commented out IPython magic to ensure Python compatibility.
def rate(step, model_size, warmup, factor=1):
  """
  We have to default the step to 1 for LambdaLR function
  to avoid zero raising to negative power.
  """
  if step == 0:
    step = 1
  return factor * (
      model_size ** (-0.5) * 1.0 #min(step ** (-0.5), step * warmup ** (-1.5))
  )


'Training Loop'

class TrainState:
  '''
  Track number of steps, examples, and tokens processed
  '''

  step: int = 0  # Steps in the current epoch
  accum_step: int = 0  # Number of gradient accumulation steps
  samples: int = 0  # total # of examples used
  tokens: int = 0  # total # of tokens processed

def run_epoch(
    data_iter,
    model,
    loss_compute,
    optimizer,
    scheduler,
    mode='train',
    accum_iter=1,
    train_state=TrainState()
  ):
      'Train a single epoch'
      start = time.time()
      total_tokens = 0
      total_loss = 0
      tokens = 0
      n_accum = 0
      for i, batch in enumerate(data_iter):
        out = model.decoder(batch.tgt)      # Size([*, L, d_model]), not probabilities
        loss, loss_node = loss_compute(out, batch.tgt_y, batch.ntokens) # tgt_y: Size([*, L])

        if mode == 'train' or mode == 'train+log':
          loss_node.backward()
          train_state.step += 1
          train_state.samples += batch.tgt.shape[0]
          train_state.tokens += batch.ntokens
          if i % accum_iter == 0:
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)
            n_accum += 1
            train_state.accum_step += 1
          scheduler.step()

        total_loss += loss
        total_tokens += batch.ntokens
        tokens += batch.ntokens
        if i % 40 == 1 and (mode == 'train' or mode == 'train+log'):
          lr = optimizer.param_groups[0]['lr']
          elapsed = time.time() - start
          print(
              (
                  'Mode: %s| Epoch Step: %6d | Loss: %6.6f'
                  + '| Tokens / Sec: %7.1f | Learning Rate: %6.1e'
              )
#               % (mode, i, loss / batch.ntokens, tokens / elapsed, lr)
          )
          start = time.time()
          tokens = 0
        if i % 40 == 1 and (mode == 'eval'):
          elapsed = time.time() - start
          print(
              (
                  'Mode: %s | Epoch Step: %6d | Accumlation Step %3d | Loss: %6.6f'
                  + '| Tokens / Sec: %7.1f '
              )
#               % (mode, i, n_accum, loss / batch.ntokens, tokens / elapsed)
          )
          start = time.time()
          tokens = 0


        del loss
        del loss_node
      return total_loss / total_tokens, train_state

def run_iter_trainKQ(
    data_iter,
    model,
    loss_compute,
    optimizer,
    scheduler,
    mode='train',
    accum_iter=1,
    train_state=TrainState()
  ):
      'Train a single iter'
      start = time.time()
      total_tokens = 0
      total_loss = 0
      tokens = 0
      n_accum = 0
      for sentence in data_iter:
        tgt = torch.tensor([sentence[:-1]]).requires_grad_(False).clone().detach()
        tgt_y = torch.tensor([sentence[-1]]).requires_grad_(False).clone().detach()

        out = model.decoder(tgt)      # Size([*, L, d_model]), not probabilities
        loss, loss_node = loss_compute(out, tgt_y, 1) # tgt_y: Size([*, L])
        total_loss += loss_node

      if mode == 'train' or mode == 'train+log':

          total_loss.backward()
          train_state.step += 1
          train_state.samples += tgt.shape[0]
          train_state.tokens += 1

          optimizer.step()
          optimizer.zero_grad(set_to_none=True)
          n_accum += 1
          train_state.accum_step += 1
          scheduler.step()


      total_tokens += 1
      tokens += 1


      del loss
      del loss_node
      return total_loss / total_tokens, train_state

# Dataset generation
import random


def create_vocab_book(V):
    # Initialize a list to hold the permutation
    # Create a list of numbers from 0 to V-1
    elements = list(range(V))
    # Shuffle the list in place to get a random permutation
    random.shuffle(elements)
    # Append the ordered permutation to the list
    return elements


def create_order_for_query(V):
    # Initialize a list to hold all the orders
    orders = []
    for i in range(V):
        # Create a list of numbers from 0 to V-1
        elements = list(range(V))
        # Shuffle the list in place to get a random permutation
        random.shuffle(elements)
        # Append the ordered permutation to the list
        orders.append(elements)
    return orders


def find_largest_according_to_order(order, elements):
    # Sort elements by their index in the order and return the first
    return sorted(elements, key=lambda x: order.index(x), reverse=False)[0]

import random

def create_training_sets(V, vocab_book):
    # V: vocabulary size
    collocation = []
    dataset = []
    M = [set() for _ in range(V)]
    m = [set() for _ in range(V)]

    for v in range(V):
        collocation.append([v, vocab_book[v]])
        s = [v, vocab_book[v]]
        # Check for stability or looping too many times
        while s[-1] not in set(s[:-1]):
            last_index = s[-1]
            if M[last_index] & set(s):
                r = list(M[last_index] & set(s))[0]
            else:
                # Choose r from s such that r is neither in M[last_index] nor in m[last_index]
                possible_choices = [x for x in s if x not in M[last_index] and x not in m[last_index]]
                if not possible_choices:
                    break  # If no valid choices, stop processing
                r = random.choice(possible_choices)
                M[last_index].add(r)

            # Update m with all other elements in s
            m[last_index].update([x for x in s if x != r])
            s.append(vocab_book[r])
        dataset.append(s)

        # Debugging: print the sequence and sets to check correctness
        #print(f"Sequence for v={v}: {s}")
        #print(f"M[{v}] after update: {M[v]}")
        #print(f"m[{v}] after update: {m[v]}")

    return collocation, dataset, M, m


def generate_training_dataset(V, N, L_max, vocab_book, order_book):
    #vocab_book = create_vocab_book(V)
    #order_book = create_order_for_query(V)
    dataset = []

    for _ in range(N):
        L = random.randint(2, L_max)
        v = [random.randint(0, V-1)]  # Start by randomly choosing v[0]
        for i in range(1, L):
            current_order = order_book[v[i-1]]
            largest = find_largest_according_to_order(current_order, v[:i])
            next_element = vocab_book[largest]
            if random.random() < 0.8:
              v.append(next_element)
            else:
              v.append(random.randint(0,V-1))

        dataset.append(v)

    return dataset

# Example usage
def Example_dataset():
  V = 100 # Example size of set
  k = 2  # Number of large elements to consider
  order_book = create_order_for_query(V)
  vocab_book = create_vocab_book(V)
  N = 20  # Number of samples
  L_max = 4 # Maximum length of each vector

  #print("vocabulary book {0, 1, ..., V-1}:", vocab_book)
  #print("order book {0, 1, ..., V-1}:", order_book)

  #vocab_book = list(range(1,V)) + [0]
  #order_book = [[0,1,3,4,2],[2,3,1,0,4],[4,2,3,0,1],[3, 1, 4, 0, 2], [4, 1,2,3,0]]


  data = generate_training_dataset(V, N, L_max, vocab_book, order_book)
  print("Generated dataset with constraints:", data)

Example_dataset()

def gen_sentence(model, in_sentence, L=1):
  sentence = copy.deepcopy(in_sentence)
  sen_with_confidence = []

  phi = nn.Softmax()

  for i in range(L):
    out = model.decode(sentence)
    prob = phi(out[:,-1])
    p, next_word = torch.max(prob, dim=1, keepdim=True)
    sentence = torch.cat((sentence, next_word), dim=1)
    sen_with_confidence.append([next_word[0,0], p[0,0].detach()])

  print(sentence)
  print(sen_with_confidence)

def example_simple_model(
    vocab_book,
    order_book,
    vocab=10,
    d_model=10,
    d_pos=0,
    L_max=5,
    batch_size=1
    ):

  model = Transformer(
      Embedding(vocab, d_model, d_pos),
      vocab,
      d_model
      )
  optimizer = torch.optim.Adam(
      model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
  )
  lr_scheduler = LambdaLR(
      optimizer = optimizer,
      lr_lambda=lambda step: rate(
          step, model_size=d_model, warmup=100
      )
  )


  for epoch in range(300):
    print('\n')
    print('Epoch:', epoch)
    model.train()
    run_epoch(
        data_gen(vocab, vocab_book, order_book, batch_size, nbatches=40, L_max= L_max),
        model,
        NegativeLogLikelihoodLoss(),
        optimizer,
        lr_scheduler,
        mode='train'
    )
    model.eval()
    run_epoch(
        data_gen(vocab, vocab_book, order_book, batch_size, nbatches=40, L_max = L_max),
        model,
        NegativeLogLikelihoodLoss(),
        DummyOptimizer,
        DummyScheduler,
        mode='eval'
    )[0]

    model.eval()
    src = torch.LongTensor([[1,2]])
    print(model.forward(src))

  return model

def example_simple_model_FixKQ(
    vocab_book,
    order_book,
    KQ,
    vocab=10,
    d_model=10,
    d_pos=0,
    L_max=5,
    batch_size=1
    ):

  model = Transformer_FixKQ(
      Embedding(vocab, d_model, d_pos),
      vocab,
      d_model,
      KQ
      )
  optimizer = torch.optim.Adam(
      model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
  )
  lr_scheduler = LambdaLR(
      optimizer = optimizer,
      lr_lambda=lambda step: rate(
          step, model_size=d_model, warmup=100
      )
  )

  for epoch in range(80):
    print('\n')
    print('Epoch:', epoch)
    model.train()
    run_epoch(
        data_gen(vocab, vocab_book, order_book, batch_size, nbatches=40, L_max= L_max),
        model,
        NegativeLogLikelihoodLoss(),
        optimizer,
        lr_scheduler,
        mode='train'
    )
    model.eval()
    run_epoch(
        data_gen(vocab, vocab_book, order_book, batch_size, nbatches=40, L_max = L_max),
        model,
        NegativeLogLikelihoodLoss(),
        DummyOptimizer,
        DummyScheduler,
        mode='eval'
    )[0]

    model.eval()
    src = torch.LongTensor([[1,2]])
    print(model.forward(src))

  return model, loss_list

def example_simple_model_FixV(
    vocab_book,
    order_book,
    V,
    vocab=10,
    d_model=10,
    d_pos=0,
    L_max=5,
    batch_size=1
    ):

  model = Transformer_FixV(
      Embedding(vocab, d_model, d_pos),
      vocab,
      d_model,
      V
      )
  optimizer = torch.optim.Adam(
      model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
  )
  lr_scheduler = LambdaLR(
      optimizer = optimizer,
      lr_lambda=lambda step: rate(
          step, model_size=d_model, warmup=100
      )
  )


  for epoch in range(80):
    print('\n')
    print('Epoch:', epoch)
    model.train()
    run_epoch(
        data_gen(vocab, vocab_book, order_book, batch_size, nbatches=40, L_max= L_max),
        model,
        NegativeLogLikelihoodLoss(),
        optimizer,
        lr_scheduler,
        mode='train'
    )
    model.eval()
    run_epoch(
        data_gen(vocab, vocab_book, order_book, batch_size, nbatches=40, L_max = L_max),
        model,
        NegativeLogLikelihoodLoss(),
        DummyOptimizer,
        DummyScheduler,
        mode='eval'
    )[0]

    model.eval()
    src = torch.LongTensor([[1,2]])
    print(model.forward(src))

  return model

def PreprocessV(vocab, d_model, vocab_book, Delta = 2):
  V = torch.zeros(vocab, d_model)
  for j in range(vocab):
      V[vocab_book[j], j] = d_model
  return V - 1.0

from torch.optim.optimizer import Optimizer, required

class NormalizedGD(Optimizer):
    def __init__(self, params, lr=required, norm_type=2):
        if lr is not required and lr < 0.0:
            raise ValueError("Invalid learning rate: {}".format(lr))
        if norm_type <= 0:
            raise ValueError("Invalid norm type: {}".format(norm_type))

        defaults = dict(lr=lr, norm_type=norm_type)
        super(NormalizedGD, self).__init__(params, defaults)

    def step(self, closure=None):
        """Performs a single optimization step."""
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                grad = p.grad.data
                norm = grad.norm(group['norm_type'])

                # Avoid division by zero
                if norm > 1e-6:
                    normalized_grad = grad / norm
                    p.data = p.data - group['lr'] * normalized_grad
        #print(group['lr'])
        return loss

def example_simple_model_FixKQ_Collocation(
    collocation,
    dataset,
    vocab=10,
    d_model=10,
    d_pos=0,
    L_max=5,
    batch_size=1
    ):



  def gen_collocation(collocation_):
    collocation1 = torch.LongTensor(collocation_)
    #pdb.set_trace()
    tgt = collocation1.requires_grad_(False).clone().detach()
    yield Batch(tgt)

  def get_training_data_for_KQ(dataset):
    for sentence in dataset:
      for end in range(2, len(sentence) + 1):
            yield sentence[:end]


  model = Transformer_FixKQ(
      Embedding(vocab, d_model, d_pos),
      vocab,
      d_model,
      W_KQ=None
      )
  optimizer = NormalizedGD(
      model.parameters(), lr=0.2
  )
  #optimizer = torch.optim.Adam(
  #    model.parameters(), lr=0.5, betas=(0.9, 0.98), eps=1e-9
  #)
  lr_scheduler = LambdaLR(
      optimizer = optimizer,
      lr_lambda=lambda step: rate(
          step, model_size=d_model, warmup=100
      )
  )

  loss_list = []
  V_list = []
  for epoch in range(900):

    model.train()
    loss, _  = run_epoch(
        gen_collocation(collocation),
        model,
        NegativeLogLikelihoodLoss(),
        optimizer,
        lr_scheduler,
        mode='train'
    )
    loss_list.append(loss.data)
    V_list.append(np.array(copy.deepcopy(model.W_ffv).requires_grad_(False)))

  model.eval()

  print('after train V', model.W_ffv)

  W_v = copy.deepcopy(model.W_ffv).requires_grad_(False)
  model_kq = Transformer_FixV(
      Embedding(vocab, d_model, d_pos),
      vocab,
      d_model,
      W_v
  )
  optimizer_kq = NormalizedGD(
      model_kq.parameters(), lr=0.05
  )
  lr_scheduler_kq = LambdaLR(
      optimizer = optimizer_kq,
      lr_lambda=lambda step: rate(
          step, model_size=d_model, warmup=100
      )
  )
  loss_kq_list = []
  KQ_list = []
  for epoch in range(900):
    model_kq.train()
    loss, _  = run_iter_trainKQ(
        get_training_data_for_KQ(dataset),
        model_kq,
        NegativeLogLikelihoodLoss(),
        optimizer_kq,
        lr_scheduler_kq,
        mode='train'
    )
    loss_kq_list.append(loss.data)
    KQ_list.append(np.array(copy.deepcopy(model_kq.W_kq).requires_grad_(False)))
  print('after train KQ', model_kq.W_kq)
  print(dataset)
  model_kq.eval()
  return model, loss_list, V_list, model_kq, loss_kq_list, KQ_list

import numpy as np
vocab = 20
d_model = 20
d_pos = 0
order_book = create_order_for_query(vocab)
vocab_book = create_vocab_book(vocab)


collocation, dataset, M, m = create_training_sets(vocab, vocab_book)

V = PreprocessV(vocab, d_model, vocab_book)

def optimal_KQ(vocab, d_model, M, m):
  KQ = np.zeros([d_model, d_model])
  for query in range(vocab):
    for i in range(vocab):
        if i in M[query]:
          KQ[i,query] = len(m[query])
        if i in m[query]:
          KQ[i,query] = -len(M[query])
  return KQ

KQ = optimal_KQ(vocab, d_model, M, m)

model, loss_list, V_list, model_kq, loss_kq_list, KQ_list = example_simple_model_FixKQ_Collocation(
    collocation,
    dataset,
    vocab,
    d_model,
    d_pos
    )

import numpy as np

def matrix_inner_product(A,B):
  a = np.sum(A*A)
  b = np.sum(B*B)
  c = np.sum(A*B)
  return c/np.sqrt(a*b)

import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

linew = 4
grid_linw = 3
figsize = (6,4)
ftsz = 50
lbsz = 40

loss_list = np.array(loss_list)
loss_kq_list = np.array(loss_kq_list)
IB_list = [matrix_inner_product(x,np.array(V)) for x in V_list]
IB_kq_list = [matrix_inner_product(x,np.array(KQ)) for x in KQ_list]
norm_list = [np.sqrt(np.sum(x*x)) for x in V_list]
norm_kq_list = [np.sqrt(np.sum(x*x)) for x in KQ_list]

epochs = range(900)

plt.figure(figsize=(6,5))
plt.plot(epochs, loss_list, linewidth=linew)


plt.xlabel(r'$t$', fontsize=ftsz)
plt.ylabel(r'$\mathcal{L}_0(W_{\mathrm{ov}}^{(t)})$', fontsize=ftsz)
#plt.title(r'Loss $\mathcal{L}(\theta^{(t)})$ vs Iteration $t$', fontsize=30, pad=20)
plt.grid(True, linewidth=grid_linw)
plt.gca().xaxis.set_major_locator(MultipleLocator(300))  # Set x-axis major grid spacing
#plt.gca().yaxis.set_major_locator(MultipleLocator(int(loss_list[0]*3)/10))  # Set y-axis major grid spacing
plt.tick_params(axis='both', which='major', labelsize=lbsz)
#plt.gca().set_yticks([0,1,2,3])

plt.show()

plt.figure(figsize=(6,5))
plt.plot(epochs, loss_kq_list, linewidth=linew)

plt.xlabel(r'$t$', fontsize=ftsz)
plt.ylabel(r'$\mathcal{L}(\theta^{(t)})$', fontsize=ftsz)
#plt.title(r'Loss $\mathcal{L}(\theta^{(t)})$ vs Iteration $t$', fontsize=30, pad=20)
plt.grid(True, linewidth=grid_linw)
plt.gca().xaxis.set_major_locator(MultipleLocator(300))  # Set x-axis major grid spacing
#plt.gca().yaxis.set_major_locator(MultipleLocator(int(loss_list[0]*3)/10))  # Set y-axis major grid spacing
plt.tick_params(axis='both', which='major', labelsize=lbsz)
plt.gca().set_yticks([0,10,20,30])

plt.show()

plt.figure(figsize=(6,5))
plt.plot(epochs, IB_list,linewidth=linew)
plt.xlabel(r'$t$', fontsize=ftsz)
plt.ylabel(r'$ \left\langle \bar{W}_{\mathrm{ov}}^{(t)}, \bar{W}_{\mathrm{ov}}^* \right\rangle $', fontsize=ftsz)
#plt.title(r'$\left\langle \bar{W}_{\mathrm{ov}}^{(t)}, \bar{W}_{\mathrm{ov}}^* \right\rangle $ vs Iteration $t$', fontsize=30, pad=20)
plt.grid(True, linewidth=grid_linw)
plt.gca().xaxis.set_major_locator(MultipleLocator(300))  # Set x-axis major grid spacing
#plt.gca().yaxis.set_major_locator(MultipleLocator(0.3))  # Set y-axis major grid spacing
plt.tick_params(axis='both', which='major', labelsize=lbsz)
plt.gca().set_yticks([1.0, 0.5, 0])
plt.show()

plt.figure(figsize=(6,5))
plt.plot(range(len(IB_kq_list)), IB_kq_list,linewidth=linew)
plt.xlabel(r'$t$', fontsize=ftsz)
plt.ylabel(r'$ \left\langle \bar{W}_{\mathrm{kq}}^{(t)}, \bar{W}_{\mathrm{kq}}^* \right\rangle $', fontsize=ftsz)
#plt.title(r'$\left\langle \bar{W}_{\mathrm{ov}}^{(t)}, \bar{W}_{\mathrm{ov}}^* \right\rangle $ vs Iteration $t$', fontsize=30, pad=20)
plt.grid(True, linewidth=grid_linw)
plt.gca().xaxis.set_major_locator(MultipleLocator(300))  # Set x-axis major grid spacing
#plt.gca().yaxis.set_major_locator(MultipleLocator(0.3))  # Set y-axis major grid spacing
plt.tick_params(axis='both', which='major', labelsize=lbsz)
plt.gca().set_yticks([0,0.5,1.0])
plt.ylim(0, 1.05)
plt.show()

plt.figure(figsize=(6,5))
plt.plot(epochs, norm_list,linewidth=linew)
plt.xlabel(r'$t$', fontsize=ftsz)
plt.ylabel(r'$\|W_{\mathrm{ov}}^{(t)}\|$', fontsize=ftsz)
#plt.title(r'Norm $\|W_{\mathrm{ov}}^{(t)}\|$ vs Iteration $t$', fontsize=30, pad=20)

plt.grid(True, linewidth=grid_linw)
plt.gca().xaxis.set_major_locator(MultipleLocator(300))  # Set x-axis major grid spacing
#plt.gca().yaxis.set_major_locator(MultipleLocator(int(norm_list[-1]/3)+1))  # Set y-axis major grid spacing
plt.tick_params(axis='both', which='major', labelsize=lbsz)

plt.show()

plt.figure(figsize=(6,5))
plt.plot(epochs, norm_kq_list,linewidth=linew)
plt.xlabel(r'$t$', fontsize=ftsz)
plt.ylabel(r'$\|W_{\mathrm{kq}}^{(t)}\|$', fontsize=ftsz)
#plt.title(r'Norm $\|W_{\mathrm{ov}}^{(t)}\|$ vs Iteration $t$', fontsize=30, pad=20)

plt.grid(True, linewidth=grid_linw)
plt.gca().xaxis.set_major_locator(MultipleLocator(300))  # Set x-axis major grid spacing
#plt.gca().yaxis.set_major_locator(MultipleLocator(int(norm_list[-1]/3)+1))  # Set y-axis major grid spacing
plt.tick_params(axis='both', which='major', labelsize=lbsz)

plt.show()