import itertools
import dolphin
import wandb
from datetime import datetime
import os
import time
import random
from typing import *

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from argparse import ArgumentParser
from tqdm import tqdm

from dolphin.distribution import PreImage
from imbalanced_datasets.mipll_datasets import MIPLL_Dataset
from torchql import Database
from purification import purification

from dolphin import Distribution
from dolphin.provenances import get_provenance
from dolphin import func

from imbalanced_datasets.load_imbalanced_arithmetic import load_arithmetic_n

import numpy as np
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape((-1,)).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

class AccurracyShot(object):
    def __init__(self, train_class_count, num_class, many_shot_num=3, low_shot_num=3):
        self.train_class_count = train_class_count
        self.test_class_count = None
        self.num_class = num_class
        self.many_shot_thr = train_class_count.sort()[0][num_class - many_shot_num - 1]
        self.low_shot_thr = train_class_count.sort()[0][low_shot_num]

    def get_shot_acc(self, preds, labels, acc_per_cls=False):
        if self.test_class_count is None:
            self.test_class_count = []
            for l in range(self.num_class):
                self.test_class_count.append(len(labels[labels == l]))

        class_correct = []
        for l in range(self.num_class):
            class_correct.append((preds[labels == l] == labels[labels == l]).sum())

        many_shot = []
        median_shot = []
        low_shot = []
        for i in range(self.num_class):
            if self.train_class_count[i] > self.many_shot_thr:
                many_shot.append((class_correct[i] / float(self.test_class_count[i])))
            elif self.train_class_count[i] < self.low_shot_thr:
                low_shot.append((class_correct[i] / float(self.test_class_count[i])))
            else:
                median_shot.append((class_correct[i] / float(self.test_class_count[i])))

        if len(many_shot) == 0:
            many_shot.append(0)
        if len(median_shot) == 0:
            median_shot.append(0)
        if len(low_shot) == 0:
            low_shot.append(0)

        if acc_per_cls:
            class_accs = [
                c / cnt for c, cnt in zip(class_correct, self.test_class_count)
            ]
            return (
                np.mean(many_shot) * 100,
                np.mean(median_shot) * 100,
                np.mean(low_shot) * 100,
                class_accs,
            )
        else:
            return (
                np.mean(many_shot) * 100,
                np.mean(median_shot) * 100,
                np.mean(low_shot) * 100,
            )

class MNISTNet(nn.Module):
  def __init__(self):
    super(MNISTNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
    self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
    self.fc1 = nn.Linear(1024, 1024)
    self.fc2 = nn.Linear(1024, 10)

  def forward(self, x):
    x = F.max_pool2d(self.conv1(x), 2)
    x = F.max_pool2d(self.conv2(x), 2)
    x = x.view(-1, 1024)
    x = F.relu(self.fc1(x))
    x = F.dropout(x, p = 0.5, training=self.training)
    x = self.fc2(x)
    return F.softmax(x, dim=1)

class MNISTSumNNet(nn.Module):
  def __init__(self, dataset = "msum"):
    super(MNISTSumNNet, self).__init__()

    # MNIST Digit Recognition Network
    self.mnist_net = MNISTNet()
    self.dataset = dataset
    self.lam = None
    if dataset == "mmax":
       self.lam = lambda x, y : max(x, y)
    else:
       self.lam = lambda x, y : x + y

  def forward(self, x: Tuple[torch.Tensor, torch.Tensor], labels=None, pre_image = None):
    if labels is not None:
      labels = torch.tensor(labels).to(x[0].device)
      correct = 0
    for i in range(len(x)):
      if PreImage.record_preimage:
        probs = torch.rand((1, 10), device="mps")
        print("Processing preimage for sums up to", i+1, "digits")
      else:
        probs = self.mnist_net(x[i])
        if labels is not None:
          preds = probs.data.max(1, keepdim=True)[1]
          correct += preds.eq(labels[:,i].data.view_as(preds)).sum()

      if pre_image is None:
        digits = range(10)
        a = Distribution(probs, digits)
        if i == 0:
          res = a
        else:
          res = res.apply(a, self.lam)
      else:
        batch_size = len(pre_image)
        # print(pre_image)
        a = []
        
        for idx in range(batch_size):
          label_vector = pre_image[idx]
          digits = list(set(( label[i] for label in label_vector )))
          digits.sort()
          a.append(Distribution(probs[idx, :].view(1, -1), range(10)).map_symbols(digits))
        
        a = Distribution.stack(a)
        if i == 0:
            res = a
        else:
          res = res.apply(a, self.lam)
          # res = res + a

    if PreImage.record_preimage:
      return res.preimage
    if labels is None:
      return res.map_symbols(range(len(x) * 9 + 1)).get_probabilities() # Tensor b x 19
    else:
      return res.map_symbols(range(len(x) * 9 + 1)).get_probabilities(), correct

def bce_loss(output, ground_truth):
  gt = torch.nn.functional.one_hot(ground_truth, num_classes=(M*9)+1).float()
  return F.binary_cross_entropy(output, gt)

def nll_loss(output, ground_truth):
  return F.nll_loss(output, ground_truth)

class Trainer():
  def __init__(self, train_loader, test_loader, model_dir, learning_rate, loss, provenance, device, k, M, acc_shot, with_purification=False, with_preimage=False, synthetic=None, dataset="msum"):
    self.with_purification = with_purification
    self.device = device
    self.model_dir = model_dir
    Distribution.provenance = get_provenance(provenance)
    Distribution.provenance.k = k
    self.db = Database()
    self.network = MNISTSumNNet(dataset=dataset).to(self.device)
    self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
    self.train_loader = train_loader
    self.test_loader = test_loader
    self.best_loss = 10000000000
    self.provenance = provenance
    self.M = M
    self.best_acc = 0
    self.acc_shot = acc_shot
    self.max_acc = 0
    self.max_acc5 = 0
    self.synthetic = synthetic
    self.dataset = dataset
    
    # Store chosen loss function
    if loss == "nll":
      self.loss = nll_loss
    elif loss == "bce":
      self.loss = bce_loss
    else:
      raise Exception(f"Unknown loss function `{loss}`")

    # -----------------------------------------
    # 1) Added list to track each epoch's time
    # -----------------------------------------
    self.epoch_times = []
  
  def get_preimage(self):
    if self.dataset == "msum":
      with PreImage() as record_preimage:
        print("Record Preimage:", record_preimage.record_preimage)
        preimage = self.network(list(range(self.M)))
        for key, value in preimage.items():
          preimage[key] = [ list(v.values()) for v in value ]
        return preimage
    elif self.dataset == "mmax":
      digits = list(range(10))
      ip = [digits for _ in range(self.M)]

      # all possible combinations of the digits
      # use itertools
      all_combinations = list(itertools.product(*ip))

      # getting the preimage for max
      preimage = {}
      for comb in all_combinations:
        max_val = max(comb)
        if max_val not in preimage:
            preimage[max_val] = []
        preimage[max_val].append(list(comb))
      return preimage



  def train_epoch(self, epoch, dataloader=None):
    if dataloader is None:
      dataloader = self.train_loader
    self.network.train()
    iter = tqdm(dataloader, total=len(dataloader))
    t_begin_epoch = time.time()
    for batch in iter:
      if len(batch) == 4:
        data, target, labels, pre_image = batch
      else:
        data, target = batch
        pre_image = None
      imgs = ()
      for x in range(self.M):
        imgs = imgs + (data[x].to(self.device),)
      target = target.to(self.device)
      self.optimizer.zero_grad()
      
      output = self.network(imgs, pre_image=pre_image)
      loss = self.loss(output, target)
      loss.backward()
      self.optimizer.step()
      iter.set_description(f"[Train Epoch {epoch}] Loss: {loss.item():.4f}")
        
    total_epoch_time = time.time() - t_begin_epoch
    wandb.log(
      {
        "epoch": epoch,
        "total_epoch_time": total_epoch_time,
      }
    )
    print(f"Total Epoch Time: {total_epoch_time}")
    print("Max memory allocated:", torch.cuda.max_memory_allocated() / 1024 / 1024)
    
    # ------------------------------------------------------------------
    # 2) Append this epoch's time to our list for computing average later
    # ------------------------------------------------------------------
    self.epoch_times.append(total_epoch_time)

  def test_epoch(self, epoch):
    self.network.eval()
    num_items = len(self.test_loader.dataset)
    test_loss = 0
    correct = 0
    with torch.no_grad():
      iter = tqdm(self.test_loader, total=len(self.test_loader))
      for (data, target) in iter:
        imgs = ()
        for x in range(M):
          imgs = imgs + (data[x].to(self.device),)
        target = target.to(self.device)

        output = self.network(imgs)
        test_loss += self.loss(output, target).item()
        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()
        perc = 100. * correct / num_items
        if perc > 97.00:
          # record M + epoch number combination when accuracy is high
          file_path = f'torch_mnist_sum_n_{self.provenance}_epoch_count.log'
          current_timestamp = datetime.now()
          formatted_timestamp = current_timestamp.strftime("%Y-%m-%d %H:%M:%S")
          if os.path.exists(file_path):
            with open(file_path, 'a') as file:
              file.write(f'sum n={self.M}, epoch num={epoch}, {formatted_timestamp}\n')
          else:
             with open(file_path, 'w') as file:
              file.write(f'sum n={self.M}, epoch num={epoch}, {formatted_timestamp}\n')
        iter.set_description(f"[Test Epoch {epoch}] Total loss: {test_loss:.4f}, Accuracy: {correct}/{num_items} ({perc:.2f}%)")
        wandb.log(
          {
            "epoch": epoch,
            "accuracy": perc,
          }
        )
      if test_loss < self.best_loss:
        self.best_loss = test_loss
      if perc > self.best_acc:
        self.best_acc = perc
      print(f"Best loss: {self.best_loss:.4f}")
      print(f"Best acc: {self.best_acc:.2f}%")

  def test(self,acc_shot, network, test_loader, verbose=True):
    network.eval()
    with torch.no_grad():
        if verbose:
            print("==> Evaluation...")
        pred_list = []
        true_list = []
        for _, (images, labels) in enumerate(test_loader):
            images = images.to(self.device)
            outputs = network(images)
            pred = F.softmax(outputs, dim=1)
            pred_list.append(pred.cpu())
            true_list.append(labels)

        pred_list = torch.cat(pred_list, dim=0)
        true_list = torch.cat(true_list, dim=0)

        acc1, acc5 = accuracy(pred_list, true_list, topk=(1, 5))
        acc_many, acc_med, acc_few, class_accs = acc_shot.get_shot_acc(
            pred_list.max(dim=1)[1], true_list, acc_per_cls=True
        )
        if verbose:
            print(
                "==> Test Accuracy is %.2f%% (%.2f%%), [%.2f%%, %.2f%%, %.2f%%]"
                % (acc1, acc5, acc_many, acc_med, acc_few)
            )
            print("==> Class-specific accuracies are", class_accs)
        if acc1 > self.max_acc:
            self.max_acc = acc1
            self.max_acc5 = acc5

    return float(acc1), float(acc_many), float(acc_med), float(acc_few)
  
  def train(self, n_epochs):
    self.test(self.acc_shot,  self.network.mnist_net, self.test_loader)
    for epoch in range(1, n_epochs + 1):
      self.train_epoch(epoch)
      self.test(self.acc_shot,  self.network.mnist_net, self.test_loader)
    
    # -----------------------------------------------------------------
    # 3) Print out the average epoch time after all epochs are complete
    # -----------------------------------------------------------------
    if len(self.epoch_times) > 0:
      average_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
      print(f"Average Epoch Time: {average_epoch_time}")

  def train_w_purification(self, n_epochs, warmup_epochs=1):

    # getting the preimage
    preimage = self.get_preimage()

    # instantiate the preimage dataset

    train_dataset = self.train_loader.dataset
    train_w_preimage = MIPLL_Dataset(train_dataset.x, train_dataset.y,
                                     train_dataset.s, train_dataset.l,
                                     train_dataset.weak_transform,
                                     preimage=preimage)
    
    train_w_preimage_loader = torch.utils.data.DataLoader(
      train_w_preimage,
      collate_fn=MIPLL_Dataset.collate_fn,
      batch_size=self.train_loader.batch_size,
      shuffle=True
    )

    self.test(self.acc_shot,  self.network.mnist_net, self.test_loader)      


    eps = 0.999
    eps_max = 1
    eps_step = 0.01

    def get_eps_step(eps):
      if eps <= 0.015:
        return 0.001
      elif eps <= 0.15:
        return 0.01
      elif eps <= 0.5:
        return 0.1
      else:
        return 0.25
      

    wandb.log(
      {
        "eps": eps,
        "eps_max": eps_max,
        "eps_step": eps_step,
      }
    )

    # for warmup_epoch in range(1, warmup_epochs + 1):
    warmup_epoch = 1
    converged_epoch = 0
    best_acc = 0
    best_model = None
    best_acc_class = 0
    while True:
      print(f"Warmup Epoch {warmup_epoch}")
      self.train_epoch(warmup_epoch, dataloader=train_w_preimage_loader)
      # self.train_epoch(warmup_epoch, dataloader=self.train_loader)
      self.test(self.acc_shot,  self.network.mnist_net, self.test_loader)
      if self.max_acc > best_acc:
        best_acc = self.max_acc
        best_acc_class = self.max_acc5
        converged_epoch = warmup_epoch
        best_model = self.network.state_dict()
      if warmup_epoch - converged_epoch > 3:
        break
      warmup_epoch += 1
    print("MODEL CONVERGED AT EPOCH", warmup_epoch, "WITH ACCURACY", best_acc, "AND ACCURACY CLASS", best_acc_class)
    self.network.load_state_dict(best_model)
    
    # for epoch in range(1, n_epochs + 1):
    epoch = 1
    max_acc_epoch = 0
    max_acc = 0
    while True:
        # self.train_epoch(epoch, self.train_loader)
        # if epoch < warmup_epochs:
        #     self.train_epoch(epoch, dataloader=self.train_loader)
        # else:
        # Purification
        if self.with_purification:
            if epoch != n_epochs:
                train_w_preimage_purified = purification(train_w_preimage, eps, eps_max, MIPLL_Dataset.collate_fn, self.train_loader.batch_size, self.M, self.network.mnist_net, self.device, synthetic=self.synthetic)
                train_w_preimage_loader = torch.utils.data.DataLoader(
                    train_w_preimage_purified,
                    collate_fn=MIPLL_Dataset.collate_fn,
                    batch_size=self.train_loader.batch_size,
                    shuffle=True
                )
        self.train_epoch(epoch, dataloader=train_w_preimage_loader)
        # self.test_epoch(epoch)
        self.test(self.acc_shot,  self.network.mnist_net, self.test_loader)
        if self.max_acc > max_acc:
            max_acc = self.max_acc
            max_acc_epoch = epoch
        if epoch - max_acc_epoch > 3:
            break
        epoch += 1

    print("Converged at epoch", converged_epoch, "with accuracy", f"{float(best_acc):.2f} ({float(best_acc_class):.2f})")
    print("After purification, the model achieved", self.max_acc, "accuracy")

if __name__ == "__main__":
  # Argument parser
  parser = ArgumentParser()
  parser.add_argument("--n-epochs", type=int, default=15)
  parser.add_argument("--batch-size-train", type=int, default=64)
  parser.add_argument("--batch-size-test", type=int, default=64)
  parser.add_argument("--learning-rate", type=float, default=0.001)
  parser.add_argument("--loss-fn", type=str, default="bce")
  parser.add_argument("--seed", type=int, default=3576)
  parser.add_argument("--provenance", type=str, default="damp", choices=['damp', 'dmmp', 'dtkp-am'])
  parser.add_argument("--device", type=str, default="mps", choices=["cpu", "cuda", "mps"])
  parser.add_argument("--topk", type=int, default=3)
  parser.add_argument("--size", default=1000, type=int, help="number of partial training samples")
  parser.add_argument("--M", default=5, type=int, help="number of input instances per training sample")
  # number of inputs per instance

  parser.add_argument("--dataset", default="msum",type=str, choices=["mmax", "msum"])
  # try both msum and mmax

  parser.add_argument("--num-class", default=10, type=int, help="number of class")
  parser.add_argument("--imb_type", default="exp", choices=["exp", "expr", "step", "original"], help="imbalance data type")
  # try both exp and expr

  parser.add_argument("--imb_ratio", default=15, type=float, help="imbalance ratio for long-tailed dataset generation")
  # try 5, 15, 50

  parser.add_argument("--imb_test", default=False, help="use imbalanced test set")
  # try both True and False

  parser.add_argument("--data-dir", default="../../data/pre-processed-data", type=str, help="experiment directory for loading pre-generated data")

  parser.add_argument("--only_preimage", action="store_true")
  
  parser.add_argument("--with_preimage", action="store_true")
  
  parser.add_argument("--with_purification", action="store_true")
  
  parser.add_argument("--num-synthetic-proofs", type=int, default=None)
  
  
  args = parser.parse_args()
  args.imb_factor = 1.0 / args.imb_ratio
  print(args)

  # Parameters
  M = args.M
  n_epochs = args.n_epochs
  batch_size_train = args.batch_size_train
  batch_size_test = args.batch_size_test
  learning_rate = args.learning_rate
  loss_fn = args.loss_fn
  provenance = args.provenance
  torch.manual_seed(args.seed)
  random.seed(args.seed)

  #TODO Update config object 
  config = {
    "M": M,
    "n_epochs": n_epochs,
    "batch_size_train": batch_size_train, 
    "batch_size_test": batch_size_test,
    "provenance": provenance,
    "seed": args.seed,
    "experiment_type": "dolphin", 
  }

  # Data
  data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data"))
  model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../model/mnist_sum_{M}'))
  os.makedirs(model_dir, exist_ok=True)

  # Dataloaders
  train_loader, test_loader, train_label_cnt = load_arithmetic_n(args) 

  best_acc = 0
  many_shot_num = 3
  low_shot_num = 3

  acc_shot = AccurracyShot(
      train_label_cnt, args.num_class, many_shot_num, low_shot_num
  )

  if args.device == "cuda" and torch.cuda.is_available():
    device_name = "cuda"
  elif args.device == "mps" and torch.backends.mps.is_available():
    device_name = "mps"
  else:
    device_name = "cpu"

  device = torch.device(device_name)
  timestamp = datetime.now()
  id = f'dolphin_sum{M}_{args.seed}_{provenance}_{timestamp.strftime("%Y-%m-%d %H-%M-%S")}'

  wandb.init(
    project="WIP", config=config, id=id, mode="disabled"
  )
  # Create trainer and train
  trainer = Trainer(train_loader, test_loader, model_dir, learning_rate, loss_fn, provenance, device, args.topk, M, acc_shot, with_purification=args.with_purification, with_preimage=args.with_preimage, synthetic=args.num_synthetic_proofs, dataset=args.dataset)
  if not args.only_preimage:
    if args.with_purification or args.with_preimage:
      trainer.train_w_purification(n_epochs, warmup_epochs=-1)
    else:
      trainer.train(n_epochs)
  else:
    trainer.get_preimage()

  print("Best Accuracy:", f"{float(trainer.max_acc):.2f} ({float(trainer.max_acc5):.2f})")
