import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import time
import random
from typing import *

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from argparse import ArgumentParser
from tqdm import tqdm

import wandb  # <-- Add wandb import
from datetime import datetime

from dolphin import Distribution
from dolphin.provenances import get_provenance

# Data transformation for MNIST images
mnist_img_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

# Custom Dataset Class
class MNISTProdNDataset(torch.utils.data.Dataset):
    def __init__(
        self,
        root: str,
        prod_n: int,
        train: bool = True,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ):
        # Contains a MNIST dataset
        self.mnist_dataset = torchvision.datasets.MNIST(
            root,
            train=train,
            transform=transform,
            target_transform=target_transform,
            download=download,
        )
        self.prod_n = prod_n
        self.index_map = list(range(len(self.mnist_dataset)))
        random.shuffle(self.index_map)

    def __len__(self):
        return int(len(self.mnist_dataset) / self.prod_n)

    def __getitem__(self, idx):
        # Get n data points
        imgs = ()
        img, digit = self.mnist_dataset[self.index_map[idx * self.prod_n]]
        imgs = imgs + (img,)
        prod = digit
        for i in range(1, self.prod_n):
            img, digit = self.mnist_dataset[self.index_map[idx * self.prod_n + i]]
            imgs = imgs + (img,)
            prod *= digit
        return (*imgs, prod)

    @staticmethod
    def collate_fn(batch):
        imgs = ()
        for i in range(len(batch[0]) - 1):
            a = torch.stack([item[i] for item in batch])
            imgs = imgs + (a,)
        digits = torch.stack([torch.tensor(item[len(batch[0]) - 1]).long() for item in batch])
        return ((imgs), digits)

# Data loader function
def mnist_prod_n_loader(data_dir, prod_n, batch_size_train, batch_size_test):
    train_loader = torch.utils.data.DataLoader(
        MNISTProdNDataset(
            data_dir,
            prod_n,
            train=True,
            download=True,
            transform=mnist_img_transform,
        ),
        collate_fn=MNISTProdNDataset.collate_fn,
        batch_size=batch_size_train,
        shuffle=True
    )

    test_loader = torch.utils.data.DataLoader(
        MNISTProdNDataset(
            data_dir,
            prod_n,
            train=False,
            download=True,
            transform=mnist_img_transform,
        ),
        collate_fn=MNISTProdNDataset.collate_fn,
        batch_size=batch_size_test,
        shuffle=True
    )

    return train_loader, test_loader

# Base MNIST network
class MNISTNet(nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(1024, 1024)
        self.fc2 = nn.Linear(1024, 10)

    def forward(self, x):
        x = F.max_pool2d(self.conv1(x), 2)
        x = F.max_pool2d(self.conv2(x), 2)
        x = x.view(-1, 1024)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        return F.softmax(x, dim=1)

# Symbolic Product Network
class MNISTProdNNet(nn.Module):
    def __init__(self, k=None):
        super(MNISTProdNNet, self).__init__()

        # MNIST Digit Recognition Network
        self.mnist_net = MNISTNet()
        self.k = k  # Top-k sampling parameter

    def forward(self, x: Tuple[torch.Tensor, ...]):
        dist_list = []
        batch_size = x[0].shape[0]

        for i in range(len(x)):
            # Process each digit image through the MNISTNet
            out = self.mnist_net(x[i])  # Shape: [batch_size, 10]
            dist = Distribution(out, range(10))  # Create a distribution
            dist_list.append(dist)

        # Multiply all distributions
        a = dist_list[0]
        for dist in dist_list[1:]:
            a = a * dist
            # Apply top-k sampling to limit memory usage
            if self.k is not None:
                a = a.sample_top_k(k=self.k)
        
        return a

def nll_loss(output, ground_truth):
    return F.nll_loss(output, ground_truth)

# Trainer class
class Trainer():
    def __init__(self, train_loader, test_loader, model_dir, learning_rate, loss, provenance, device, k, prod_n, step_size=10, gamma=0.1):
        self.device = device
        self.model_dir = model_dir
        Distribution.provenance = get_provenance(provenance)
        Distribution.provenance.k = 3

        self.prod_n = prod_n
        self.network = MNISTProdNNet(k=k).to(self.device)
        self.optimizer = optim.Adam(self.network.parameters(), lr=learning_rate)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=step_size, gamma=gamma)
        self.train_loader = train_loader
        self.test_loader = test_loader

        self.total_time = 0
        self.best_loss = 1e10
        self.provenance = provenance
        self.best_acc = 0

        # Track epoch times for average later
        self.epoch_times = []

        if loss == "nll":
            self.loss = nll_loss
        elif loss == "bce":
            self.loss = F.binary_cross_entropy
        else:
            raise Exception(f"Unknown loss function `{loss}`")

    def train_epoch(self, epoch):
        self.network.train()
        num_items = 0
        total_correct = 0

        # Measure time
        t_begin = time.time()

        for batch_idx, (data, target) in enumerate(tqdm(self.train_loader)):
            imgs = tuple(d.to(self.device) for d in data)
            target = target.to(self.device)

            self.optimizer.zero_grad()
            output = self.network(imgs)
            dist = output.get_probabilities()
            symbols = torch.tensor(output.symbols, device=self.device)

            # Identify target indices
            target_indices = []
            for t in target:
                where = torch.where(symbols == t.item())
                if len(where[0]) == 0:
                    target_indices.append(-1)
                else:
                    target_indices.append(where[0][0].item())
            target_indices = torch.tensor(target_indices, device=self.device)

            # Build label distribution
            y = torch.zeros(dist.shape[0], dist.shape[1] + 1, device=self.device)
            y[torch.arange(len(target), device=self.device), target_indices] = 1.0
            y = y[:, :-1]

            # Prediction
            pred_indices = torch.argmax(dist, dim=1)
            pred_values = torch.tensor(symbols)[pred_indices].to(self.device)
            correct = (pred_values == target).sum().item()

            # Loss and backprop
            loss = F.binary_cross_entropy(dist, y)
            loss.backward()
            self.optimizer.step()

            num_items += len(target)
            total_correct += correct

        # End of epoch time
        epoch_time = time.time() - t_begin
        self.epoch_times.append(epoch_time)
        print(f"[Train Epoch {epoch}] Epoch time: {epoch_time:.4f}s")

        # Log training epoch time
        wandb.log({"epoch": epoch, "train_epoch_time": epoch_time})

    def test_epoch(self, epoch):
        self.network.eval()
        num_items = 0
        total_correct = 0
        test_loss = 0

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(tqdm(self.test_loader)):
                imgs = tuple(d.to(self.device) for d in data)
                target = target.to(self.device)
                output = self.network(imgs)
                dist = output.get_probabilities()
                symbols = torch.tensor(output.symbols, device=self.device)

                # Identify target indices
                target_indices = []
                for t in target:
                    where = torch.where(symbols == t.item())
                    if len(where[0]) == 0:
                        target_indices.append(-1)
                    else:
                        target_indices.append(where[0][0].item())
                target_indices = torch.tensor(target_indices, device=self.device)

                # Build label distribution
                y = torch.zeros(dist.shape[0], dist.shape[1] + 1, device=self.device)
                y[torch.arange(len(target), device=self.device), target_indices] = 1.0
                y = y[:, :-1]

                # Prediction
                pred_indices = torch.argmax(dist, dim=1)
                pred_values = torch.tensor(symbols)[pred_indices].to(self.device)
                correct = (pred_values == target).sum().item()

                # Loss
                loss = F.binary_cross_entropy(dist, y)
                test_loss += loss.item()

                num_items += len(target)
                total_correct += correct

        # Accuracy and avg loss
        acc = 100.0 * total_correct / num_items
        avg_loss = test_loss / len(self.test_loader)
        print(f"[Test Epoch {epoch}] Accuracy: {acc:.2f}%, Loss: {avg_loss:.4f}")

        # Update best metrics
        if acc > self.best_acc:
            self.best_acc = acc

        # Log to wandb
        wandb.log({"epoch": epoch, "accuracy": acc, "test_loss": avg_loss, "best_accuracy": self.best_acc})

        return test_loss

    def train(self, n_epochs):
        # Evaluate once at start
        self.test_epoch(0)

        for epoch in range(1, n_epochs + 1):
            self.train_epoch(epoch)
            self.test_epoch(epoch)
            self.scheduler.step()

        # Print average epoch time
        if len(self.epoch_times) > 0:
            average_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
            print(f"Average epoch time: {average_epoch_time:.4f}s")
            wandb.log({"avg_epoch_time": average_epoch_time})

        print("Max memory allocated (MB):", torch.cuda.max_memory_allocated() / 1024 / 1024)

if __name__ == "__main__":
    # Argument parser
    parser = ArgumentParser()
    parser.add_argument("--prod-n", type=int, default=5)
    parser.add_argument("--n-epochs", type=int, default=10)
    parser.add_argument("--batch-size-train", type=int, default=64)
    parser.add_argument("--batch-size-test", type=int, default=64)
    parser.add_argument("--learning-rate", type=float, default=0.001)
    parser.add_argument("--loss-fn", type=str, default="bce")
    parser.add_argument("--seed", type=int, default=1234)
    parser.add_argument("--provenance", type=str, default="damp", choices=['damp', 'dmmp', 'dtkp-am'])
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--topk", type=int, default=None, help="Top-k sampling parameter")
    parser.add_argument("--step-size", type=int, default=10)
    parser.add_argument("--gamma", type=float, default=0.1)
    args = parser.parse_args()

    print(args)

    # Parameters
    prod_n = args.prod_n
    n_epochs = args.n_epochs
    batch_size_train = args.batch_size_train
    batch_size_test = args.batch_size_test
    learning_rate = args.learning_rate
    loss_fn = args.loss_fn
    provenance = args.provenance
    torch.manual_seed(args.seed)
    random.seed(args.seed)

    # Data paths
    data_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), "../../data"))
    model_dir = os.path.abspath(os.path.join(os.path.abspath(__file__), f'../../model/mnist_prod_{prod_n}'))
    os.makedirs(model_dir, exist_ok=True)

    # Dataloaders
    train_loader, test_loader = mnist_prod_n_loader(data_dir, prod_n, batch_size_train, batch_size_test)

    if args.device == "cuda" and torch.cuda.is_available():
        device_name = "cuda"
    elif args.device == "mps" and torch.backends.mps.is_available():
        device_name = "mps"
    else:
        device_name = "cpu"
    device = torch.device(device_name)

    # Initialize W&B
    config = {
        "prod_n": prod_n,
        "n_epochs": n_epochs,
        "batch_size_train": batch_size_train,
        "batch_size_test": batch_size_test,
        "learning_rate": learning_rate,
        "loss_fn": loss_fn,
        "provenance": provenance,
        "seed": args.seed,
    }
    timestamp = datetime.now()
    run_id = f"dolphin_prod{prod_n}_{args.seed}_{provenance}_{timestamp.strftime('%Y-%m-%d_%H-%M-%S')}"
    wandb.init(project="Prod n", config=config, id=run_id)

    trainer = Trainer(train_loader, test_loader, model_dir, learning_rate, loss_fn, provenance, device, args.topk, prod_n)
    trainer.train(n_epochs)
