from utils import *

def sequential_train_dbat(num_models, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, 
                     test_rc_dl, perturb_dl, perturb_dl_test, model_fn, alpha=10, max_epoch=100, opt='SGD',
                     use_diversity_reg=True, reg_model_weights=None, lr_max=0.2, weight_decay=1e-5, use_scheduler=True):
  
  models = [model_fn(num_classes=1).to(DEVICE) for _ in range(num_models)]
  set_train_mode(models)
  
  stats = {f"m{i+1}": {"acc": [], "rm-acc": [], "rc-acc": [], "perturb-acc": [], "loss": [], "adv-loss": []} for i in range(len(models))}

  if reg_model_weights is None:
    reg_model_weights = [1.0 for _ in range(num_models)]

  for m_idx, m in enumerate(models):

    if opt == 'SGD':
      opt = torch.optim.SGD(m.parameters(), lr=lr_max, momentum=0.9, weight_decay=weight_decay)
    else:
      opt = torch.optim.Adam(m.parameters(), lr=lr_max, weight_decay=weight_decay)
    if use_scheduler:
      scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 0, lr_max, step_size_up=(len(train_dl)*max_epoch)//2, 
                                                    mode='triangular', cycle_momentum=False)
    else:
      scheduler = None
    perturb_sampler = dl_to_sampler(perturb_dl)

    for epoch in range(max_epoch):
      for itr, (x, y) in enumerate(train_dl):
        (x_tilde, _) = perturb_sampler()
        erm_loss = F.binary_cross_entropy_with_logits(m(x), y)
        
        if use_diversity_reg and m_idx != 0:
          adv_loss = []
          with torch.no_grad():
            set_eval_mode(models)
            ps = [torch.sigmoid(m_(x_tilde)) for m_ in models[:m_idx]]
            set_train_mode(models)
          psm = torch.sigmoid(m(x_tilde))
          for i in range(len(ps)):
            al = - ((1.-ps[i]) * psm + ps[i] * (1.-psm) + 1e-7).log().mean()
            adv_loss.append(al*reg_model_weights[i])
        else:
          adv_loss = [torch.tensor([0]).to(DEVICE)]

        adv_loss = sum(adv_loss)/sum(reg_model_weights[:len(adv_loss)])
        loss = erm_loss + alpha * adv_loss

        opt.zero_grad()
        loss.backward()
        opt.step()
        if scheduler is not None: scheduler.step()

        if (itr + epoch * len(train_dl)) % 200 == 0:
          set_eval_mode(models)
          itr_ = itr + epoch * len(train_dl)
          print_str = f"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_loss.item():.2f} adv-loss: {adv_loss.item():.2f} "
          if itr_ != 0 and scheduler is not None:
            print_str += f"[lr] {scheduler.get_last_lr()[0]:.5f} "
          stats[f"m{m_idx+1}"]["loss"].append((itr_, erm_loss.item()))
          stats[f"m{m_idx+1}"]["adv-loss"].append((itr_, adv_loss.item()))
          acc = get_acc(m, valid_dl)
          acc_rm = get_acc(m, valid_rm_dl)
          acc_rc = get_acc(m, valid_rc_dl)
          acc_on_perturb = get_acc(m, perturb_dl_test)
          stats[f"m{m_idx+1}"]["acc"].append((itr_, acc))
          stats[f"m{m_idx+1}"]["rm-acc"].append((itr_, acc_rm))
          stats[f"m{m_idx+1}"]["rc-acc"].append((itr_, acc_rc))
          stats[f"m{m_idx+1}"]["perturb-acc"].append((itr_, acc_on_perturb))
          print_str += f" acc: {acc:.2f}, {Fore.BLUE} r0/1-acc: {acc_rm:.2f} {Style.RESET_ALL}"
          set_train_mode(models)
          print(print_str)
        
        itr += 1

    test_acc = get_acc(m, test_dl)
    test_rm_acc = get_acc(m, test_rm_dl)
    test_rc_acc = get_acc(m, test_rc_dl)
    acc_on_perturb = get_acc(m, perturb_dl_test)
    ensemble_acc = get_ens_acc(models, test_rm_dl)
    stats[f"m{m_idx+1}"]["test-acc"] = test_acc
    stats[f"m{m_idx+1}"]["test-rm-acc"] = test_rm_acc
    stats[f"m{m_idx+1}"]["test-rc-acc"] = test_rc_acc
    stats[f"m{m_idx+1}"]["ens-test-rm-acc"] = ensemble_acc
    stats[f"m{m_idx+1}"]["final-perturb-acc"] = acc_on_perturb
    print(f"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_rm_acc:.3f}, r-acc-ens: {ensemble_acc:.3f}, acc-on-perturb: {acc_on_perturb:.3f}")

  return stats


def simultaneous_train_divdis(num_models, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, 
                     test_rc_dl, perturb_dl, perturb_dl_test, model_fn, alpha=10, max_epoch=100, opt='SGD',
                     use_diversity_reg=True, reg_model_weights=None, lr_max=0.2, weight_decay=1e-5, use_scheduler=True):
  
  from divdis import DivDisLoss
  loss_fn = DivDisLoss(heads=num_models, mode="mi", reduction="mean")

  m = model_fn(num_classes=num_models).to(DEVICE)
  m.train()
  
  stats = {f"m{i+1}": {"acc": [], "rm-acc": [], "rc-acc": [], "perturb-acc": [], "loss": [], "adv-loss": []} for i in range(num_models)}

  if opt == 'SGD':
    opt = torch.optim.SGD(m.parameters(), lr=lr_max, momentum=0.9, weight_decay=weight_decay)
  else:
    opt = torch.optim.Adam(m.parameters(), lr=lr_max, weight_decay=weight_decay)
  if use_scheduler:
    scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 0, lr_max, step_size_up=(len(train_dl)*max_epoch)//2, 
                                                  mode='triangular', cycle_momentum=False)
  else:
    scheduler = None
  perturb_sampler = dl_to_sampler(perturb_dl)

  for epoch in range(max_epoch):
    for itr, (x, y) in enumerate(train_dl):
      (x_tilde,_) = perturb_sampler()
      logits = m(x)
      logits_chunked = torch.chunk(logits, num_models, dim=-1)
      erm_losses = [F.binary_cross_entropy_with_logits(logit, y) for logit in logits_chunked]
      erm_loss = sum(erm_losses)
      
      if use_diversity_reg:
        adv_loss = loss_fn(m(x_tilde))
        loss = erm_loss + alpha * adv_loss
      else:
        loss = erm_loss

      opt.zero_grad()
      loss.backward()
      opt.step()
      if scheduler is not None: scheduler.step()

      if (itr + epoch * len(train_dl)) % 200 == 0:
        itr_ = itr + epoch * len(train_dl)
        for m_idx in range(num_models):
          m.eval()
          print_str = f"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_losses[m_idx].item():.2f} adv-loss: {adv_loss.item():.2f} "
          if itr_ != 0 and scheduler is not None:
            print_str += f"[lr] {scheduler.get_last_lr()[0]:.5f} "
          stats[f"m{m_idx+1}"]["loss"].append((itr_, erm_loss.item()))
          stats[f"m{m_idx+1}"]["adv-loss"].append((itr_, adv_loss.item()))
          acc = get_acc(m, valid_dl, idx=m_idx)
          acc_rm = get_acc(m, valid_rm_dl, idx=m_idx)
          acc_rc = get_acc(m, valid_rc_dl, idx=m_idx)
          acc_on_perturb = get_acc(m, perturb_dl_test, idx=m_idx)
          stats[f"m{m_idx+1}"]["acc"].append((itr_, acc))
          stats[f"m{m_idx+1}"]["rm-acc"].append((itr_, acc_rm))
          stats[f"m{m_idx+1}"]["rc-acc"].append((itr_, acc_rc))
          stats[f"m{m_idx+1}"]["perturb-acc"].append((itr_, acc_on_perturb))
          print_str += f" acc: {acc:.2f}, {Fore.BLUE} r0/1-acc: {acc_rm:.2f} {Style.RESET_ALL}"
          m.train()
          print(print_str)
      
      itr += 1
  for m_idx in range(num_models):
    m.eval()
    test_acc = get_acc(m, test_dl, idx=m_idx)
    test_rm_acc = get_acc(m, test_rm_dl, idx=m_idx)
    test_rc_acc = get_acc(m, test_rc_dl, idx=m_idx)
    acc_on_perturb = get_acc(m, perturb_dl_test, idx=m_idx)
    ensemble_acc = get_ens_acc(m, test_rm_dl)
    stats[f"m{m_idx+1}"]["test-acc"] = test_acc
    stats[f"m{m_idx+1}"]["test-rm-acc"] = test_rm_acc
    stats[f"m{m_idx+1}"]["test-rc-acc"] = test_rc_acc
    stats[f"m{m_idx+1}"]["ens-test-rm-acc"] = ensemble_acc
    stats[f"m{m_idx+1}"]["final-perturb-acc"] = acc_on_perturb
    print(f"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_rm_acc:.3f}, r-acc-ens: {ensemble_acc:.3f}, acc-on-perturb: {acc_on_perturb:.3f}")

  return stats


def sequential_train_divdis(num_models, train_dl, valid_dl, valid_rm_dl, valid_rc_dl, test_dl, test_rm_dl, 
                     test_rc_dl, perturb_dl, perturb_dl_test, model_fn, alpha=10, max_epoch=100, opt='SGD',
                     use_diversity_reg=True, reg_model_weights=None, lr_max=0.2, weight_decay=1e-5, use_scheduler=True):

  from divdis import DivDisLoss
  loss_fn = DivDisLoss(heads=num_models, mode="mi", reduction="mean")
  
  models = [model_fn(num_classes=1).to(DEVICE) for _ in range(num_models)]
  set_train_mode(models)
  
  stats = {f"m{i+1}": {"acc": [], "rm-acc": [], "rc-acc": [], "perturb-acc": [], "loss": [], "adv-loss": []} for i in range(len(models))}

  for m_idx, m in enumerate(models):

    if opt == 'SGD':
      opt = torch.optim.SGD(m.parameters(), lr=lr_max, momentum=0.9, weight_decay=weight_decay)
    else:
      opt = torch.optim.Adam(m.parameters(), lr=lr_max, weight_decay=weight_decay)
    if use_scheduler:
      scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 0, lr_max, step_size_up=(len(train_dl)*max_epoch)//2, 
                                                    mode='triangular', cycle_momentum=False)
    else:
      scheduler = None
    perturb_sampler = dl_to_sampler(perturb_dl)

    for epoch in range(max_epoch):
      for itr, (x, y) in enumerate(train_dl):
        (x_tilde,_) = perturb_sampler()
        erm_loss = F.binary_cross_entropy_with_logits(m(x), y)
        
        if use_diversity_reg and m_idx != 0:
          with torch.no_grad():
            set_eval_mode(models)
            prev_logits = [m_(x_tilde) for m_ in models[:m_idx]]
            set_train_mode(models)
          curr_logits = m(x_tilde)
          prev_logits.append(curr_logits)
          target_logits = torch.cat(prev_logits, dim=-1)
          adv_loss = loss_fn(target_logits)
          loss = erm_loss + alpha * adv_loss
        else:
          adv_loss = torch.tensor([0])
          loss = erm_loss
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        if scheduler is not None: scheduler.step()

        if (itr + epoch * len(train_dl)) % 200 == 0:
          set_eval_mode(models)
          itr_ = itr + epoch * len(train_dl)
          print_str = f"[m{m_idx+1}] {epoch}/{itr_} [train] loss: {erm_loss.item():.2f} adv-loss: {adv_loss.item():.2f} "
          if itr_ != 0 and scheduler is not None:
            print_str += f"[lr] {scheduler.get_last_lr()[0]:.5f} "
          stats[f"m{m_idx+1}"]["loss"].append((itr_, erm_loss.item()))
          stats[f"m{m_idx+1}"]["adv-loss"].append((itr_, adv_loss.item()))
          acc = get_acc(m, valid_dl)
          acc_rm = get_acc(m, valid_rm_dl)
          acc_rc = get_acc(m, valid_rc_dl)
          acc_on_perturb = get_acc(m, perturb_dl_test)
          stats[f"m{m_idx+1}"]["acc"].append((itr_, acc))
          stats[f"m{m_idx+1}"]["rm-acc"].append((itr_, acc_rm))
          stats[f"m{m_idx+1}"]["rc-acc"].append((itr_, acc_rc))
          stats[f"m{m_idx+1}"]["perturb-acc"].append((itr_, acc_on_perturb))
          print_str += f" acc: {acc:.2f}, {Fore.BLUE} r0/1-acc: {acc_rm:.2f} {Style.RESET_ALL}"
          set_train_mode(models)
          print(print_str)
        
        itr += 1

    test_acc = get_acc(m, test_dl)
    test_rm_acc = get_acc(m, test_rm_dl)
    test_rc_acc = get_acc(m, test_rc_dl)
    acc_on_perturb = get_acc(m, perturb_dl_test)
    ensemble_acc = get_ens_acc(models, test_rm_dl)
    stats[f"m{m_idx+1}"]["test-acc"] = test_acc
    stats[f"m{m_idx+1}"]["test-rm-acc"] = test_rm_acc
    stats[f"m{m_idx+1}"]["test-rc-acc"] = test_rc_acc
    stats[f"m{m_idx+1}"]["ens-test-rm-acc"] = ensemble_acc
    stats[f"m{m_idx+1}"]["final-perturb-acc"] = acc_on_perturb
    print(f"[m{m_idx+1}] [test] acc: {test_acc:.3f}, r-acc: {test_rm_acc:.3f}, r-acc-ens: {ensemble_acc:.3f}, acc-on-perturb: {acc_on_perturb:.3f}")

  return stats