import os
import math
import json
import random as rnd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
import matplotlib.pyplot as plt
import pandas as  pd
import torchvision.utils as vision_utils
from PIL import Image
import torchvision
from colorama import Fore, Back, Style
from matplotlib.ticker import NullFormatter
import sys
sys.path.insert(1, os.path.join(sys.path[0], ".."))

DEVICE = torch.device('cuda')

def get_mc_dataset(mnist_root, cifar_root, degree_of_balance):
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

    mnist_train = torchvision.datasets.MNIST(mnist_root, train=True, download=True, transform=transform)
    cifar_train = torchvision.datasets.CIFAR10(cifar_root, train=True, download=True, transform=transform)
    mnist_perturb_base, mnist_train, mnist_valid = random_split(mnist_train, [10000, 45000, 5000])# , generator=torch.Generator().manual_seed(42))
    cifar_perturb_base, cifar_train, cifar_valid = random_split(cifar_train, [10000, 35000, 5000])# , generator=torch.Generator().manual_seed(42))

    mnist_test = torchvision.datasets.MNIST(mnist_root, train=False, download=True, transform=transform)
    cifar_test = torchvision.datasets.CIFAR10(cifar_root, train=False, download=True, transform=transform)


    # Training / valid / test datasets
    data_train = build_mc_dataset(mnist_train, cifar_train)
    data_valid = build_mc_dataset(mnist_valid, cifar_valid)
    data_test = build_mc_dataset(mnist_test, cifar_test)

    train_dl = torch.utils.data.DataLoader(data_train, batch_size=256, shuffle=True)
    valid_dl = torch.utils.data.DataLoader(data_valid, batch_size=1024, shuffle=False)
    test_dl = torch.utils.data.DataLoader(data_test, batch_size=1024, shuffle=False)


    # MNIST randomized test / valid datasets
    data_test_rm = build_mc_dataset(mnist_test, cifar_test, randomize_m=True, randomize_c=False)
    data_valid_rm = build_mc_dataset(mnist_valid, cifar_valid, randomize_m=True, randomize_c=False)

    test_rm_dl = torch.utils.data.DataLoader(data_test_rm, batch_size=1024, shuffle=False)
    valid_rm_dl = torch.utils.data.DataLoader(data_valid_rm, batch_size=1024, shuffle=False)

    # CIFAR-10 randomized test / valid datasets
    data_test_rc = build_mc_dataset(mnist_test, cifar_test, randomize_m=False, randomize_c=True)
    data_valid_rc = build_mc_dataset(mnist_valid, cifar_valid, randomize_m=False, randomize_c=True)

    test_rc_dl = torch.utils.data.DataLoader(data_test_rc, batch_size=1024, shuffle=False)
    valid_rc_dl = torch.utils.data.DataLoader(data_valid_rc, batch_size=1024, shuffle=False)

    print(f"Train length: {len(train_dl.dataset)}")
    print(f"Test length: {len(test_dl.dataset)}")
    print(f"Test length randomized mnist: {len(test_rm_dl.dataset)}")
    print(f"Test length randomized cifar10: {len(test_rc_dl.dataset)}")

    data_perturb = build_mc_perturb_dataset(mnist_perturb_base, cifar_perturb_base, including_labels=True, degree_of_balance=degree_of_balance)
    data_perturb, data_perturb_test = random_split(data_perturb, [int(0.8 * len(data_perturb)), len(data_perturb) - int(0.8 * len(data_perturb))])

    perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)
    perturb_dl_test = torch.utils.data.DataLoader(data_perturb_test, batch_size=256, shuffle=False)

    print(f"OOD dataset size: {len(perturb_dl.dataset)}")
    print(f"OOD dataset test size: {len(perturb_dl_test.dataset)}")

    return train_dl, valid_dl, test_dl, valid_rm_dl, test_rm_dl, valid_rc_dl, test_rc_dl, perturb_dl, perturb_dl_test

def get_mf_dataset(mnist_root, fashm_root, degree_of_balance):
    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

    mnist_train = torchvision.datasets.MNIST(mnist_root, train=True, download=True, transform=transform)
    fashm_train = torchvision.datasets.FashionMNIST(fashm_root, train=True, download=True, transform=transform)
    mnist_perturb_base, mnist_train, mnist_valid = random_split(mnist_train, [10000, 45000, 5000])# , generator=torch.Generator().manual_seed(42))
    fashm_perturb_base, fashm_train, fashm_valid = random_split(fashm_train, [10000, 45000, 5000])#, generator=torch.Generator().manual_seed(42))

    mnist_test = torchvision.datasets.MNIST(mnist_root, train=False, download=True, transform=transform)
    fashm_test = torchvision.datasets.FashionMNIST(fashm_root, train=False, download=True, transform=transform)

    # Training / valid / test datasets
    data_train = build_mf_dataset(mnist_train, fashm_train)
    data_valid = build_mf_dataset(mnist_valid, fashm_valid)
    data_test = build_mf_dataset(mnist_test, fashm_test)

    train_dl = torch.utils.data.DataLoader(data_train, batch_size=256, shuffle=True)
    valid_dl = torch.utils.data.DataLoader(data_valid, batch_size=1024, shuffle=False)
    test_dl = torch.utils.data.DataLoader(data_test, batch_size=1024, shuffle=False)


    # MNIST randomized test / valid datasets
    data_test_rm = build_mf_dataset(mnist_test, fashm_test, randomize_m=True, randomize_f=False)
    data_valid_rm = build_mf_dataset(mnist_valid, fashm_valid, randomize_m=True, randomize_f=False)

    test_rm_dl = torch.utils.data.DataLoader(data_test_rm, batch_size=1024, shuffle=False)
    valid_rm_dl = torch.utils.data.DataLoader(data_valid_rm, batch_size=1024, shuffle=False)

    # F-MNIST randomized test / valid datasets
    data_test_rf = build_mf_dataset(mnist_test, fashm_test, randomize_m=False, randomize_f=True)
    data_valid_rf = build_mf_dataset(mnist_valid, fashm_valid, randomize_m=False, randomize_f=True)

    test_rf_dl = torch.utils.data.DataLoader(data_test_rf, batch_size=1024, shuffle=False)
    valid_rf_dl = torch.utils.data.DataLoader(data_valid_rf, batch_size=1024, shuffle=False)


    print(f"Train length: {len(train_dl.dataset)}")
    print(f"Test length: {len(test_dl.dataset)}")
    print(f"Test length randomized mnist: {len(test_rm_dl.dataset)}")
    print(f"Test length randomized cifar10: {len(test_rf_dl.dataset)}")

    data_perturb = build_mf_perturb_dataset(mnist_perturb_base, fashm_perturb_base, including_labels=True, degree_of_balance=degree_of_balance)
    data_perturb, data_perturb_test = random_split(data_perturb, [int(0.8 * len(data_perturb)), len(data_perturb) - int(0.8 * len(data_perturb))])

    perturb_dl = torch.utils.data.DataLoader(data_perturb, batch_size=256, shuffle=True)
    perturb_dl_test = torch.utils.data.DataLoader(data_perturb_test, batch_size=256, shuffle=False)

    print(f"OOD dataset size: {len(perturb_dl.dataset)}")
    print(f"OOD dataset test size: {len(perturb_dl_test.dataset)}")

    return train_dl, valid_dl, test_dl, valid_rm_dl, test_rm_dl, valid_rf_dl, test_rf_dl, perturb_dl, perturb_dl_test


def build_mc_dataset(mnist_data, cifar_data, randomize_m=False, randomize_c=False):
  X_m_train_0, _ = keep_only_lbls(mnist_data, lbls=[0])
  X_m_train_1, _ = keep_only_lbls(mnist_data, lbls=[1])
  X_m_train_0 = format_mnist(X_m_train_0.view(-1, 1, 28, 28))
  X_m_train_1 = format_mnist(X_m_train_1.view(-1, 1, 28, 28))
  X_m_train_0 = X_m_train_0[torch.randperm(len(X_m_train_0))]
  X_m_train_1 = X_m_train_1[torch.randperm(len(X_m_train_1))]

  X_c_train_1, _ = keep_only_lbls(cifar_data, lbls=[1])
  X_c_train_9, _ = keep_only_lbls(cifar_data, lbls=[9])
  X_c_train_1 = X_c_train_1[torch.randperm(len(X_c_train_1))]
  X_c_train_9 = X_c_train_9[torch.randperm(len(X_c_train_9))]

  min_01 = min(len(X_m_train_0), len(X_c_train_1))
  min_19 = min(len(X_m_train_1), len(X_c_train_9))
  X_top = torch.cat((X_m_train_0[:min_01], X_m_train_1[:min_19]),dim=0)
  X_bottom = torch.cat((X_c_train_1[:min_01], X_c_train_9[:min_19]),dim=0) 
  if randomize_m:
    shuffle = torch.randperm(len(X_top))
    X_top = X_top[shuffle]
  if randomize_c:
    shuffle = torch.randperm(len(X_bottom))
    X_bottom = X_bottom[shuffle]
  X_train = torch.cat((X_top, X_bottom), dim=2)
  Y_train = torch.cat((torch.zeros((min_01,)), torch.ones((min_19,))))
  shuffle = torch.randperm(len(X_train))
  X_train = X_train[shuffle]
  Y_train = Y_train[shuffle].float().view(-1,1)
  data_train = torch.utils.data.TensorDataset(X_train.to(DEVICE), Y_train.to(DEVICE))
  return data_train


def build_mf_dataset(mnist_data, fashm_data, randomize_m=False, randomize_f=False):
  X_m_train_0, _ = keep_only_lbls(mnist_data, lbls=[0])
  X_m_train_1, _ = keep_only_lbls(mnist_data, lbls=[1])
  X_m_train_0 = X_m_train_0[torch.randperm(len(X_m_train_0))]
  X_m_train_1 = X_m_train_1[torch.randperm(len(X_m_train_1))]

  X_f_train_4, _ = keep_only_lbls(fashm_data, lbls=[4])
  X_f_train_3, _ = keep_only_lbls(fashm_data, lbls=[3])
  X_f_train_4 = X_f_train_4[torch.randperm(len(X_f_train_4))]
  X_f_train_3 = X_f_train_3[torch.randperm(len(X_f_train_3))]

  min_04 = min(len(X_m_train_0), len(X_f_train_4))
  min_13 = min(len(X_m_train_1), len(X_f_train_3))
  X_top = torch.cat((X_m_train_0[:min_04], X_m_train_1[:min_13]),dim=0) 
  X_bottom = torch.cat((X_f_train_4[:min_04], X_f_train_3[:min_13]),dim=0) 
  if randomize_m:
    shuffle = torch.randperm(len(X_top))
    X_top = X_top[shuffle]
  if randomize_f:
    shuffle = torch.randperm(len(X_bottom))
    X_bottom = X_bottom[shuffle]
  X_train = torch.cat((X_top, X_bottom), dim=2)
  Y_train = torch.cat((torch.zeros((min_04,)), torch.ones((min_13,))))
  shuffle = torch.randperm(len(X_train))
  X_train = X_train[shuffle]
  Y_train = Y_train[shuffle].float().view(-1,1)
  data_train = torch.utils.data.TensorDataset(X_train.to(DEVICE), Y_train.to(DEVICE))
  return data_train


def build_mc_perturb_dataset(mnist_test, cifar_test, including_labels=True, degree_of_balance=None):
  assert degree_of_balance >= 0 and degree_of_balance <= 1
  spurious_ratio = 1 - degree_of_balance

  # Filter the class 0 and 1 in the given MNIST data
  X_m_0, _ = keep_only_lbls(mnist_test, lbls=[0])
  X_m_1, _ = keep_only_lbls(mnist_test, lbls=[1])
  X_m_0 = format_mnist(X_m_0.view(-1, 1, 28, 28))
  X_m_1 = format_mnist(X_m_1.view(-1, 1, 28, 28))
  # Filter the car and truck for the given CIFAR10 data
  X_c_1, _ = keep_only_lbls(cifar_test, lbls=[1])
  X_c_9, _ = keep_only_lbls(cifar_test, lbls=[9])

  # Shuffle
  X_c_1 = X_c_1[torch.randperm(len(X_c_1))]
  X_c_9 = X_c_9[torch.randperm(len(X_c_9))]
  X_m_0 = X_m_0[torch.randperm(len(X_m_0))]
  X_m_1 = X_m_1[torch.randperm(len(X_m_1))]

  # In training, m0c1 | m1c9
  # Find the shorter one
  min_group = min(len(X_m_0), len(X_c_1), len(X_m_1), len(X_c_9))
  X_c_1 = X_c_1[:min_group]
  X_c_9 = X_c_9[:min_group]
  X_m_0 = X_m_0[:min_group]
  X_m_1 = X_m_1[:min_group]

  sp_len = int(spurious_ratio * min_group)
  norm_len = len(X_c_1[sp_len:])

  sp_top = torch.cat((X_m_1[:sp_len], X_m_0[:sp_len]),dim=0)
  sp_bottom = torch.cat((X_c_1[:sp_len], X_c_9[:sp_len]),dim=0)
  norm_top = torch.cat((X_m_0[sp_len:], X_m_1[sp_len:]),dim=0)
  norm_bottom = torch.cat((X_c_1[sp_len:], X_c_9[sp_len:]),dim=0)

  X_sp = torch.cat((sp_top, sp_bottom), dim=2)
  X_norm = torch.cat((norm_top, norm_bottom), dim=2)
  X = torch.cat((X_sp, X_norm), dim=0)
  # Y = torch.cat((torch.zeros((min_11,)), torch.ones((min_09,))))
  # c1 --> 0, c9 --> 1
  Y = torch.cat((torch.zeros((sp_len,)), torch.ones((sp_len,)), torch.zeros((norm_len,)), torch.ones((norm_len,))))

  shuffle = torch.randperm(len(X))
  X = X[shuffle]
  Y = Y[shuffle].float().view(-1,1)
  if including_labels:
    data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE), Y.to(DEVICE))
  else:
    data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE))
  return data_perturb


def build_mf_perturb_dataset(mnist_test, fashm_test, including_labels=True, degree_of_balance=None):
  assert degree_of_balance >= 0 and degree_of_balance <= 1
  spurious_ratio = 1 - degree_of_balance

  # Filter the class 0 and 1 in the given MNIST data
  X_m_0, _ = keep_only_lbls(mnist_test, lbls=[0])
  X_m_1, _ = keep_only_lbls(mnist_test, lbls=[1])
  # X_m_0 = format_mnist(X_m_0.view(-1, 1, 28, 28))
  # X_m_1 = format_mnist(X_m_1.view(-1, 1, 28, 28))
  # Filter the coats and dresses for the given Fashion-MNIST data
  X_f_3, _ = keep_only_lbls(fashm_test, lbls=[3])
  X_f_4, _ = keep_only_lbls(fashm_test, lbls=[4])

  # Shuffle
  X_f_3 = X_f_3[torch.randperm(len(X_f_3))]
  X_f_4 = X_f_4[torch.randperm(len(X_f_4))]
  X_m_0 = X_m_0[torch.randperm(len(X_m_0))]
  X_m_1 = X_m_1[torch.randperm(len(X_m_1))]

  # In training, m0f4 | m1f3 
  # Find the shorter one
  min_group = min(len(X_m_0), len(X_f_3), len(X_m_1), len(X_f_4))
  X_f_3 = X_f_3[:min_group]
  X_f_4 = X_f_4[:min_group]
  X_m_0 = X_m_0[:min_group]
  X_m_1 = X_m_1[:min_group]

  sp_len = int(spurious_ratio * min_group)
  norm_len = len(X_f_3[sp_len:])

  sp_top = torch.cat((X_m_1[:sp_len], X_m_0[:sp_len]),dim=0)
  sp_bottom = torch.cat((X_f_4[:sp_len], X_f_3[:sp_len]),dim=0)
  norm_top = torch.cat((X_m_0[sp_len:], X_m_1[sp_len:]),dim=0)
  norm_bottom = torch.cat((X_f_4[sp_len:], X_f_3[sp_len:]),dim=0)

  X_sp = torch.cat((sp_top, sp_bottom), dim=2)
  X_norm = torch.cat((norm_top, norm_bottom), dim=2)
  X = torch.cat((X_sp, X_norm), dim=0)
  # Y = torch.cat((torch.zeros((min_11,)), torch.ones((min_09,))))
  # f4-->0 f3-->1
  Y = torch.cat((torch.zeros((sp_len,)), torch.ones((sp_len,)), torch.zeros((norm_len,)), torch.ones((norm_len,))))

  shuffle = torch.randperm(len(X))
  X = X[shuffle]
  Y = Y[shuffle].float().view(-1,1)
  if including_labels:
    data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE), Y.to(DEVICE))
  else:
    data_perturb = torch.utils.data.TensorDataset(X.to(DEVICE))
  return data_perturb

def plot_samples(dataset, nrow=13, figsize=(10,7)):
  try:
    X, Y = dataset.tensors
  except:
    try:
      (X,) = dataset.tensors
    except:
      X = dataset
  fig = plt.figure(figsize=figsize, dpi=130)
  grid_img = vision_utils.make_grid(X[:nrow].cpu(), nrow=nrow, normalize=True, padding=1)
  _ = plt.imshow(grid_img.permute(1, 2, 0), interpolation='nearest')
  _ = plt.tick_params(axis=u'both', which=u'both',length=0)
  ax = plt.gca()
  _ = ax.xaxis.set_major_formatter(NullFormatter()) 
  _ = ax.yaxis.set_major_formatter(NullFormatter()) 
  plt.show()


def keep_only_lbls(dataset, lbls):
  lbls = {lbl: i for i, lbl in enumerate(lbls)}
  final_X, final_Y = [], []
  for x, y in dataset:
    if y in lbls:
      final_X.append(x)
      final_Y.append(lbls[y])
  X = torch.stack(final_X)
  Y = torch.tensor(final_Y).float().view(-1,1)
  return X, Y


def format_mnist(imgs):
  imgs = np.stack([np.pad(imgs[i][0], 2, constant_values=0)[None,:] for i in range(len(imgs))])
  imgs = np.repeat(imgs, 3, axis=1)
  return torch.tensor(imgs)


@torch.no_grad()
def get_acc(model, dl, idx=None):
  if idx is None:  # Then it's D-BAT
    model.eval()
    acc = []
    for X, y in dl:
      acc.append((torch.sigmoid(model(X)) > 0.5) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)
    model.train()
    return acc.item()
  else:  # Then it's DivDis
    model.eval()
    acc = []
    for X, y in dl:
      acc.append((torch.sigmoid(model(X))[:, idx][:, None] > 0.5) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)
    model.train()
    return acc.item()


@torch.no_grad()
def get_ens_acc(ensemble, dl):
  if isinstance(ensemble, nn.Module):  # Then DivDis
    model = ensemble
    model.eval()
    acc = []
    for X, y in dl:
      probs = torch.sigmoid(model(X)).mean(dim=0)
      acc.append((probs > 0.5) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)
    model.train()
    return acc.item()
  else:  # Then D-BAT
    for model in ensemble:
      model.eval()
    acc = []
    for X, y in dl:
      probs = [torch.sigmoid(model(X)) for model in ensemble]
      probs = torch.stack(probs).mean(dim=0)
      acc.append((probs > 0.5) == y)
    acc = torch.cat(acc)
    acc = torch.sum(acc)/len(acc)
    for model in ensemble:
      model.train()
    return acc.item()


def dl_to_sampler(dl):
  dl_iter = iter(dl)
  def sample():
    nonlocal dl_iter
    try:
      return next(dl_iter)
    except StopIteration:
      dl_iter = iter(dl)
      return next(dl_iter)
  return sample


def print_stats(stats):

  fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1,5,figsize=(16,3), dpi=110)
  ax1.grid()
  ax2.grid()
  ax3.grid()
  ax4.grid()
  ax5.grid()

  ax1.set_title("ERM loss")
  ax2.set_title("Adv Loss")
  ax3.set_title("Acc")
  ax4.set_title("Randomized MNIST Acc")
  ax5.set_title("Randomized CIFAR Acc")
  
  ax1.set_xlabel("iterations")
  ax2.set_xlabel("iterations")
  ax3.set_xlabel("iterations")
  ax4.set_xlabel("iterations")
  ax5.set_xlabel("iterations")

  for m_id, m_stats in stats.items():
    if m_id[0] != 'm':
      continue
    itrs = [x[0] for x in m_stats['loss']]
    ax1.plot(itrs, [x[1] for x in m_stats['loss']], label=m_id)
    ax2.plot(itrs, [x[1] for x in m_stats['adv-loss']], label=m_id)
    ax3.plot(itrs, [x[1] for x in m_stats['acc']], label=m_id)
    ax4.plot(itrs, [x[1] for x in m_stats['rm-acc']], label=m_id)
    ax5.plot(itrs, [x[1] for x in m_stats['rc-acc']], label=m_id)

  ax3.set_ylim(0.45, 1.05)
  ax4.set_ylim(0.45, 1.05)
  ax5.set_ylim(0.45, 1.05)
  
  
class LeNet_MC(nn.Module):

    def __init__(self, num_classes=10, dropout_p=0.0) -> None:
        super().__init__()
        self.droput_p = dropout_p
        self.conv1 = nn.Conv2d(3, 32, kernel_size=5, padding=2)
        self.conv2 = nn.Conv2d(32, 56, kernel_size=5)
        self.fc1 = nn.Linear(2016, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
        self.avgpool_2 = nn.AvgPool2d(kernel_size=2)
        self.avgpool_3 = nn.AvgPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, dropout=True) -> torch.Tensor:
        x = self.relu(self.conv1(x))
        x = F.dropout(x, p=self.droput_p, training=dropout)
        x = self.avgpool_2(x)
        x = self.relu(self.conv2(x))
        x = F.dropout(x, p=self.droput_p, training=dropout)
        x = self.avgpool_3(x)
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        x = self.relu(x)
        x = F.dropout(x, p=self.droput_p, training=dropout)
        x = self.fc2(x)
        x = self.relu(x)
        x = F.dropout(x, p=self.droput_p, training=dropout)
        x = self.fc3(x)
        return x

class LeNet_MF(nn.Module):

  def __init__(self, num_classes=10) -> None:
    super().__init__()
    self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2)
    self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
    self.fc1 = nn.Linear(960, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, num_classes)
    self.maxPool = nn.MaxPool2d(2,2)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.maxPool(F.relu(self.conv1(x)))
    x = self.maxPool(F.relu(self.conv2(x)))
    x = torch.flatten(x, start_dim=1)
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.fc3(x)
    return x

    
def set_train_mode(models):
  for m in models:
    m.train()


def set_eval_mode(models):
  for m in models:
    m.eval()

def set_seeds(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)