# Standard imports
import argparse
import gc
import os
from pathlib import Path
from tqdm import tqdm
from sys import exit
import wandb
from matplotlib import pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.nn.utils
from torch.nn.utils import weight_norm
from functools import partial

# Our imports
import utils
from data import get_batch_simple, batch_copy_selective, copy_selective

# =======
# METRICS
# =======

def mse_loss(pred, label, args, device):
  objective = nn.MSELoss()
  return objective(pred, label.to(device))
def ce_loss(pred, label, args, device):
  objective = nn.CrossEntropyLoss()
  pred = pred.flatten(end_dim=-2)
  label = label.flatten()
  return objective(pred, label.to(device))

def curvature_loss(pred, target, args, device):
    pred_diff = pred[2:] - 2* pred[1:-1] + pred[:-2]
    target_diff = target[2:] - 2* target[1:-1] + target[:-2]
    return 1e5*F.mse_loss(pred_diff, target_diff) + mse_loss(pred, target, args, device)

loss_register = {'mse': mse_loss, 'crossentropy': ce_loss, 'curvature': curvature_loss}

batch_copy_selective_task = partial(batch_copy_selective, func=copy_selective)

# data_fn_dict = {'lorenz': get_lorenz, 'copy': get_batch_simple, 'selective': batch_copy_selective_task,
#                 'ordered': batch_copy_ordered, 'tf_copy': get_batch_simple_tf, 'copy_eos': copy_with_eos_token}

# here is what is implemented so far
data_fn_dict = {'copy': get_batch_simple, 'selective': batch_copy_selective_task}

# =======
# METHODS
# =======

def test(model, data_fn, objective, args, device):
    model.eval()
    with torch.no_grad():
        data, label = data_fn(args.batch_test, args.T)
        if args.model in ['mamba','T_alibi','T_nope','T_rope']:
          # mamba/tf expect batch first
          data,label = data.permute(1,0,2).squeeze().long(), label.permute(1,0,2).squeeze().long()
          
        if args.model == 'wavernn':
          out, _  = model(data.to(device), get_seq=False)
          if utils.str_to_bool(args.only_last) == True:
            out = out[-args.mem_len:]
            label = label[-args.mem_len:]
        else:
          out = model(data.to(device)).logits
          if utils.str_to_bool(args.only_last) == True:
            out = out[:,-args.mem_len:]
            label = label[:,-args.mem_len:]
          
        if args.loss == 'mse':
            loss = objective(out, label.float(), args, device)
        else:
            loss = objective(out, label, args, device)

    return loss.item()

def train(model, args, device):
  y_seq=None
  data_fn = data_fn_dict[args.dataset]

  objective = loss_register[args.loss]
  optimizer = optim.Adam(model.parameters(), lr=args.lr)

  pbar = tqdm(range(args.max_steps))
  for i in pbar:
      data, label = data_fn(num_samples=args.batch, sample_len=args.T, 
                            memory_len=args.mem_len)
  
      if args.model in ['mamba','T_alibi','T_nope','T_rope']:
        # mamba/tf expect batch first
        data,label = data.permute(1,0,2).squeeze().long(), label.permute(1,0,2).squeeze().long()

      optimizer.zero_grad()

      if args.model == 'wavernn':
        out, seq  = model(data.to(device), get_seq=False)
        if utils.str_to_bool(args.only_last) == True:
          out = out[-args.mem_len:]
          label = label[-args.mem_len:]
      else:
        out = model(data.to(device)).logits
        if utils.str_to_bool(args.only_last) == True:
          out = out[:,-args.mem_len:]
          label = label[:,-args.mem_len:]

      if args.loss == 'mse':
          loss = objective(out, label.float(), args, device)
      else:
          loss = objective(out, label, args, device)

      # this is due to the quirky (T,B,N) convention of the rnn
      # loss will divide by num timesteps so we need to divide by batchsize
      if args.loss == 'crossentropy' and args.model == 'wavernn':
        loss = loss / args.batch.float()
        
      loss.backward()

      if args.grad_clip > 0:
          torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)

      optimizer.step()
      pbar.set_description(f"Loss: {loss}")

      if(i%args.log_interval==0 and i!=0):

          # with torch.no_grad():
          #   if isinstance(model.cell.Wy, SoftmaxNormalizedConv1dCircular):
          #     # Reshape back to the original weight shape.
          #     conv = None #model.cell.Wy.weight.weight.data
          #     coef = None
          #   elif isinstance(model.cell.Wy, nn.Conv1d):
          #     conv = model.cell.Wy.weight.data
          #     coef = None
          #   else:
          #     conv = model.cell.Wy.weight.data
          #     coef = None

          print(" ")
          utils.log('Train Loss:', loss)

          mse_error = test(model, data_fn, objective, args, device)
          utils.log('Test Loss:', mse_error)
          print(" ")

          if args.model == 'wavernn':
            plt.figure(figsize=(6, 4))
            plt.imshow(model.cell.Wx.weight.detach().cpu().numpy(), cmap="viridis", aspect="auto")
            plt.colorbar()
            plt.title("V")

          # removing wandb for the github for now
          # wandb.log({"weight_matrix": wandb.Image(plt)})
          # plt.close()

            model.eval()
            with torch.no_grad():
              _, y_seq = model(data.to(device), get_seq=True)
              if len(y_seq) > 0:
                  utils.Plot_Seq(y_seq.detach().cpu()[:, 0], step=i, log=utils.str_to_bool(args.use_wandb))

          if mse_error<1e-5:
            break
          if torch.isnan(loss):
              exit()

          if args.model == 'wavernn':
            utils.plot_output(data[:, 0], out[:, 0], label[:, 0], 
                              log=utils.str_to_bool(args.use_wandb), step=i)
          else:
            utils.plot_output(data[0, :].unsqueeze(-1), F.softmax(out[0,:],-1), label[0, :].unsqueeze(-1), 
                              log=utils.str_to_bool(args.use_wandb), step=i)

          model.train()
        
  return data, label, out, y_seq
