import copy
import sys
import torch
import wandb

from pyprojroot import here as project_root

sys.path.insert(0, str(project_root()))

from src.train_utils.utils import get_device

@torch.no_grad()
def log_enc_dec_upd_eff_dim(model):
  if hasattr(model, "module"): # DDP artifact
    model = model.module

  enc = model.pre_quant_proj[0].weight.grad
  dec = model.post_quant_proj.weight.grad
  
  # sanity checks
  assert enc is not None, "Encoder gradient is None"
  assert dec is not None, "Decoder gradient is None"
  assert len(enc.shape) == 2, "Encoder gradient shape is not 2D"
  assert len(dec.shape) == 2, "Decoder gradient shape is not 2D"

  transpose = lambda w: w.T if w.shape[1] > w.shape[0] else w # get_effective_dim expects the first dim >= second dim
  enc = transpose(enc)
  dec = transpose(dec)

  from src.eval_utils.pca import get_effective_dim
  encoder_upd_eff_dim = get_effective_dim(enc)
  decoder_upd_eff_dim = get_effective_dim(dec)

  from src.train_utils.wandb_utils import wandb_log
  wandb_log({'encoder_upd_eff_dim': encoder_upd_eff_dim, 'decoder_upd_eff_dim': decoder_upd_eff_dim})

def train_loop(args, loader, model, world_size, optimizer, scaler, scheduler, logger, max_steps=10000):
  import src.train_utils.trainer as trainer
  max_steps = int(
    max_steps / (args.batch_size * world_size / 16))  # Hack to ensure backwards consistency w.r.t. lr schedule :/
  print(f'Max steps: {max_steps}')
  model.train()
  device = get_device(model)
  optimizer.zero_grad()

  avg_rec_loss = accumulate_count = avg_commit_loss = 0
  for i, x in enumerate(loader):
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
      # Get only the raw images if training on a dataset like ImageNet that returns (image, label) tuples.
      if len(x) == 2: x = x[0]

      x = x.to(device)
      ret = model(x, 
                                  vhp='vhp' in args.model, 
                                  double_fp=('dfp' in args.model or args.double_fp),
                                  rot='rot' in args.model,
                                loss_scale=scaler.get_scale())
      rec_loss, vq_loss = ret['rec_loss'], ret['commit_loss']

      rec_loss = torch.mean(rec_loss)
      vq_loss = torch.mean(vq_loss)

      loss = rec_loss + vq_loss
      if 'fp2_rec_loss' in ret:
        fp2_rec_loss = torch.mean(ret['fp2_rec_loss'])
        loss = loss + fp2_rec_loss
      rec_val = rec_loss.item()
      commit_val = vq_loss.item()
      scaler.scale(loss).backward()

      accumulate_count += 1


      from src.train_utils.wandb_utils import wandb_log
      log_dict = {
        'rec_loss': rec_loss.item(),
        'commit_loss': vq_loss.item() / args.commit_weight,
      }
      if 'fp2_rec_loss' in ret:
        log_dict['fp2_rec_loss'] = fp2_rec_loss.item()
      wandb_log(log_dict)

    if accumulate_count == args.n_accumulate:
      accumulate_count = 0
      scaler.unscale_(optimizer)
      torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

      scaler.step(optimizer)  # Update parameters w.r.t. optimizer values.

      scaler.update()  # Update the scale factor for the next iteration.
      scheduler.step()  # Update the scheduler.

      log_enc_dec_upd_eff_dim(model) # do this before zeroing out gradients

      optimizer.zero_grad()  # Zero out gradient attributes for all parameters.

    if i % (5 * args.n_accumulate) == 0:
      print(f'Train loss at step {i // args.n_accumulate}: {rec_val + commit_val:.3f}')
    if trainer.global_step == 1000 and False:
      trainer.to_save['state_dict'] = model.state_dict()
      import torch.distributed as dist
      if dist.get_rank() == 0:
        from pathlib import Path
        save_path = f'{trainer.save_path}/step1000_outputs/save.pt'
        Path(save_path).parent.mkdir(parents=True, exist_ok=True)
        torch.save(trainer.to_save, save_path)
    trainer.global_step += 1
    avg_rec_loss += rec_val
    avg_commit_loss += commit_val
    if i > max_steps:
      break
  avg_rec_loss /= (i + 1)
  avg_commit_loss /= (i + 1)
  logger.info(f'average train loss: {avg_rec_loss + avg_commit_loss:3f}')
  return avg_rec_loss, avg_commit_loss
