import itertools
import os
import json
import random
from argparse import ArgumentParser
from datetime import datetime
import time
from torchvision import transforms
from tqdm import tqdm
import math

import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
from torchvision import models
from PIL import Image
from mipll_pruning_algorithm import structural_pruning

import scallopy
import math
import wandb 
import sys
import logging
import traceback

def exception_handler(exc_type, exc_value, exc_traceback):
    error_msg = f"An uncaught {exc_type.__name__} exception occurred:\n"
    error_msg += f"{exc_value}\n"
    error_msg += "Traceback:\n"
    error_msg += ''.join(traceback.format_tb(exc_traceback))

    logging.error(error_msg)

    print(error_msg, file=sys.stderr)

sys.excepthook = exception_handler

class HWFDataset(torch.utils.data.Dataset):
  def __init__(self, root: str, prefix: str, split: str, max_length: int, exclude_lengths: list = [], num_train_samples: int = None):
    super(HWFDataset, self).__init__()
    self.root = root
    self.split = split
    self.exclude_lengths = exclude_lengths
    self.metadata_og = json.load(open(os.path.join(root, f"HWF/hwf_{max_length}_{split}.json")))
    self.metadata = []
    for sample in self.metadata_og:
      length = len(sample["img_paths"])
      if length not in self.exclude_lengths:
        self.metadata.append(sample)

    if num_train_samples is not None and split == "train":
      self.metadata = random.sample(self.metadata, num_train_samples)

    self.img_transform = torchvision.transforms.Compose([
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize((0.5,), (1,))
    ])

  def __getitem__(self, index):
    sample = self.metadata[index]

    # Input is a sequence of images
    img_seq = []
    for img_path in sample["img_paths"]:
      img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path)
      img = Image.open(img_full_path).convert("L")
      img = self.img_transform(img)
      img_seq.append(img)
    img_seq_len = len(img_seq)

    # Output is the "res" in the sample of metadata
    res = sample["res"]

    # GT is the ground truth label for each image
    gt = tuple(sample["expr"])

    # Return (input, output) pair
    return (img_seq, img_seq_len, res, gt)

  def __len__(self):
    return len(self.metadata)

  @staticmethod
  def collate_fn(batch):
    max_len = max([img_seq_len for (_, img_seq_len, _, _) in batch])
    zero_img = torch.zeros_like(batch[0][0][0])
    pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq))
    img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _, _) in batch])
    img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _, _) in batch])
    results = torch.stack([torch.tensor(res) for (_, _, res, _) in batch])
    gts = [gt for (_, _, _, gt) in batch]
    return (img_seqs, img_seq_len, results, gts)


def hwf_loader(data_dir, batch_size, prefix, formula_length, num_train_samples=None):
  train_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "train", formula_length, exclude_lengths=[1,3,5], num_train_samples=num_train_samples), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "test", formula_length), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True)
  return (train_loader, test_loader)


class SymbolNet(nn.Module):
  def __init__(self):
    super(SymbolNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1)
    self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1)
    self.fc1 = nn.Linear(30976, 128)
    self.fc2 = nn.Linear(128, 14)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.max_pool2d(x, 2)
    x = F.dropout(x, p=0.25, training=self.training)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = F.relu(x)
    x = F.dropout(x, p=0.5, training=self.training)
    x = self.fc2(x)
    return F.softmax(x, dim=1)


class HWFNet(nn.Module):
  def __init__(self, no_sample_k, sample_k, provenance, k, max_length, debug=False):
    super(HWFNet, self).__init__()
    self.no_sample_k = no_sample_k
    self.sample_k = sample_k
    self.provenance = provenance  
    self.debug = debug
    self.max_length = max_length

    # Symbol embedding
    self.symbol_cnn = SymbolNet()

    # Scallop context
    self.scallop_file = "hwf_eval.scl" if not args.do_not_use_hash else "hwf_parser_wo_hash.scl"
    self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"]
    self.symbol_idx_map = { s : i for i, s in enumerate(self.symbols) }
    self.ctx = scallopy.ScallopContext(provenance=provenance, k=k)
    self.ctx.import_file(os.path.abspath(os.path.join(os.path.abspath(__file__), f"../scl/{self.scallop_file}")))
    self.ctx.set_non_probabilistic("length")
    self.ctx.set_input_mapping("symbol", [(i, s) for i in range(max_length) for s in self.symbols])
    if self.debug:
      self.eval_formula = self.ctx.forward_function("result", dispatch="single", debug_provenance=True)
    else:
      self.eval_formula = self.ctx.forward_function("result", jit=args.jit, recompile=args.recompile)

  def forward(self, img_seq, img_seq_len, preimages=None, labels=None):
    batch_size, formula_length, _, _, _ = img_seq.shape
    length = [[(l.item(),)] for l in img_seq_len]
    
    # Compute CNN predictions and accuracy if labels provided
    cnn_correct = 0
    total_symbols = 0
    if labels is not None:
      inp = img_seq.flatten(start_dim=0, end_dim=1)
      symbol_probs = self.symbol_cnn(inp).view(batch_size, -1, 14)
      
      for i in range(batch_size):
        seq_len = img_seq_len[i].item()
        for j in range(seq_len):
          # Get CNN prediction for this position
          cnn_pred = torch.argmax(symbol_probs[i, j])
          
          # Convert ground truth symbol to class index
          gt_symbol = labels[i][j]
          if gt_symbol.isdigit():
            gt_class = int(gt_symbol)
          else:
            # Map operators to classes 10-13
            operator_map = {'+': 10, '-': 11, '*': 12, '/': 13}
            gt_class = operator_map[gt_symbol]
          
          if cnn_pred.item() == gt_class:
            cnn_correct += 1
          total_symbols += 1
    
    if self.no_sample_k: 
      result = self._forward_with_no_sampling(batch_size, img_seq, length, preimages)
    else: 
      result = self._forward_with_sampling(batch_size, formula_length, img_seq, img_seq_len, length, preimages)
    
    if labels is not None:
      return result + (cnn_correct,)
    else:
      return result

  def _forward_with_no_sampling(self, batch_size, img_seq, length, preimages):
    symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, -1)
    (mapping, probs) = self.eval_formula(symbol=symbol, length=length)
    return ([v for (v,) in mapping], probs)

  def _forward_with_sampling(self, batch_size, formula_length, img_seq, img_seq_len, length, preimages):
    symbol = self.symbol_cnn(img_seq.flatten(start_dim=0, end_dim=1)).view(batch_size, formula_length, -1)
    symbol_facts = [[] for _ in range(batch_size)]
    disjunctions = [[] for _ in range(batch_size)]
    for task_id in range(batch_size):
      for symbol_id in range(img_seq_len[task_id]):
        possible_symbol_idx = list(range(len(self.symbols)))
        if preimages is not None:
          preimg = preimages[task_id]
          possible_symbol_idx = set()
          for lbl_vector in preimg:
            s = lbl_vector[symbol_id]
            possible_symbol_idx.add(self.symbol_idx_map[s])
          possible_symbol_idx = list(possible_symbol_idx)
          
        symbols_distr = symbol[task_id, symbol_id]                      # Get the predicted distrubution
        symbols_distr = symbols_distr[possible_symbol_idx]                # Filter the distribution

        categ = torch.distributions.Categorical(symbols_distr)          # Create a categorical distribution
        sample_ids = [k.item() for k in categ.sample((self.sample_k,))] # Sample from this distribution
        sample_ids = list(dict.fromkeys(sample_ids))                    # Deduplicate the ids
        possible_symbols = [self.symbols[k] for k in possible_symbol_idx]


        # Create facts
        curr_symbol_facts = [(symbols_distr[k], (symbol_id, possible_symbols[k])) for k in sample_ids]

        # Generate disjunction from facts
        disjunctions[task_id].append([len(symbol_facts[task_id]) + i for i in range(len(curr_symbol_facts))])
        symbol_facts[task_id] += curr_symbol_facts
    (mapping, probs) = self.eval_formula(symbol=symbol_facts, length=length, disjunctions={"symbol": disjunctions})
    return ([v for (v,) in mapping], probs)


class Trainer():
  def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate, no_sample_k, sample_k, provenance, k, max_length, use_preimages=False):
    self.network = HWFNet(no_sample_k, sample_k, provenance, k, max_length).to(device)
    self.device = device
    if use_preimages:
      self.preimages = {}
      for l in range(1, max_length+1, 2):
        t_begin = time.time()
        print(f"Preimage for length {l}")
        self.build_preimage(l)
        print(f"Time taken: {time.time() - t_begin}")
      self.resnet_transform = transforms.Compose([
      transforms.Lambda(lambda x: x.expand(x.shape[0],3,*x.shape[2:]) ),
      transforms.Resize(224),  # Resize to 224x224
      transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalization
                          std=[0.229, 0.224, 0.225])
      ])
      model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
      for p in model.parameters():
        p.requires_grad = False
      self.feature_extractor = torch.nn.Sequential(*(list(model.children())[:-1]))
      self.feature_extractor = self.feature_extractor.to(self.device)
      self.feature_extractor.eval()
    else:
      self.preimages = None
    self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
    self.train_loader = train_loader
    self.test_loader = test_loader
    self.loss_fn = F.binary_cross_entropy
    self.model_root = model_root
    self.model_name = model_name
    self.min_test_loss = 100000000.0
    self.best_accuracy = 0.0
    self.best_cnn_accuracy = 0.0

  def eval_result_eq(self, a, b, threshold=0.01):
    result = abs(a - b) < threshold
    return result
  
  def build_preimage(self, length):
    if length in self.preimages:
      pass
    else:
      digits = [str(i) for i in range(10)]
      ops = ["+", "-", "*", "/"]
      possibilities = []
      for i in range(length):
        possibilities.append(digits if i % 2 == 0 else ops)
      preimages = {}
      combinations = itertools.product(*possibilities)
      for comb in combinations:
        expr = "".join(comb)
        # evaluate the expression while ignoring division by zero
        try:
          result = eval(expr)
          formatted_result = self.format_result_string(result)
          if formatted_result not in preimages:
            preimages[formatted_result] = []
          preimages[formatted_result].append(comb)
        except ZeroDivisionError:
          continue
      self.preimages[length] = preimages

  def get_preimage(self, length, result):
    # print(result, self.format_result_string(result))
    if length not in self.preimages:
      self.build_preimage(length)
    formatted_result = self.format_result_string(result)
    if formatted_result not in self.preimages[length]:
      print(f"Formatted result {formatted_result} not in preimages: {sorted(list(self.preimages[length].keys()))}")
    preimages = self.preimages[length][formatted_result]

    return preimages

  def format_result_string(self, result):
    # take a number and format it to be a float point number with 4 decimal places
    if isinstance(result, float):
      new_result = f"{result:.4f}"
      if new_result[-5:] == ".0000":
        new_result = new_result[:-5]
      return new_result
    elif isinstance(result, int):
      return f"{result}"
    else:
      raise ValueError(f"Unsupported type {type(result)} for result: {result}")

  def train_epoch(self, epoch):
    self.network.train()
    num_items = 0
    train_loss = 0
    total_correct = 0
    iter = tqdm(self.train_loader, total=len(self.train_loader))
    sample_size = 200
    num_og_preimages = 0
    num_pruned_preimages = 0
    t_begin_total_epoch = time.time()
    for (i, (img_seq, img_seq_len, label, gts)) in enumerate(iter):
      batch_size, formula_length, _, _, _ = img_seq.shape
      if self.preimages is not None:
        preimages = []
        for seq_len, gt, lbl in zip(img_seq_len, gts, label):
          preimg = self.get_preimage(seq_len.item(), lbl.item())
          # sample sample_size preimages randomly

          if sample_size > 0 and len(preimg) > sample_size:
            preimg = random.sample(preimg, sample_size)
          if gt not in preimg:
            preimg += [ gt, ]

          assert gt in preimg, f"GT {gt} not in preimage {preimg}, evals: {eval(''.join(gt))} , preimage evals: {eval(''.join(preimg[0]))}"
          preimages.append(preimg)

        # pruning the preimages
        # first transform the img_seq for resnet
        # first, make it a list of formula_length x batch_size x 1 x 224 x 224
        img_seq_perm = img_seq.permute(1, 0, 2, 3, 4)
        
        images = [ self.resnet_transform(img) for img in img_seq_perm ]
        features = [ self.feature_extractor(img.to(self.device)).tolist() for img in images ]

        images_for_pruning = [ ]
        features_for_pruning = [ ]
        labels_for_pruning = [ ]

        for i in range(batch_size):
          length = img_seq_len[i].item()
          images_for_pruning.append([ img[i] for img in images[:length] ])
          # print(f"Image {i}: {images_for_pruning[-1]}")
          features_for_pruning.append([ feature[i] for feature in features[:length] ])
          labels_for_pruning.append(tuple(gts[i]))

        # prune the preimages
        pruned_preimages = structural_pruning(images_for_pruning, preimages, labels_for_pruning, features_for_pruning, None, args)
        num_og_preimages += sum([ len(preimg) for preimg in preimages ])
        num_pruned_preimages += sum([ len(preimg) for preimg in pruned_preimages ])
        # print(f"Num og preimages: {num_og_preimages}, num pruned preimages: {num_pruned_preimages}")
        preimages = pruned_preimages

      else:
        preimages = None
      
      t_begin_epoch = time.time()
      (output_mapping, y_pred) = self.network(img_seq.to(device), img_seq_len.to(device), preimages=preimages)
      y_pred = y_pred.to("cpu")

      # Normalize label format
      batch_size, num_outputs = y_pred.shape
      y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1)

      # Compute loss
      loss = self.loss_fn(y_pred, y)
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      if not math.isnan(loss.item()):
        train_loss += loss.item()

      # Collect index and compute accuracy
      if num_outputs > 0:
        y_index = torch.argmax(y, dim=1)
        y_pred_index = torch.argmax(y_pred, dim=1)
        correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item()
      else:
        correct_count = 0

      # Stats
      num_items += batch_size
      total_correct += correct_count
      perc = 100. * total_correct / num_items
      avg_loss = train_loss / (i + 1)
      epoch_time = time.time() - t_begin_epoch
      wandb.log({
        "train_time_per_epoch": epoch_time,
        "epoch": epoch,
      })
      # Prints
      iter.set_description(f"[Train Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)")
    total_epoch_time = time.time() - t_begin_total_epoch
    wandb.log(
      {
        "epoch": epoch,
        "total_epoch_time": total_epoch_time,
      }
    )
    print(f"Total Epoch Time: {total_epoch_time}")

  def test_epoch(self, epoch):
    self.network.eval()
    num_items = 0
    test_loss = 0
    total_correct = 0
    total_cnn_correct = 0
    total_symbols = 0
    with torch.no_grad():
      iter = tqdm(self.test_loader, total=len(self.test_loader))
      for i, (img_seq, img_seq_len, label, gts) in enumerate(iter):
        (output_mapping, y_pred, cnn_correct) = self.network(img_seq.to(device), img_seq_len.to(device), labels=gts)
        y_pred = y_pred.to("cpu")

        # Normalize label format
        batch_size, num_outputs = y_pred.shape
        y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in output_mapping]).view(batch_size, -1)

        # Compute loss
        loss = self.loss_fn(y_pred, y)
        if not math.isnan(loss.item()):
          test_loss += loss.item()

        # Collect index and compute accuracy
        if num_outputs > 0:
          y_index = torch.argmax(y, dim=1)
          y_pred_index = torch.argmax(y_pred, dim=1)
          correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size).bool())).item()
        else:
          correct_count = 0

        # Stats
        num_items += batch_size
        total_correct += correct_count
        total_cnn_correct += cnn_correct
        # Count total symbols in this batch
        batch_symbols = sum([length.item() for length in img_seq_len])
        total_symbols += batch_symbols
        
        perc = 100. * total_correct / num_items
        cnn_perc = 100. * total_cnn_correct / total_symbols if total_symbols > 0 else 0.0
        avg_loss = test_loss / (i + 1)

        # Prints
        iter.set_description(f"[Test Epoch {epoch}] Loss: {avg_loss:.4f}, Overall: {total_correct}/{num_items} ({perc:.2f}%), CNN: {total_cnn_correct}/{total_symbols} ({cnn_perc:.2f}%)")
        wandb.log(
          {
            "epoch": epoch,
            "test_accuracy": perc,
            "test_cnn_accuracy": cnn_perc,
            "test_loss": test_loss,
          }
        )

    # Save model
    if test_loss < self.min_test_loss:
      self.min_test_loss = test_loss
      torch.save(self.network, os.path.join(self.model_root, self.model_name))
    
    # Track best accuracies
    if perc > self.best_accuracy:
      self.best_accuracy = perc
    if cnn_perc > self.best_cnn_accuracy:
      self.best_cnn_accuracy = cnn_perc

  def train(self, n_epochs):
    # self.test_epoch(0)
    for epoch in range(1, n_epochs + 1):
      self.train_epoch(epoch)
      self.test_epoch(epoch)


if __name__ == "__main__":
  # Command line arguments
  parser = ArgumentParser("hwf")
  parser.add_argument("--model-name", type=str, default="hwf.pkl")
  parser.add_argument("--n-epochs", type=int, default=100)
  parser.add_argument("--no-sample-k", action="store_true")
  parser.add_argument("--sample-k", type=int, default=10)
  parser.add_argument("--dataset-prefix", type=str, default="expr")
  parser.add_argument("--batch-size", type=int, default=16)
  parser.add_argument("--learning-rate", type=float, default=0.0001)
  parser.add_argument("--loss-fn", type=str, default="bce")
  parser.add_argument("--seed", type=int, default=1234)
  parser.add_argument("--do-not-use-hash", action="store_true")
  parser.add_argument("--provenance", type=str, default="difftopkproofs")
  parser.add_argument("--top-k", type=int, default=3)
  parser.add_argument("--cuda", action="store_true")
  parser.add_argument("--gpu", type=int, default=0)
  parser.add_argument("--jit", action="store_true")
  parser.add_argument("--recompile", action="store_true")
  parser.add_argument("--max-length", type=int, default=7)
  parser.add_argument("--structured-pruning", action="store_true")
  parser.add_argument('--structure-k', help='knn', default=10, type=int)
  parser.add_argument('--percent', help='knn', default=0.001, type=int)
  parser.add_argument("--mock_proximity", default = False, action="store_true")
  parser.add_argument("--num-training-samples", type=int, default=None)
  args = parser.parse_args()

  # Parameters
  torch.manual_seed(args.seed)
  random.seed(args.seed)
  if args.cuda:
    if torch.cuda.is_available(): device = torch.device(f"cuda:{args.gpu}")
    else: raise Exception("No cuda available")
  else: device = torch.device("cpu")

  config = {
    "hwf_n": args.max_length,
    "n_epochs": args.n_epochs,
    "batch_size": args.batch_size, 
    "provenance": args.provenance,
    "seed": args.seed,
    "experiment_type": "scallop", 
  }

  timestamp = datetime.now()
  id = f'scallop_hwf{args.max_length}_{args.seed}_{args.provenance}_{timestamp.strftime("%Y-%m-%d %H-%M-%S")}'


  wandb.init(
    project="HWF-N", config=config, id=id, mode='disabled'
  )

  wandb.define_metric("epoch")
  wandb.define_metric("total_epoch_time")
  wandb.define_metric("train_time_per_epoch", step_metric="epoch", summary="mean")
  wandb.define_metric("test_accuracy", step_metric="epoch", summary="max")
  wandb.define_metric("test_cnn_accuracy", step_metric="epoch", summary="max")
  wandb.define_metric("test_loss", step_metric="epoch", summary="min")

  # Data
  data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data"))
  model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf"))
  if not os.path.exists(model_dir): os.makedirs(model_dir)
  train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix, formula_length=args.max_length, num_train_samples=args.num_training_samples)

  # Training
  trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate, args.no_sample_k, args.sample_k, args.provenance, args.top_k, args.max_length, use_preimages=args.structured_pruning)
  trainer.train(args.n_epochs)
  print(f"Best accuracy: {trainer.best_accuracy:.2f}%, Best CNN accuracy: {trainer.best_cnn_accuracy:.2f}%")
