# Adapted from the MAE implementation by xxx
# --------------------------------------------------------
# References:
# MAE:  https://github.com/facebookresearch/mae
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------

import math
import numpy as np
import sys
from typing import Iterable, Optional
import re
import torch

from timm.data import Mixup
from timm.utils import accuracy
from timm.data.mixup import mixup_target

import util.misc as misc
import util.lr_sched as lr_sched
from util.regression_utils import create_optimizer
import os, json
import wandb
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast

def pearson_per_trial(y_pred,
                      y_true,
                      eps= 1e-8):
    """
    Compute Pearson r per trial along dim=1 and return the average.

    Args:
        y_pred: Tensor of shape (B, N)
        y_true: Tensor of shape (B, N)
        eps: Small number for numerical stability.

    Returns:
        mean_r: scalar tensor = mean of per-trial Pearson r
        r_per_trial: tensor of shape (B,) with each trial's r
    """
    # means over time axis (dim=1), keep dims for broadcasting
    mu_pred = y_pred.mean(dim=1, keepdim=True)
    mu_true = y_true.mean(dim=1, keepdim=True)

    # zero-center
    pred_z = y_pred - mu_pred      # (B, N)
    true_z = y_true - mu_true      # (B, N)

    # per-trial covariance and variances
    cov = (pred_z * true_z).sum(dim=1)             # (B,)
    sigma_pred = torch.sqrt((pred_z ** 2).sum(dim=1))  # (B,)
    sigma_true = torch.sqrt((true_z ** 2).sum(dim=1))  # (B,)

    # per-trial Pearson r
    r = cov / (sigma_pred * sigma_true + eps)      # (B,)

    # average across batch
    mean_r = r.mean()
    return mean_r


def wandb_log_stats(prefix: str,
                    epoch: int,
                    total_epoch: int,
                    train_stats: dict,
                    test_whole: dict,
                    individual: dict,
                    args,
                    n_parameters: int):
    """
    Logs to W&B for regression:
      • at final epoch:
         – scatter plot of GT vs Pred (no built-in regression)
         – time-series line plot of GT vs Pred
         – per-sample preds/labels table (to download later)
         – scalar R²
      • every epoch:
         – a Table of (epoch, stage, test_subject, *metrics)
         – flat scalars: train_*, test_whole_*, individual_* metrics
    """
    is_final = (epoch == total_epoch - 1)

    # --- 1) Final-epoch preds vs. labels logging ---
    if is_final:
        y_true_wh = test_whole.get("labels")
        y_pred_wh = test_whole.get("preds")
        if y_true_wh is not None and y_pred_wh is not None:
            # ensure numpy arrays
            y_true_arr = np.asarray(y_true_wh).ravel()
            y_pred_arr = np.asarray(y_pred_wh).ravel()

            # compute R²
            ss_res = np.sum((y_true_arr - y_pred_arr) ** 2)
            ss_tot = np.sum((y_true_arr - np.mean(y_true_arr)) ** 2)
            r2 = 1.0 - ss_res / ss_tot
            wandb.log({f"{prefix}/r2": r2}, step=epoch)

            # build per-sample table
            sample_table = wandb.Table(columns=["index", "ground_truth", "prediction"])
            for i, (gt, pr) in enumerate(zip(y_true_arr.tolist(), y_pred_arr.tolist())):
                sample_table.add_data(i, gt, pr)

            # scatter plot of GT vs Pred
            scatter = wandb.plot.scatter(
                sample_table,
                x="ground_truth",
                y="prediction",
                title=f"{prefix} GT vs Prediction"
            )

            # time-series line plot
            line_plot = wandb.plot.line(
                sample_table,
                x="index",
                y=["ground_truth", "prediction"],
                title=f"{prefix} GT vs Prediction Over Time"
            )

            # log all final-epoch artifacts
            wandb.log({
                f"{prefix}_gt_pred_scatter": scatter,
                f"{prefix}_gt_pred_line": line_plot,
                f"{prefix}_preds_labels_table": sample_table
            }, step=epoch)

    # --- 2) Build per-epoch summary table ---
    metric_keys = [k for k in test_whole.keys() if k not in ("preds", "labels")]
    cols = ["epoch", "stage", "test_subject"] + metric_keys
    summary_table = wandb.Table(columns=cols)

    # whole-set row
    summary_table.add_data(
        epoch, prefix, "whole",
        *[test_whole[k] for k in metric_keys]
    )
    # per-subject rows
    for subj, stats in individual.items():
        summary_table.add_data(
            epoch, prefix, subj,
            *[stats.get(k) for k in metric_keys]
        )

    # --- 3) Flat scalars ---
    flat = {f"{prefix}/train_{k}": v for k, v in train_stats.items()}
    for k, v in test_whole.items():
        if k not in ("preds", "labels"):
            flat[f"{prefix}/test_whole_{k}"] = v
    for subj, stats in individual.items():
        for k, v in stats.items():
            if k in ("preds", "labels"):
                continue
            flat[f"{prefix}/{subj}_{k}"] = v

    # --- 4) Log scalars + summary table ---
    wandb.log({
        **flat,
        f"{prefix}_results": summary_table,
        "epoch": epoch,
        "n_parameters": n_parameters,
    }, step=epoch)

    return flat


def run_phase(
    model, train_loader, test_loader_whole, test_loaders, prefix, args,
    device, criterion, loss_scaler_class,
    log_writer=None, epochs=None, log_dir=None, mixup_fn=None, wandb_log_freq=5
):
    """
    Train for all epochs, log training stats every epoch, evaluate and log test metrics ONCE at the end.
    """
    import os, json
    optimizer    = create_optimizer(args, model)
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    loss_scaler  = loss_scaler_class()
    epochs       = epochs or args.train_epochs

    train_stats_list = []

    for epoch in range(epochs):
        train_stats = train_one_epoch(
            model, criterion, train_loader,
            optimizer, device, epoch, loss_scaler,
            args.clip_grad,
            mixup_fn=mixup_fn,
            log_writer=log_writer, args=args, tag_base=prefix
        )
        train_stats_list.append(train_stats)

        # TensorBoard train logging
        if log_writer:
            tb_base = prefix + "/"
            for k, v in train_stats.items():
                log_writer.add_scalar(f"{tb_base}train_{k}", v, epoch)

        # --- Wandb train logging every N epochs ---
        if epoch % wandb_log_freq == 0 or epoch == epochs - 1:
            _ = wandb_log_stats(
                prefix=prefix,
                epoch=epoch,
                total_epoch=epochs,
                train_stats=train_stats,
                test_whole={},       # No test/eval for now
                individual={},
                args=args,
                n_parameters=n_parameters,
            )

    # ---- EVALUATION (ONCE) ----
    test_whole = evaluate(test_loader_whole, model, criterion, device, args)
    y_pred = test_whole["preds"]
    y_true = test_whole["labels"]

    # Per-subject eval
    individual = {}
    loaders = test_loaders if isinstance(test_loaders, list) else [test_loaders]
    if len(loaders) > 1:
        offset = 0
        for loader in loaders:
            n = len(loader.dataset)
            subj_name = loader.dataset.subjectName
            p = y_pred[offset:offset + n]
            t = y_true[offset:offset + n]
            mse = float(((t - p) ** 2).mean())
            pearson = float(np.corrcoef(t, p)[0, 1])
            individual[subj_name] = {"mse": mse, "pearson": pearson}
            offset += n
            print(subj_name, mse, pearson)

    # TensorBoard test/eval logging (final epoch only)
    if log_writer:
        tb_base = prefix + "/"
        for k, v in test_whole.items():
            if k not in ("preds", "labels"):
                log_writer.add_scalar(f"{tb_base}test_whole_{k}", v, epochs - 1)
        for subj, stats in individual.items():
            for k, v in stats.items():
                if k not in ("preds", "labels"):
                    log_writer.add_scalar(f"{tb_base}{subj}_test_{k}", v, epochs - 1)

    # --- Wandb final test/eval logging ---
    flat = wandb_log_stats(
        prefix=prefix,
        epoch=epochs - 1,
        total_epoch=epochs,
        train_stats=train_stats_list[-1],
        test_whole=test_whole,
        individual=individual,
        args=args,
        n_parameters=n_parameters,
    )

    # JSON dump (optional)
    if log_dir and misc.is_main_process():
        fname = os.path.join(log_dir, f"log_{prefix.replace('/', '_')}")
        with open(fname, "a") as f:
            f.write(json.dumps(flat) + "\n")
                

def train_one_epoch(model: torch.nn.Module, criterion,
                    data_loader: Iterable, optimizer: torch.optim.Optimizer,
                    device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
                    mixup_fn: Optional[Mixup] = None, log_writer=None,
                    args=None, tag_base=None):
    model.train(True)
    metric_logger = misc.MetricLogger(delimiter="  ", rank=args.rank)
    metric_logger.add_meter('lr', misc.SmoothedValue(window_size=1, fmt='{value:.6f}'))
    header = 'Epoch: [{}]'.format(epoch)
    print_freq = 200

    accum_iter = args.accum_iter

    optimizer.zero_grad()

    if log_writer is not None:
        print('log_dir: {}'.format(log_writer.log_dir))

    for data_iter_step, samples in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
        #print(f"start {data_iter_step} iter", flush =True)
        if len(samples)==3:
            eeg = samples[0]
            label = samples[1]
            sensloc = samples[2] 
            #print(eeg.shape,label.shape,sensloc.shape,flush=True)
        elif len(samples)==2:
            eeg = samples[0]
            label = samples[1]
            sensloc = None
            #print(eeg.shape,label.shape,flush=True)
        if eeg.shape[0]==1:
            pass
        else:
            # we use a per iteration (instead of per epoch) lr scheduler
            if data_iter_step % accum_iter == 0:
                lr_sched.adjust_learning_rate(optimizer, data_iter_step / len(data_loader) + epoch, args)
            #print(eeg.shape,sensloc.shape,label.shape)
            eeg = eeg.to(device, non_blocking=True)
            targets = label.to(device, non_blocking=True)
            if sensloc is not None:
                sensloc = sensloc.to(device, non_blocking=True)
            # convert the long input segment to the model required segment
            model_eeg_L = int(args.downstream_task_t*args.model_downstream_task_fs)
            
            # choose your new step size
            step = int(args.model_downstream_task_fs/24) #downsample to 32 hz prediction for faster training 
            #print(f"downsample ratio= {step}")
            # x: (batch, channel, length=1280)
            B, C, L = eeg.shape
            windows = eeg.unfold(dimension=2, size=model_eeg_L, step=step)
            #print(f"windows shape= {windows.shape}")
            windowed_targets = targets[:, :, model_eeg_L - 1 :: step]
            #print(f"windowed targets shape= {windowed_targets.shape}")
            # a) permute to (batch, n_windows, channel, window_size)
            windows = windows.permute(0, 2, 1, 3)

            # b) reshape to (batch * n_windows, channel, window_size)
            B, N_window, _, _ = windows.shape
            B_new = B * N_window
            x2 = windows.reshape(B_new, C, model_eeg_L)
            #print(f"x2: {x2.shape}")
            if sensloc is not None:
                sensloc = sensloc.repeat((N_window, 1))
                
            if args.half_precision:
                with autocast():
                    #print(args.model, eeg.shape, x2.shape, targets.shape)
                    if sensloc is not None:
                        outputs = model(x2,sensloc)
                    else:
                        outputs = model(x2)
                    #print(outputs.shape, targets.shape) #(b*window,1) (b,1,full_window)

                    # 1) reshape the output back into (B, N, 1)
                    y = outputs.reshape(B, N_window, 1)

                    y = y.permute(0, 2, 1)
                    #print(y.shape, windowed_targets.shape) #y: torch.Size([54, 1, 17])  
                    #outputs = torch.unsqueeze(outputs, 2) #from b,n to b,n,1
                    #targets = torch.unsqueeze(targets, 2)
                    loss, loss_shape, loss_temporal = criterion(y, windowed_targets)
            else:
                #print(args.model, eeg.shape, x2.shape, targets.shape)
                if sensloc is not None:
                    outputs = model(x2,sensloc)
                else:
                    outputs = model(x2)
                #print(outputs.shape, targets.shape) #(b*window,ch,1) (b,full_window)

                # 1) reshape the output back into (B, N, 1)
                y = outputs.reshape(B, N_window, 1)

                y = y.permute(0, 2, 1)
                
                #print(y.shape, windowed_targets.shape)
                loss, loss_shape, loss_temporal = criterion(y, windowed_targets)
            # correlation
            #print("corr:", torch.squeeze(y).shape, torch.squeeze(windowed_targets).shape)
            correlation = pearson_per_trial(torch.squeeze(y),torch.squeeze(windowed_targets))
            batch_size = eeg.shape[0]
            loss_value = loss.item()
            #print("loss=",loss_value, flush=True)
            if not math.isfinite(loss_value):
                print(eeg.shape, eeg, targets, outputs)
                print("Loss is {}, stopping training".format(loss_value), flush=True)
                sys.exit(1)

            loss /= accum_iter
            correlation /= accum_iter
            loss_scaler(loss, optimizer, clip_grad=max_norm,
                        parameters=model.parameters(), create_graph=False,
                        update_grad=(data_iter_step + 1) % accum_iter == 0)
            #print("loss scaler update", flush=True)
            if (data_iter_step + 1) % accum_iter == 0:
                optimizer.zero_grad()

            if args.distributed:
                torch.cuda.synchronize()

            metric_logger.update(loss=loss_value)
            metric_logger.meters['loss_shape'].update(loss_shape.item(), n=batch_size)
            metric_logger.meters['loss_temporal'].update(loss_temporal.item(), n=batch_size)
            metric_logger.meters['correlation'].update(correlation.item(), n=batch_size)
            min_lr = 10.
            max_lr = 0.
            for group in optimizer.param_groups:
                min_lr = min(min_lr, group["lr"])
                max_lr = max(max_lr, group["lr"])

            metric_logger.update(lr=max_lr)
            #print("metri logger update", flush=True)
            if args.distributed:
                loss_value_reduce = misc.all_reduce_mean(loss_value)
            else:
                loss_value_reduce = loss_value
            if log_writer is not None and (data_iter_step + 1) % accum_iter == 0:
                """ We use epoch_1000x as the x-axis in tensorboard.
                This calibrates different curves when batch size changes.
                """
                epoch_1000x = int((data_iter_step / len(data_loader) + epoch) * 1000)
                log_writer.add_scalar(tag_base +'/step_loss', loss_value_reduce, epoch_1000x)
                log_writer.add_scalar(tag_base +'/lr', max_lr, epoch_1000x)
                #print("log writer update", flush=True)
            #print(f"end {data_iter_step} iter", flush =True)

    # gather the stats from all processes
    if args.distributed:
        metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger, flush=True)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}



@torch.no_grad()
def evaluate(data_loader, model, criterion, device, args=None):
    """
    Evaluate `model` on `data_loader`, returning a dict of:
      { 'loss', 'loss_shape', 'loss_temporal', 'mse', 'pearson', 'preds', 'labels' }
    """
    model.eval()
    metric_logger = misc.MetricLogger(delimiter="  ")
    metric_logger.add_meter('loss_shape', misc.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter('loss_temporal', misc.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter('mse', misc.SmoothedValue(window_size=1, fmt='{value:.4f}'))
    metric_logger.add_meter('pearson', misc.SmoothedValue(window_size=1, fmt='{value:.4f}'))

    all_targets, all_preds = [], []
    for samples in metric_logger.log_every(data_loader, 100, 'Test:'):
        eeg, label = samples[:2]
        sensloc    = samples[2] if len(samples) == 3 else None
        eeg, targets = eeg.to(device), label.to(device)
        if sensloc is not None:
            sensloc = sensloc.to(device)
        if args.half_precision:
            with autocast():
                #print(eeg.shape, targets.shape) #eeg: 32 64 256, targets: 32 256
                outputs = model(eeg, sensloc) if sensloc is not None else model(eeg) # outputs=(32,1) 
        else:
            outputs = model(eeg, sensloc) if sensloc is not None else model(eeg) # outputs=(32,1) 
        targets = targets.squeeze()[:,-1]
        #outputs = torch.unsqueeze(outputs, 2) #from b,n to b,n,1
        #targets = torch.unsqueeze(targets, 2)
        #loss, loss_shape, loss_temporal = criterion(outputs, targets)
        #print(outputs.shape, targets.shape, flush=True) # 32 1 32 256
        batch_size = eeg.size(0)
        all_preds.append(outputs.squeeze())    # still a torch.Tensor
        all_targets.append(targets.squeeze())

        #metric_logger.update(loss=loss.item())
        #metric_logger.update(loss_shape=loss_shape.item(), n=batch_size)
        #metric_logger.update(loss_temporal=loss_temporal.item(), n=batch_size)

    # cat along the batch dimension
    y_pred = torch.cat(all_preds, dim=0).cpu().numpy().astype(np.float64)
    y_true = torch.cat(all_targets, dim=0).cpu().numpy().astype(np.float64)
    
    y_pred = y_pred.reshape(-1)   # same as y_pred.ravel() or y_pred.flatten()
    y_true = y_true.reshape(-1)

    #print(y_pred.shape, y_true.shape)
    
    # compute regression metrics
    
    mse = float(np.mean((y_true - y_pred) ** 2))
    pearson = float(np.corrcoef(y_true, y_pred)[0, 1])
    print(mse, pearson)
    metric_logger.update(mse=mse, pearson=pearson)

    return {
        **{k: m.global_avg for k, m in metric_logger.meters.items()},
        "preds": y_pred,
        "labels": y_true
    }