import os
import time
import datetime
import json

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

from dataloader import EndomondoDataReader, EndomondoDataset
from ours_dataloader import OursDataReader
from utils import get_device, set_random_seed, parse_arguments
from evaluation import compute_metrics
from models import EndoLSTMModel_Full1, contrastive_loss_func


def train_one_epoch(model, dataloader, criterion, optimizer, device, scheduler=None, global_step=0,
                    contrastive_loss=True, epoch=None):
    model.train()
    total_loss = 0.0
    total_samples = 0
    assert args is not None
    using_sport_ids = args.use_sport
    idx = 0
    for batch in dataloader:
        idx+=1
        inputs_dict, targets, _ = batch
        main_input = torch.as_tensor(inputs_dict['input'], dtype=torch.float32, device=device)
        targets = torch.as_tensor(targets, dtype=torch.float32, device=device)

        user_input = sport_input = gender_input = device_input = None
        context_in1 = context_in2 = None
        full_context_in1 = full_context_in2 = None
        user_ids = None
        sport_ids = None
        if 'user_input' in inputs_dict:
            user_in = torch.as_tensor(inputs_dict['user_input'], device=device, dtype=torch.long)
            user_input = user_in.squeeze(-1)
            user_ids = user_input[:, 0]

        if 'sport_input' in inputs_dict:
            sport_in = torch.as_tensor(inputs_dict['sport_input'], device=device, dtype=torch.long)
            sport_input = sport_in.squeeze(-1)
            if using_sport_ids:
                sport_ids = sport_input[:, 0]

        if 'gender_input' in inputs_dict:
            gender_in = torch.as_tensor(inputs_dict['gender_input'], device=device, dtype=torch.long)
            gender_input = gender_in.squeeze(-1)

        if 'device_input' in inputs_dict:
            device_in = torch.as_tensor(inputs_dict['device_input'], device=device, dtype=torch.long)
            device_input = device_in.squeeze(-1)

        if 'context_input_1' in inputs_dict:
            context_in1 = torch.as_tensor(inputs_dict['context_input_1'], dtype=torch.float32, device=device)
        if 'context_input_2' in inputs_dict:
            context_in2 = torch.as_tensor(inputs_dict['context_input_2'], dtype=torch.float32, device=device)

        if 'full_context_input_1' in inputs_dict:
            full_context_in1 = torch.as_tensor(inputs_dict['full_context_input_1'], dtype=torch.float32, device=device)
        if 'full_context_input_2' in inputs_dict:
            full_context_in2 = torch.as_tensor(inputs_dict['full_context_input_2'], dtype=torch.float32, device=device)

        if contrastive_loss:
            preds, embeddings = model(
                main_input=main_input,
                user_input=user_input,
                sport_input=sport_input,
                gender_input=gender_input,
                device_input=device_input,
                context_in1=context_in1,
                context_in2=context_in2,
                full_context_in1=full_context_in1,
                full_context_in2=full_context_in2,
                epoch=epoch
            )

            loss = criterion(preds, targets)
            if user_ids is not None:
                contrast_loss = contrastive_loss_func(embeddings, user_ids)
                loss = loss + args.contrastive_weight * contrast_loss
            if using_sport_ids and sport_ids is not None:
                contrast_loss = contrastive_loss_func(embeddings, sport_ids)
                loss = loss + args.contrastive_weight * contrast_loss
        else:
            preds = model(
                main_input=main_input,
                user_input=user_input,
                sport_input=sport_input,
                gender_input=gender_input,
                device_input=device_input,
                context_in1=context_in1,
                context_in2=context_in2,
                full_context_in1=full_context_in1,
                full_context_in2=full_context_in2,
                epoch=epoch
            )
            loss = criterion(preds, targets)

        optimizer.zero_grad()

        loss.backward()
        if args.clip_grad_norm:
            total_norm = torch.norm(torch.stack([
                p.grad.detach().norm() for p in model.parameters() if p.grad is not None
            ]))
            if total_norm > 5.0:
                pass
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)

        optimizer.step()

        if scheduler is not None and isinstance(scheduler, dict):
            if 'warmup' in scheduler and global_step < scheduler['warmup_steps']:
                scheduler['warmup'].step()

        batch_size = main_input.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size
        global_step += 1

    mean_loss = total_loss / float(total_samples + 1e-8)
    return mean_loss, global_step


def compute_user_mean_targets(train_loader):
    user_sum = {}
    user_count = {}

    for batch in train_loader:
        if 'user_input' not in batch[0]:
            continue
        users = batch[0]['user_input'][:, 0, 0].cpu().numpy()  # shape (B,)

        targets = batch[1].cpu().numpy()

        for user, target_seq in zip(users, targets):
            avg_target = np.mean(target_seq)
            if user in user_sum:
                user_sum[user] += avg_target
                user_count[user] += 1
            else:
                user_sum[user] = avg_target
                user_count[user] = 1

    user_mean = {user: user_sum[user] / user_count[user] for user in user_sum}
    return user_mean


def compute_baseline_mse(data_loader, user_mean_targets):
    total_squared_error = 0.0
    total_elements = 0

    for batch in data_loader:
        if 'user_input' not in batch[0]:
            continue
        users = batch[0]['user_input'][:, 0, 0].cpu().numpy()
        targets = batch[1].cpu().numpy()

        for user, target_seq in zip(users, targets):
            baseline_value = user_mean_targets.get(user, np.mean(list(user_mean_targets.values())))
            baseline_pred = np.full_like(target_seq, fill_value=baseline_value)
            total_squared_error += np.sum((target_seq - baseline_pred) ** 2)
            total_elements += np.prod(target_seq.shape)

    mse = total_squared_error / total_elements if total_elements > 0 else float('0')
    return mse


@torch.no_grad()
def evaluate_model(model, dataloader, criterion, device, base_dir, save_output=False, contrastive_loss=True,
                   metrics_list=None,
                   return_raw=False):
    model.eval()
    total_loss = 0.0
    total_samples = 0
    results = []
    preds_collector = []
    targets_collector = []
    sport_collector = []

    for batch in dataloader:
        inputs_dict, targets, workout_id = batch
        targets = torch.as_tensor(targets, dtype=torch.float32, device=device)
        main_input = torch.as_tensor(inputs_dict['input'], dtype=torch.float32, device=device)
        output = model_forward(batch, device, model, sport_collector)
        if contrastive_loss:
            preds, _ = output
        else:
            preds = output

        batch_size = main_input.size(0)
        total_samples += batch_size

        if not save_output:
            loss = criterion(preds, targets)
            total_loss += loss.item() * batch_size

        else:
            for i in range(batch_size):
                results.append({
                    "workout_id": int(workout_id[i]),
                    "target_heart_rate": [val[0] if isinstance(val, list) and len(val) == 1 else val for val in
                                          targets[i].cpu().tolist()],
                    "predicted_heart_rate": [val[0] if isinstance(val, list) and len(val) == 1 else val for val in
                                             preds[i].cpu().tolist()],
                    "latitude_seq": inputs_dict["lat_seq"][i].cpu().tolist(),
                    "longitude_seq": inputs_dict["lon_seq"][i].cpu().tolist(),
                })
        if metrics_list is not None:
            preds_collector.append(preds.detach().cpu().numpy())
            targets_collector.append(targets.detach().cpu().numpy())

    if save_output:
        folder_path = os.path.join(base_dir, "./output/eval_output")
        os.makedirs(folder_path, exist_ok=True)
        with open(os.path.join(folder_path, "evaluation_results_wo_aug.json"), "w") as f:
            json.dump(results, f, indent=4)
        print("Evaluation results saved to evaluation_results_wo_aug.json")

    if metrics_list is not None:
        preds_all = np.concatenate(preds_collector, axis=0)
        targets_all = np.concatenate(targets_collector, axis=0)
        if return_raw:
            sports_all = np.concatenate(sport_collector, axis=0).astype(int) if sport_collector else None
            metric_dict = compute_metrics(preds_all, targets_all, metrics_list)
            return metric_dict, preds_all, targets_all, sports_all
        else:
            return compute_metrics(preds_all, targets_all, metrics_list)
    mean_loss = total_loss / (float(total_samples) + 1e-8)
    return mean_loss


def model_forward(batch, device, model, sport_collector):
    inputs_dict, _, _ = batch
    main_input = torch.as_tensor(inputs_dict['input'], dtype=torch.float32, device=device)

    user_input = sport_input = gender_input = device_input = None
    context_in1 = context_in2 = None
    full_context_in1 = full_context_in2 = None
    if 'user_input' in inputs_dict:
        user_in = torch.as_tensor(inputs_dict['user_input'], device=device, dtype=torch.long)
        user_input = user_in.squeeze(-1)
    if 'sport_input' in inputs_dict:
        sport_in = torch.as_tensor(inputs_dict['sport_input'], device=device, dtype=torch.long)
        sport_input = sport_in.squeeze(-1)
        sport_collector.append(sport_input[:, 0].cpu().numpy().astype(int))
    if 'gender_input' in inputs_dict:
        gender_in = torch.as_tensor(inputs_dict['gender_input'], device=device, dtype=torch.long)
        gender_input = gender_in.squeeze(-1)
    if 'device_input' in inputs_dict:
        device_in = torch.as_tensor(inputs_dict['device_input'], device=device, dtype=torch.long)
        device_input = device_in.squeeze(-1)

    if 'context_input_1' in inputs_dict:
        context_in1 = torch.as_tensor(inputs_dict['context_input_1'], dtype=torch.float32, device=device)
    if 'context_input_2' in inputs_dict:
        context_in2 = torch.as_tensor(inputs_dict['context_input_2'], dtype=torch.float32, device=device)
    if 'full_context_input_1' in inputs_dict:
        full_context_in1 = torch.as_tensor(inputs_dict['full_context_input_1'], dtype=torch.float32, device=device)
    if 'full_context_input_2' in inputs_dict:
        full_context_in2 = torch.as_tensor(inputs_dict['full_context_input_2'], dtype=torch.float32, device=device)

    output = model(
        main_input=main_input,
        user_input=user_input,
        sport_input=sport_input,
        gender_input=gender_input,
        device_input=device_input,
        context_in1=context_in1,
        context_in2=context_in2,
        full_context_in1=full_context_in1,
        full_context_in2=full_context_in2
    )
    return output


def run_training(
        model,
        train_loader,
        valid_loader,
        test_loader,
        criterion,
        optimizer,
        max_epochs,
        patience,
        device,
        scheduler=None,
        base_dir="model_states",
        contrastive_loss=True
):
    now_str = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    model_save_dir = os.path.join(base_dir, "./output/ckpt/model_states", f"HR_{now_str}")
    os.makedirs(model_save_dir, exist_ok=True)

    best_valid_loss = float('200000')
    best_epoch = 0
    current_patience = 0
    global_step = 0
    for epoch in range(1, max_epochs + 1):
        print("\n" + "-" * 50)
        print(f"Epoch {epoch}/{max_epochs}")

        start_time = time.time()
        train_loss, global_step = train_one_epoch(model, train_loader, criterion, optimizer, device, scheduler,
                                                  global_step, contrastive_loss,
                                                  epoch=epoch)
        valid_loss = evaluate_model(model, valid_loader, criterion, device, base_dir, contrastive_loss=contrastive_loss)
        test_loss = evaluate_model(model, test_loader, criterion, device, base_dir, contrastive_loss=contrastive_loss)
        if scheduler is not None:
            if isinstance(scheduler, dict) and 'plateau' in scheduler:
                scheduler['plateau'].step(valid_loss)
            else:
                scheduler.step(valid_loss)

        elapsed = time.time() - start_time

        print(
            f"Train Loss: {train_loss:.4f},  Valid Loss: {valid_loss:.4f},  Test Loss: {test_loss:.4f}, Time: {elapsed:.2f}s")

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            best_epoch = epoch
            current_patience = 0
            save_path = os.path.join(model_save_dir, "best_model.pt")
            torch.save(model.state_dict(), save_path)
            print(f"  [*] Best model so far. Saved to {save_path}.")
        else:
            current_patience += 1
            if current_patience >= patience:
                print(f"Early stopped at epoch={epoch}, best_valid_loss={best_valid_loss:.4f} at epoch={best_epoch}")
                break

    model.load_state_dict(torch.load(os.path.join(model_save_dir, "best_model.pt")))
    test_loss = evaluate_model(model, test_loader, criterion, device, base_dir, contrastive_loss=contrastive_loss)
    print("=" * 50)
    print(f"Best Epoch: {best_epoch},  Best Valid Loss: {best_valid_loss:.4f}")
    print(f"Test Loss: {test_loss:.4f}")

    user_mean_targets = compute_user_mean_targets(train_loader)
    baseline_valid_mse = compute_baseline_mse(valid_loader, user_mean_targets)
    baseline_test_mse = compute_baseline_mse(test_loader, user_mean_targets)
    print(f'Baseline Valid MSE: {baseline_valid_mse:.4f}, Baseline Test MSE: {baseline_test_mse:.4f}')

    print("Done!!!")


args = None


def main():
    global args
    args = parse_arguments()
    print("Args dict:" + str(args.__dict__))

    set_random_seed(114)

    if args.device is None or args.device == "auto":
        device = get_device()
    else:
        device = torch.device(args.device)
    print("Using device:", device)


    inputAtts = args.input_attributes.split(",")
    targetAtts = [args.target]

    if args.dataset == 'fitrec':
        data_reader = EndomondoDataReader(
            inputAtts=inputAtts,
            targetAtts=targetAtts,
            base_dir=args.base_dir,
            includeUser=True,
            includeSport=True,
            includeGender=True,
            includeDevice=False,
            includeTemporal=args.temporal,
            fn=args.fn,
            scaleVals=True,
            trimmed_workout_len=450,
            scaleTargets=False,
            trainValidTestFN=args.trainValidTestFN,
            zMultiple=5,
            includeFullTemporal=args.full_temporal,
            fullTemporalLength=args.limit_full_temporal_length,
        )
    elif args.dataset == 'ours':
        assert args.include_device, f"Include device is not true! Check Reason! Args:{args}"
        data_reader = OursDataReader(
            base_dir=args.base_dir,
            includeUser=True,
            includeSport=True,
            includeGender=True,
            includeDevice=args.include_device,
            includeTemporal=args.temporal,
            fn=args.fn,
            scaleVals=True,
            trimmed_workout_len=450,
            scaleTargets=False,
            trainValidTestFN=args.trainValidTestFN,
            zMultiple=5,
            includeFullTemporal=args.full_temporal,
            fullTemporalLength=args.limit_full_temporal_length,
        )
    else:
        raise ValueError(f'Unsupported dataset: {args.dataset}')


    data_reader.preprocess_data()
    if args.dataset == 'fitrec':
        print("Calculating type of metadata attributes...")
        num_users = len(data_reader.oneHotMap['userId']) if ("userId" in data_reader.oneHotMap) else 1
        num_sports = len(data_reader.oneHotMap['sport']) if ("sport" in data_reader.oneHotMap) else 1
        num_genders = len(data_reader.oneHotMap['gender']) if ("gender" in data_reader.oneHotMap) else 2
        num_device = 1
    elif args.dataset == 'ours':
        print("Calculating type of metadata attributes...")
        num_users = len(data_reader.oneHotMap['user_id'])
        num_sports = len(data_reader.oneHotMap['sport_type'])
        num_genders = len(data_reader.oneHotMap['gender'])
        num_device = len(data_reader.oneHotMap['device'])
    else:
        raise ValueError(f'Unsupported dataset: {args.dataset}')

    train_dataset = EndomondoDataset(data_reader, mode='train')
    valid_dataset = EndomondoDataset(data_reader, mode='valid')
    test_dataset = EndomondoDataset(data_reader, mode='test')

    collate_fn = None
    if args.full_temporal:
        collate_fn = EndomondoDataset.pad_full_context

    num_workrs = args.num_workers
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=num_workrs,
                              collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=num_workrs,
                              collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=num_workrs,
                             collate_fn=collate_fn)

    model = EndoLSTMModel_Full1(
        input_dim=data_reader.input_dim,
        output_dim=data_reader.output_dim,
        num_users=num_users,
        num_sports=num_sports,
        num_genders=num_genders,
        user_dim=args.attribute_dim,
        sport_dim=args.attribute_dim,
        gender_dim=args.attribute_dim,
        hidden_dim=args.hidden_dim,
        includeUser=True,
        includeSport=True,
        includeGender=True,
        num_devices=num_device,
        includeDevice=args.include_device,
        includeTemporal=args.temporal,
        includeFullTemporal=args.full_temporal,
        fullTemporalLength=args.limit_full_temporal_length,
        feature_dropout=args.feature_dropout,
        advanced_feature_dropout=args.advanced_feature_dropout,
        contrastive_loss=args.contrastive_loss,
    ).to(device)

    if args.pretrain:
        print('Pretrain is set')
        if os.path.exists(args.pretrain_file):
            print(f"Loading pre-trained weights from {args.pretrain_file}...")
            checkpoint = torch.load(args.pretrain_file, map_location=device)
            model.load_state_dict(checkpoint, strict=False)
        else:
            print(f"Warning: Pretrain file {args.pretrain_file} not found. Training from scratch.")

    if args.eval:
        criterion = nn.MSELoss()

        loader_chain = test_loader
        overall_loss = evaluate_model(model, loader_chain, criterion, device, args.base_dir, save_output=True,
                                        contrastive_loss=args.contrastive_loss)
        print(f'Overall loss:{overall_loss}')
    else:
        criterion = nn.MSELoss()
        optimizer = optim.RMSprop(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.8, patience=3)
        run_training(
            model=model,
            train_loader=train_loader,
            valid_loader=valid_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            max_epochs=args.epoch,
            patience=args.patience,
            device=device,
            base_dir=args.base_dir,
            scheduler=scheduler,
            contrastive_loss=args.contrastive_loss
        )


if __name__ == "__main__":
    main()
