# -*- coding: utf-8 -*-
"""Helper utilities for training with the f-VAT framework.

This module provides ``train_SemiMean`` which sets up data augmentation,
builds the Temporal Convolutional Network (TCN) backbone and delegates the
actual training logic to :class:`model.trainer.Model_SemiMean`.

The function is intentionally kept lightweight so that experiments can simply
import and call ``train_SemiMean`` with different datasets or hyperparameters.
"""

from datetime import datetime
import torch
import utils.transforms as transforms
from dataloader.ucr2018 import *
from torch.utils.data.sampler import SubsetRandomSampler
from model.trainer import Model_SemiMean
import numpy as np
from model.TCNmodel import TCN

import os
import pandas as pd

percentageArray = [i for i in range(10, 91, 10)]
maskedPercentages = [i for i in range(0, 101, 10)]

def train_SemiMean(x_train, y_train, x_val, y_val, x_test, y_test, opt):
    """Run semi-supervised training with the given dataset.

    Parameters
    ----------
    x_train, y_train : np.ndarray
        Training data and labels.
    x_val, y_val : np.ndarray
        Validation split.
    x_test, y_test : np.ndarray
        Test split used for reporting the final accuracy.
    opt : argparse.Namespace
        Collection of command line options controlling model parameters and
        optimization hyperparameters.

    Returns
    -------
    Tuple[float, float, int]
        Best test accuracy, best validation accuracy and the epoch index.
    """

    # Short aliases for frequently used options
    K = opt.K
    batch_size = opt.batch_size
    tot_epochs = opt.epochs
    ckpt_dir = opt.ckpt_dir

    # ------------------------------------------------------------------
    # 1) Define the pool of possible time-series augmentations. These are
    # used both for the labeled and unlabeled branches of the f-VAT model.
    # ------------------------------------------------------------------
    prob = 0.2  # Transform probability applied to each augmentation
    raw = transforms.Raw()
    cutout = transforms.Cutout(sigma=0.1, p=prob)
    jitter = transforms.Jitter(sigma=0.2, p=prob)
    scaling = transforms.Scaling(sigma=0.4, p=prob)
    magnitude_warp = transforms.MagnitudeWrap(sigma=0.3, knot=4, p=prob)
    time_warp = transforms.TimeWarp(sigma=0.2, knot=8, p=prob)
    window_slice = transforms.WindowSlice(reduce_ratio=0.8, p=prob)
    window_warp = transforms.WindowWarp(window_ratio=0.3, scales=(0.5, 2), p=prob)

    # Select which transforms to use based on ``opt.aug_type``. The option can
    # be a combination such as ``G0`` or simply ``jitter``.
    transforms_list = {
        'jitter': [jitter],
        'cutout': [cutout],
        'scaling': [scaling],
        'magnitude_warp': [magnitude_warp],
        'time_warp': [time_warp],
        'window_slice': [window_slice],
        'window_warp': [window_warp],
        'G0': [jitter, magnitude_warp, window_slice],
        'G1': [jitter, time_warp, window_slice],
        'G2': [jitter, time_warp, window_slice, window_warp, cutout],
        'none': [raw]
    }

    # Gather the augmentation objects into a single transform pipeline.
    transforms_targets = []
    for name in opt.aug_type:
        for item in transforms_list[name]:
            transforms_targets.append(item)

    train_transform = transforms.Compose(transforms_targets)
    train_transform_label = transforms.Compose(transforms_targets + [transforms.ToTensor()])
    tensor_transform = transforms.ToTensor()

    # ------------------------------------------------------------------
    # 2) Construct the labeled, validation and test datasets.
    #    ``MultiUCR2018_Forecast`` returns multiple augmentations per sample for
    #    the unlabeled branch.
    # ------------------------------------------------------------------
    train_set_labeled = UCR2018(
        data=x_train, targets=y_train, transform=train_transform_label
    )
    val_set = UCR2018(data=x_val, targets=y_val, transform=tensor_transform)
    test_set = UCR2018(data=x_test, targets=y_test, transform=tensor_transform)
    train_set = MultiUCR2018_Forecast(
        data=x_train,
        targets=y_train,
        K=K,
        transform=train_transform,
        totensor_transform=tensor_transform,
    )

    # Sample a subset of the labeled data according to ``label_ratio``.
    train_dataset_size = len(train_set_labeled)
    partial_size = int(opt.label_ratio * train_dataset_size)
    train_ids = list(range(train_dataset_size))
    np.random.shuffle(train_ids)
    train_sampler = SubsetRandomSampler(train_ids[:partial_size])

    # Data loaders for labeled/unlabeled/validation/test splits
    train_loader_label = torch.utils.data.DataLoader(
        train_set_labeled, batch_size=batch_size, sampler=train_sampler
    )
    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=batch_size, shuffle=False
    )
    test_loader = torch.utils.data.DataLoader(
        test_set, batch_size=batch_size, shuffle=False
    )


    # ------------------------------------------------------------------
    # 3) Instantiate the model and its EMA (teacher) copy. Currently only
    # TCN is supported, but other backbones could easily be added here.
    # ------------------------------------------------------------------
    if opt.model_select == "TCN":
        channel_sizes = [opt.nhid] * opt.levels
        kernel_size = opt.ksize
        model = TCN(
            input_size=1,
            output_size=opt.nb_class,
            num_channels=channel_sizes,
            kernel_size=kernel_size,
            dropout=opt.dropout,
        ).cuda()
        ema_model = TCN(
            input_size=1,
            output_size=opt.nb_class,
            num_channels=channel_sizes,
            kernel_size=kernel_size,
            dropout=opt.dropout,
        ).cuda()

    else:
        # For unsupported architectures, raise an explicit error so that the
        # caller knows ``model_select`` was misconfigured.
        raise ValueError(
            f"Unknown model_select: {opt.model_select}. Choose from [TCN, GRU, SelfAttn]."
        )

    # Wrap the backbone with the training logic which handles virtual
    # adversarial perturbations and the EMA update.
    trainer = Model_SemiMean(model, ema_model, opt).cuda()
    torch.save(
        trainer.model.state_dict(),
        f"{ckpt_dir}/backbone_init_{opt.model_name}_{opt.use_flag}.tar",
    )

    # ------------------------------------------------------------------
    # 4) Kick off training. ``Model_SemiMean.train`` handles the entire loop
    # including evaluation on the validation and test sets.
    # ------------------------------------------------------------------
    test_acc, acc_unlabel, best_epoch = trainer.train(
        tot_epochs=tot_epochs,
        train_loader=train_loader,
        train_loader_label=train_loader_label,
        val_loader=val_loader,
        test_loader=test_loader,
        opt=opt,
    )

    # Save the last model checkpoint for reproducibility
    torch.save(
        trainer.model.state_dict(),
        f"{ckpt_dir}/backbone_last_{opt.model_name}_{opt.use_flag}.tar",
    )
    return test_acc, acc_unlabel, best_epoch
