import json
import os
import pickle
import time
from os.path import join

import torch
import torch.nn as nn
import utils
from torch.autograd import Variable
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Subset


def compute_score_with_logits(logits, labels):
    logits = torch.max(logits, 1)[1].data # argmax
    one_hots = torch.zeros(*labels.size()).cuda()
    one_hots.scatter_(1, logits.view(-1, 1), 1)
    scores = (one_hots * labels)
    return scores


def train(model, train_dset, eval_dset, args):
    # Unpack args
    output = args.output
    num_epochs = args.epochs
    eval_each_epoch = args.eval_each_epoch
    batch_size = args.batch_size # Need batch_size for dataloader
    use_curriculum = args.use_curriculum
    # Curriculum params (if used)
    curriculum_start = args.curriculum_start_percent
    curriculum_end = args.curriculum_end_percent
    curriculum_pacing = args.curriculum_pacing

    utils.create_dir(output)
    optim = torch.optim.Adamax(model.parameters())
    logger = utils.Logger(os.path.join(output, 'log.txt'))
    all_results = []

    total_step = 0

    for epoch in range(num_epochs):
        total_loss = 0
        train_score = 0

        # --- Curriculum Learning: Determine subset and create DataLoader --- 
        current_train_loader = None
        current_dataset_size = len(train_dset)
        if use_curriculum:
            epoch_ratio = (epoch + 1) / num_epochs
            if curriculum_pacing == 'linear':
                # Linearly interpolate the percentage of data used
                current_percent = curriculum_start + (curriculum_end - curriculum_start) * epoch_ratio
            else:
                # Default to using all data if pacing is unknown
                current_percent = 1.0 
                # raise NotImplementedError(f"Curriculum pacing '{curriculum_pacing}' not implemented.")
                
            current_percent = min(max(current_percent, 0.0), 1.0) # Clamp between 0 and 1
            num_samples_this_epoch = int(len(train_dset) * current_percent)
            current_dataset_size = num_samples_this_epoch

            logger.write(f"Curriculum Epoch {epoch+1}/{num_epochs}: Using {num_samples_this_epoch}/{len(train_dset)} samples ({current_percent*100:.1f}%)")

            # Create a Subset of the (potentially sorted) dataset
            subset_indices = list(range(num_samples_this_epoch))
            current_train_dset_subset = Subset(train_dset, subset_indices)
            
            # Create a DataLoader for this epoch's subset
            # Shuffle should be True for training the subset
            current_train_loader = DataLoader(current_train_dset_subset, batch_size, shuffle=True, num_workers=32) # Use same num_workers
        else:
            # Standard training: use the full dataset loader (created once outside loop if preferred)
            current_train_loader = DataLoader(train_dset, batch_size, shuffle=True, num_workers=32) # Recreate or reuse if needed
        # ------------------------------------------------------------------

        t = time.time()

        for i, (v, q, a, b) in tqdm(enumerate(current_train_loader), ncols=100,
                                    desc="Epoch %d" % (epoch+1), total=len(current_train_loader)):
            total_step += 1
            v = Variable(v).cuda()
            q = Variable(q).cuda()
            a = Variable(a).cuda()
            b = Variable(b).cuda()

            pred, loss = model(v, None, q, a, b)

            if (loss != loss).any():
              raise ValueError("NaN loss")
            loss.backward()
            # nn.utils.clip_grad_norm_(model.parameters(), 0.25)  # Note the underscore
            optim.step()
            optim.zero_grad()

            batch_score = compute_score_with_logits(pred, a.data).sum()
            total_loss += loss.item() * v.size(0)
            train_score += batch_score

        # Normalize loss and score by the size of the dataset used in this epoch
        total_loss /= current_dataset_size
        train_score = 100 * train_score / current_dataset_size

        run_eval = eval_each_epoch or (epoch == num_epochs - 1)

        if run_eval:
            model.train(False)
            results = evaluate(model, eval_dset)
            results["epoch"] = epoch+1
            results["step"] = total_step
            results["train_loss"] = total_loss
            results["train_score"] = train_score
            all_results.append(results)

            # Add custom serialization function to handle tensors
            def tensor_serializer(obj):
                if isinstance(obj, torch.Tensor):
                    # Convert scalar to float, multi-dimensional to list
                    return obj.item() if obj.dim() == 0 else obj.tolist()
                raise TypeError(f"Type {type(obj)} not serializable")
            # Save with default handler function
            with open(join(output, "results.json"), "w") as f:
                json.dump(all_results, f, default=tensor_serializer, indent=2)


            model.train(True)
            eval_score = results["score"]
            bound = results["upper_bound"]

        logger.write('epoch %d, time: %.2f' % (epoch+1, time.time()-t))
        logger.write('\ttrain_loss: %.2f, score: %.2f' % (total_loss, train_score))

        if run_eval:
            logger.write('\teval score: %.2f (%.2f)' % (100 * eval_score, 100 * bound))

    model_path = os.path.join(output, 'model.pth')
    torch.save(model.state_dict(), model_path)

    return results


def evaluate(model, eval_dset, batch_size=512):
    # Create DataLoader within evaluate function
    dataloader = DataLoader(eval_dset, batch_size, shuffle=False, num_workers=32) 

    score = 0
    upper_bound = 0
    num_data = 0

    all_logits = []
    all_bias = []

    model.eval()  # Ensure model is in evaluation mode (disables Dropout, etc.)

    with torch.no_grad():  # Disable gradient computation
        for v, q, a, b in tqdm(dataloader, ncols=100, total=len(dataloader), desc="eval"):
            # Use tensors directly, no need for Variable
            v = v.cuda()
            q = q.cuda()
            a = a.cuda()  # Move a to GPU as well (if needed for compute_score_with_logits)

            pred, _ = model(v, None, q, None, None)
            all_logits.append(pred.data.cpu().numpy())

            batch_score = compute_score_with_logits(pred, a).sum()  # a is already on GPU
            score += batch_score
            upper_bound += (a.max(1)[0]).sum()
            num_data += pred.size(0)
            all_bias.append(b)

    score = score / len(dataloader.dataset)
    upper_bound = upper_bound / len(dataloader.dataset)

    results = dict(
        score=score,
        upper_bound=upper_bound,
    )
    return results