from collections import defaultdict
import itertools
import logging
import os
import json
import random
from argparse import ArgumentParser
from time import time
from tqdm import tqdm
import datetime
import math
import wandb

import torch
from torch import nn, optim
import torch.nn.functional as F
import torchvision
from torchvision import transforms, models
import numpy as np
from PIL import Image
from mipll_pruning_algorithm import structural_pruning

from torchql.database import Database
from torchql.query import Query
from dolphin.provenances import get_provenance
from dolphin.distribution import Distribution

logger = logging.getLogger("HWFLogger")

if "stats" not in logger.__dict__:
  logger.stats = defaultdict(list)

symbolic_logger = logging.getLogger("dolphin")

class HWFDataset(torch.utils.data.Dataset):
  def __init__(self, root: str, prefix: str, split: str, l, exact_length=False, num_train_samples=None):
    super(HWFDataset, self).__init__()
    self.root = root
    self.split = split
    md = json.load(open(os.path.join(root, f"HWF/{prefix}_{split}.json")))
    # Filter metadata based on exact length or up-to length
    if l > 0:
      if exact_length:
        self.metadata = [m for m in md if len(m['img_paths']) == l]
        print(f"Using exact length filtering: {len(self.metadata)} samples with length exactly {l}")
      else:
        self.metadata = [m for m in md if len(m['img_paths']) <= l]
        print(f"Using up-to length filtering: {len(self.metadata)} samples with length up to {l}")
    else:
      self.metadata = md
      print(f"Using all samples: {len(self.metadata)} samples")
    
    # Limit training samples if specified
    if num_train_samples is not None and split == "train":
      original_count = len(self.metadata)
      self.metadata = random.sample(self.metadata, min(num_train_samples, len(self.metadata)))
      print(f"Limited training samples: {len(self.metadata)}/{original_count} samples")

    self.img_transform = torchvision.transforms.Compose([
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize((0.5,), (1,))
    ])

  def __getitem__(self, index):
    sample = self.metadata[index]

    # Input is a sequence of images
    img_seq = []
    for img_path in sample["img_paths"]:
      img_full_path = os.path.join(self.root, "HWF/Handwritten_Math_Symbols", img_path)
      img = Image.open(img_full_path).convert("L")
      img = self.img_transform(img)
      img_seq.append(img)
    img_seq_len = len(img_seq)

    # Output is the "res" in the sample of metadata
    res = sample["res"]

    # GT is the ground truth label for each image
    gt = tuple(sample["expr"])

    # Return (input, output) pair
    return (img_seq, img_seq_len, res, gt)

  def __len__(self):
    return len(self.metadata)

  @staticmethod
  def collate_fn(batch):
    max_len = max([img_seq_len for (_, img_seq_len, _, _) in batch])
    zero_img = torch.zeros_like(batch[0][0][0])
    pad_zero = lambda img_seq: img_seq + [zero_img] * (max_len - len(img_seq))
    img_seqs = torch.stack([torch.stack(pad_zero(img_seq)) for (img_seq, _, _, _) in batch])
    img_seq_len = torch.stack([torch.tensor(img_seq_len).long() for (_, img_seq_len, _, _) in batch])
    results = torch.stack([torch.tensor(res) for (_, _, res, _) in batch])
    gts = [gt for (_, _, _, gt) in batch]
    return (img_seqs, img_seq_len, results, gts)


def hwf_loader(data_dir, batch_size, prefix, l, exact_length=False, num_train_samples=None):
  train_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "train", l, exact_length=exact_length, num_train_samples=num_train_samples), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True)
  test_loader = torch.utils.data.DataLoader(HWFDataset(data_dir, prefix, "test", l, exact_length=exact_length), collate_fn=HWFDataset.collate_fn, batch_size=batch_size, shuffle=True)
  return (train_loader, test_loader)

class SymbolNet(nn.Module):
  def __init__(self):
    super(SymbolNet, self).__init__()
    self.conv1 = nn.Conv2d(1, 32, 3, stride = 1, padding = 1)
    self.conv2 = nn.Conv2d(32, 64, 3, stride = 1, padding = 1)
    self.fc1 = nn.Linear(30976, 128)
    self.fc1_bn = nn.BatchNorm1d(128)
    self.fc2 = nn.Linear(128, 14)

  def forward(self, x):
    x = self.conv1(x)
    x = F.relu(x)
    x = self.conv2(x)
    x = F.max_pool2d(x, 2)
    x = F.dropout(x, p=0.25, training=self.training)
    x = torch.flatten(x, 1)
    x = self.fc1(x)
    x = self.fc1_bn(x)
    x = F.relu(x)
    x = F.dropout(x, p=0.5, training=self.training)
    x = self.fc2(x)
    return F.softmax(x, dim=1)

class HWFNet(nn.Module):
  def __init__(self, provenance, k, sample_k, debug=False):
    super(HWFNet, self).__init__()

    # Symbol embedding
    self.symbol_cnn = SymbolNet()
    self.operators = ["+", "-", "*", "/"]
    self.symbols = [str(i) for i in range(10)] + ["+", "-", "*", "/"]

    Distribution.provenance = get_provenance(provenance)
    Distribution.provenance.k = k
    Distribution.k = sample_k

  def forward(self, img_seq, img_seq_len, db):
    batch_size, formula_length, _, _, _ = img_seq.shape
    length = [l.item() for l in img_seq_len]

    inp = img_seq.flatten(start_dim=0, end_dim=1)
    
    t = time()
    symbol = self.symbol_cnn(inp).view(batch_size, -1, 14)
    logger.stats["T_SymbolCNN"].append(time() - t)

    def eval_formula(s):
      try:
        return eval(s)
      except:
        return 0

    def infer_expression(length, *symbols):
      t = time()
      res = symbols[0]
      for i in range(1, len(symbols)):
        # print("RES",res, "SYMBOL ADDED", symbols[i])
        res += symbols[i]

      x = (res.map(eval_formula), )
      logger.stats["T_Infer"].append(time() - t)
      return x

    def reorg(symbols, lengths):
      t = time()
      distrs = []
      default = torch.zeros(symbols.shape[1], device=symbols.device)
      default[[self.symbols.index("+"), self.symbols.index("0")]] = 1
      for i in range(symbol.shape[1]):
        if i < lengths:
          distrs.append(Distribution(symbols[i, :].view(-1, 14), self.symbols))
        else:
          distrs.append(Distribution(default.view(-1, 14), self.symbols))
          # distrs.append(Distribution(torch.ones(1, device=device), ["",]))
        if i % 2 == 0:
          distrs[-1] = distrs[-1].filter(lambda s : s not in self.operators)
        else:
          distrs[-1] = distrs[-1].filter(lambda s : s in self.operators)

      res = (lengths, *distrs)
      logger.stats["T_Reorg"].append(time() - t)
      return res

    q = Query("hwf", base="symbols").join("lengths").project(lambda symbols, lengths: reorg(symbols, lengths)) \
      .project(infer_expression, batch_size=batch_size)

    t = time()
    res = q(db, tensors={"symbols": symbol, "lengths": length})

    stacked = Distribution.stack(res.rows)
    logger.stats["T_Query"].append(time() - t)
    return stacked

    # default = torch.zeros((batch_size, 1, symbol.shape[-1]), device=symbol.device)
    # default[:, 0, [self.symbols.index("+"), self.symbols.index("0")]] = 1
    # symbol1 = torch.cat([symbol, default], dim=1)

    # digits = []
    # for i in range(symbol.shape[1]):
    #   length_mask = [i if i < l else -1 for l in length]
    #   d = Distribution(symbol1[torch.arange(batch_size), length_mask], np.array(self.symbols))
    #   d = d.filter(lambda s : (s not in self.operators) if i % 2 == 0 else (s in self.operators))
    #   digits.append(d)

    # res = digits[0]
    # for i in range(1, len(digits)):
    #   res += digits[i]

    # x = res.map(eval_formula)

    # return [x]

class Trainer():
  def __init__(self, train_loader, test_loader, device, model_root, model_name, learning_rate, provenance, k, sample_k, max_length, step_size=10, gamma=0.1, use_preimages=False, args=None):
    self.network = HWFNet(provenance, k, sample_k).to(device)
    self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate) #, weight_decay=0.01)
    self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=gamma)
    self.db = Database()
    self.train_loader = train_loader
    self.test_loader = test_loader
    self.device = device
    self.loss_fn = F.binary_cross_entropy
    self.model_root = model_root
    self.model_name = model_name
    self.min_test_loss = 100000000.0
    self.max_length = max_length
    self.args = args
    
    if use_preimages:
      self.preimages = {}
      for l in range(1, max_length+1, 2):
        t_begin = time()
        print(f"Preimage for length {l}")
        self.build_preimage(l)
        print(f"Time taken: {time() - t_begin}")
      self.resnet_transform = transforms.Compose([
        transforms.Lambda(lambda x: x.expand(x.shape[0],3,*x.shape[2:]) ),
        transforms.Resize(224),  # Resize to 224x224
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # ImageNet normalization
                            std=[0.229, 0.224, 0.225])
      ])
      model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
      for p in model.parameters():
        p.requires_grad = False
      self.feature_extractor = torch.nn.Sequential(*(list(model.children())[:-1]))
      self.feature_extractor = self.feature_extractor.to(self.device)
      self.feature_extractor.eval()
    else:
      self.preimages = None

  def eval_result_eq(self, a, b, threshold=0.01):
    if a is None or b is None:
      return False
    result = abs(a - b) < threshold
    return result

  def retrieve_y(self, label, s, threshold=0.01):
    num_labels = len(label)
    num_symbols = len(s)

    label_2d = label.view(-1, 1).expand(-1, num_symbols)
    symbols_2d = s.view(1, -1).expand(num_labels, -1)

    return (torch.abs(label_2d - symbols_2d) < threshold).float()
    
  def build_preimage(self, length):
    if length in self.preimages:
      pass
    else:
      digits = [str(i) for i in range(10)]
      ops = ["+", "-", "*", "/"]
      possibilities = []
      for i in range(length):
        possibilities.append(digits if i % 2 == 0 else ops)
      preimages = {}
      combinations = itertools.product(*possibilities)
      for comb in combinations:
        expr = "".join(comb)
        # evaluate the expression while ignoring division by zero
        try:
          result = eval(expr)
          formatted_result = self.format_result_string(result)
          if formatted_result not in preimages:
            preimages[formatted_result] = []
          preimages[formatted_result].append(comb)
        except ZeroDivisionError:
          continue
      self.preimages[length] = preimages

  def get_preimage(self, length, result):
    if length not in self.preimages:
      self.build_preimage(length)
    formatted_result = self.format_result_string(result)
    if formatted_result not in self.preimages[length]:
      print(f"Formatted result {formatted_result} not in preimages: {sorted(list(self.preimages[length].keys()))}")
    preimages = self.preimages[length][formatted_result]
    return preimages

  def format_result_string(self, result):
    # take a number and format it to be a float point number with 4 decimal places
    if isinstance(result, float):
      new_result = f"{result:.4f}"
      if new_result[-5:] == ".0000":
        new_result = new_result[:-5]
      return new_result
    elif isinstance(result, int):
      return f"{result}"
    else:
      raise ValueError(f"Unsupported type {type(result)} for result: {result}")

  def train_epoch(self, epoch):
    self.network.train()
    num_items = 0
    train_loss = 0
    total_correct = 0
    iter = tqdm(self.train_loader, total=len(self.train_loader))

    logger.debug(f"Epoch {epoch}")
    logger.debug(f"Learning rate: {self.scheduler.get_last_lr()}")

    t_begin_epoch = time()
    sample_size = 200
    num_og_preimages = 0
    num_pruned_preimages = 0
    
    for (i, (img_seq, img_seq_len, label, gts)) in enumerate(iter):
      t_begin = time()
      self.optimizer.zero_grad()
      
      batch_size, formula_length, _, _, _ = img_seq.shape
      
      # Preimage processing for structural pruning
      if self.preimages is not None:
        preimages = []
        for seq_len, gt, lbl in zip(img_seq_len, gts, label):
          preimg = self.get_preimage(seq_len.item(), lbl.item())
          
          # Sample preimages randomly if too many
          if sample_size > 0 and len(preimg) > sample_size:
            preimg = random.sample(preimg, sample_size)
          if gt not in preimg:
            preimg += [gt, ]

          assert gt in preimg, f"GT {gt} not in preimage {preimg}"
          preimages.append(preimg)

        # Structural pruning preprocessing
        # Transform images for ResNet
        img_seq_perm = img_seq.permute(1, 0, 2, 3, 4)
        
        images = [self.resnet_transform(img) for img in img_seq_perm]
        features = [self.feature_extractor(img.to(self.device)).tolist() for img in images]

        images_for_pruning = []
        features_for_pruning = []
        labels_for_pruning = []

        for batch_idx in range(batch_size):
          length = img_seq_len[batch_idx].item()
          images_for_pruning.append([img[batch_idx] for img in images[:length]])
          features_for_pruning.append([feature[batch_idx] for feature in features[:length]])
          labels_for_pruning.append(tuple(gts[batch_idx]))

        # Perform structural pruning
        pruned_preimages = structural_pruning(images_for_pruning, preimages, labels_for_pruning, features_for_pruning, None, self.args)
        
        # Track pruning statistics
        batch_og_preimages = sum([len(preimg) for preimg in preimages])
        batch_pruned_preimages = sum([len(preimg) for preimg in pruned_preimages])
        num_og_preimages += batch_og_preimages
        num_pruned_preimages += batch_pruned_preimages
        
        preimages = pruned_preimages
      else:
        preimages = None

      t = time()
      img_seq, img_seq_len, label = img_seq.to(device), img_seq_len.to(device), label.to(device)
      d = self.network(img_seq, img_seq_len, self.db)[0]
      
      s, y_pred = d.symbols, d.get_probabilities()
      logger.stats['T_Forward'].append(time() - t)

      if len(y_pred.shape) == 1:
        y_pred = y_pred.view(1, -1)
      batch_size, num_outputs = y_pred.shape
      
      t = time()
      y = self.retrieve_y(label, torch.tensor(s.astype(float), device=device))
      logger.stats['T_CreateY'].append(time() - t)

      # Compute loss
      t = time()
      loss = self.loss_fn(y_pred, y)
      loss.backward()
      self.optimizer.step()
      if not math.isnan(loss.item()):
        train_loss += loss.item()
      logger.stats['T_Backward'].append(time() - t)

      # Collect index and compute accuracy
      t = time()
      if num_outputs > 0:
        y_index = torch.argmax(y, dim=1)
        y_pred_index = torch.argmax(y_pred, dim=1)
        correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size, device=device).bool())).item()
      else:
        correct_count = 0

      # Stats
      num_items += batch_size
      total_correct += correct_count
      perc = 100. * total_correct / num_items
      avg_loss = train_loss / (i + 1)

      # Prints
      logger.stats['T_Total'].append(time() - t_begin)
      
      # Create progress bar description with optional pruning stats
      desc = f"[Train Epoch {epoch}] Loss: {avg_loss:.4f}, LR: {self.scheduler.get_lr()}, Acc: {total_correct}/{num_items} ({perc:.2f}%)"
      if self.preimages is not None and num_og_preimages > 0:
        current_pruning_ratio = (num_og_preimages - num_pruned_preimages) / num_og_preimages
        desc += f", Pruned: {num_og_preimages - num_pruned_preimages:,}/{num_og_preimages:,} ({current_pruning_ratio:.1%})"
      iter.set_description(desc)
      wandb.log({"epoch": epoch, "train/loss": loss})

    # Print epoch-level pruning statistics
    if self.preimages is not None and num_og_preimages > 0:
      total_pruning_ratio = (num_og_preimages - num_pruned_preimages) / num_og_preimages
      print(f"\nEpoch {epoch} Pruning Summary:")
      print(f"  Original proofs: {num_og_preimages:,}")
      print(f"  Pruned proofs: {num_pruned_preimages:,}")
      print(f"  Proofs removed: {num_og_preimages - num_pruned_preimages:,}")
      print(f"  Pruning ratio: {total_pruning_ratio:.1%}")
      wandb.log({
        "epoch": epoch,
        "original_proofs": num_og_preimages,
        "pruned_proofs": num_pruned_preimages,
        "pruning_ratio": total_pruning_ratio,
      })

    t_epoch = time() - t_begin_epoch
    wandb.log({"epoch": epoch, "train/it_time": t_epoch / len(iter)})

    # Log timing statistics
    timing_stats = "\n".join([f"{k}: {sum(v) / len(v):.2f}" for k, v in logger.stats.items() if k.startswith("T_")])
    memory_stats = "\n".join([f"{k}: {v}" for k, v in logger.stats.items() if k.startswith("Memory_")])
    
    # Check if symbolic_logger has stats before accessing
    if hasattr(symbolic_logger, 'stats'):
      symbolic_stats = "\n".join([f"{k}: {v}" for k, v in symbolic_logger.stats.items()])
      logger.debug(f"Times: {symbolic_stats}\n{timing_stats}\n{memory_stats}")
      symbolic_logger.reset_stats()
    else:
      logger.debug(f"Times: {timing_stats}\n{memory_stats}")
    
    logger.stats = defaultdict(list)
    self.scheduler.step()

  def test_epoch(self, epoch):
    self.network.eval()
    num_items = 0
    test_loss = 0
    total_correct = 0
    with torch.no_grad():
      iter = tqdm(self.test_loader, total=len(self.test_loader))
      for i, (img_seq, img_seq_len, label, gts) in enumerate(iter):
        d = self.network(img_seq.to(device), img_seq_len.to(device), self.db)[0]
        s, y_pred = d.symbols, d.get_probabilities()
        # s, y_pred = self.network(img_seq.to(device), img_seq_len.to(device), self.db)

        if len(y_pred.shape) == 1:
          y_pred = y_pred.view(1, -1)
        batch_size, num_outputs = y_pred.shape
        y = torch.tensor([1.0 if self.eval_result_eq(l.item(), m) else 0.0 for l in label for m in s.astype(float)], device=device).view(batch_size, -1)
        
        # Compute loss
        loss = self.loss_fn(y_pred, y)
        if not math.isnan(loss.item()):
          test_loss += loss.item()

        # Collect index and compute accuracy
        if num_outputs > 0:
          y_index = torch.argmax(y, dim=1)
          y_pred_index = torch.argmax(y_pred, dim=1)
          correct_count = torch.sum(torch.where(torch.sum(y, dim=1) > 0, y_index == y_pred_index, torch.zeros(batch_size, device=device).bool())).item()
        else:
          correct_count = 0

        # Stats
        num_items += batch_size
        total_correct += correct_count
        perc = 100. * total_correct / num_items
        avg_loss = test_loss / (i + 1)

        # Prints
        iter.set_description(f"[Test Epoch {epoch}] Avg loss: {avg_loss:.4f}, Accuracy: {total_correct}/{num_items} ({perc:.2f}%)")

    # Save model
    if test_loss < self.min_test_loss:
      self.min_test_loss = test_loss
      torch.save(self.network, os.path.join(self.model_root, self.model_name))

    wandb.log({"epoch": epoch, "test/loss": test_loss, "test/acc": perc})

  def train(self, n_epochs):
    def compare_weights(w1, w2):
      for p1, p2 in zip(w1, w2):
        if not torch.equal(p1, p2):
          return True
      return False
    
    def get_weights(model):
      weights = []

      for param in model.parameters():
          weights.append(param.clone())

      return weights
    
    # params_init = get_weights(self.network)
    # self.test_epoch(0)
    for epoch in range(1, n_epochs + 1):
      self.train_epoch(epoch)
      self.test_epoch(epoch)
      # logging.debug(f"Did the weights change? {compare_weights(params_init, get_weights(self.network))}")


if __name__ == "__main__":
  # Command line arguments
  parser = ArgumentParser("hwf")
  parser.add_argument("--model-name", type=str, default="hwf.pkl")
  parser.add_argument("--n-epochs", type=int, default=20)
  parser.add_argument("--sample-k", type=int, default=7)
  parser.add_argument("--l", type=int, default=3)
  parser.add_argument("--dataset-prefix", type=str, default="expr")
  parser.add_argument("--batch-size", type=int, default=64)
  parser.add_argument("--learning-rate", type=float, default=0.0001)
  parser.add_argument("--step-size", type=int, default=10)
  parser.add_argument("--gamma", type=float, default=0.1)
  parser.add_argument("--loss-fn", type=str, default="bce")
  parser.add_argument("--seed", type=int, default=1234)
  parser.add_argument("--do-not-use-hash", action="store_true")
  parser.add_argument("--provenance", type=str, default="dtkp-am", choices=['damp', 'dmmp', 'dtkp-am'])
  parser.add_argument("--top-k", type=int, default=3)
  parser.add_argument("--jit", action="store_true")
  parser.add_argument("--recompile", action="store_true")
  parser.add_argument("--device", type=str, default="cuda", choices=["cpu", "cuda", "mps"])
  parser.add_argument("--gpu", type=int, default=0)
  parser.add_argument("--log-level", type=str, default="INFO", choices=["INFO", "DEBUG", "WARNING"])
  parser.add_argument("--log-file", type=str, default=None)
  parser.add_argument("--structured-pruning", action="store_true")
  parser.add_argument('--structure-k', help='knn', default=10, type=int)
  parser.add_argument('--percent', help='knn', default=0.001, type=int)
  parser.add_argument("--mock_proximity", default = False, action="store_true")
  parser.add_argument("--num-training-samples", type=int, default=None)
  parser.add_argument("--exact-length", action="store_true", help="Filter for exact length instead of up-to length")
  args = parser.parse_args()

  # Parameters
  torch.manual_seed(args.seed)
  random.seed(args.seed)
  
  if args.device == "cuda" and torch.cuda.is_available():
    device_name = f"cuda:{args.gpu}"
  elif args.device == "mps" and torch.backends.mps.is_available():
    device_name = "mps"
  else:
    device_name = "cpu"

  device = torch.device(device_name)

  handler = [logging.StreamHandler()]
  if args.log_file:
    handler = [logging.FileHandler(args.log_file, mode="w")]

  logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=args.log_level, handlers=handler)

  # Data
  # data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data"))
  data_dir = "../../data"
  model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../model/hwf"))
  if not os.path.exists(model_dir): os.makedirs(model_dir)
  train_loader, test_loader = hwf_loader(data_dir, batch_size=args.batch_size, prefix=args.dataset_prefix, l=args.l, exact_length=args.exact_length, num_train_samples=args.num_training_samples)

  config = {
    "length": args.l,
    "device": device,
    "provenance": args.provenance,
    "top_k": args.top_k,
    "sample_k": args.sample_k,
    "seed": args.seed,
    "n_epochs": args.n_epochs,
    "batch_size": args.batch_size,
    "learning_rate": args.learning_rate,
    "experiment_type": "dolphin",
  }

  timestamp = datetime.datetime.now()
  id = f'dolphin_hwf{args.l}_{args.provenance}({args.top_k})_{args.seed}_{timestamp.strftime("%Y-%m-%d %H-%M-%S")}'

  wandb.login()
  wandb.init(project="HWF-dtkp", config=config, id=id)
  wandb.define_metric("epoch")
  wandb.define_metric("train/it_time", step_metric="epoch", summary="mean")
  wandb.define_metric("test/loss", step_metric="epoch", summary="min")
  wandb.define_metric("test/acc", step_metric="epoch", summary="max")
  wandb.define_metric("original_proofs", step_metric="epoch", summary="mean")
  wandb.define_metric("pruned_proofs", step_metric="epoch", summary="mean")
  wandb.define_metric("pruning_ratio", step_metric="epoch", summary="mean")

  print(args)

  # Training
  trainer = Trainer(train_loader, test_loader, device, model_dir, args.model_name, args.learning_rate, args.provenance, args.top_k, args.sample_k, args.l, step_size=args.step_size, gamma=args.gamma, use_preimages=args.structured_pruning, args=args)
  trainer.train(args.n_epochs)
