import random
import numpy as np
import os, json
import yaml
from argparse import ArgumentParser
from tqdm import tqdm
import multiprocessing as mp

import torch
from torch.utils.data import DataLoader

from train_utils.helper import construct_agent_sim

from train_utils.dataset import SimData, sample_data
from train_utils.bandit import LinearBandit, QuadBandit, LogisticBandit
from train_utils.bandit import QuadFormBandit, load_A, generate_and_save_A
from train_utils.bandit import DistBandit


try:
    import wandb
except ImportError:
    wandb = None

file_dir = "./data/linear/LinES.csv"
out_dir = "./data/linear"
base_file = os.path.join(out_dir, "LinES.csv")

def run(config, args):
    seed = random.randint(1, 100000)
    # seed = 42
    print(f'Random seed: {seed}')
    torch.manual_seed(seed)
    if args.log and wandb:
        group = config['group'] if 'group' in config else None
        run = wandb.init(
            project=config['project'],
            group=group,
            config=config)
        config = wandb.config

    if args.cpu:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    # Parse argument
    data = torch.load(config['datapath'])
    theta = data['theta'].to(device)
    sigma = config['sigma']
    T = config['T']

    # Create bandit from dataset
    index = config['index'] if 'index' in config else 0
    num_data = config['num_data'] if 'num_data' in config else None
    dataset = SimData(config['datapath'], num_data=num_data, index=index)
    loader = DataLoader(dataset, shuffle=False)
    loader = sample_data(loader)
    if config['func'] == 'linear':
        bandit = LinearBandit(theta=theta, sigma=sigma)
    elif config['func'] == 'quad':
        bandit = QuadBandit(theta=theta, sigma=sigma)
    elif config['func'] == 'logistic':
        bandit = LogisticBandit(theta=theta, sigma=sigma)
    elif config['func'] == 'dist':
        bandit = DistBandit(theta=theta, sigma=sigma)
    elif config['func'] == 'quadform':
        #generate_and_save_A(dim_context=20, out_path="data/A.pt", seed=42)
        A = load_A("data/A.pt", map_location="cpu")
        bandit = QuadFormBandit(A, sigma=sigma)
    else:
        raise ValueError('Reward model not defined!')
    print(config)
    # ------------- construct strategy --------------------

    agent = construct_agent_sim(config, device)
    # ---------------------------------------------------
    pbar = tqdm(range(T), dynamic_ncols=True, smoothing=0.1)

    regret_history = []
    accum_regret = 0
    for e in pbar:
        context = next(loader)
        context = context[0].to(device)
        arm_to_pull = agent.choose_arm(context)
        reward, regret = bandit.get_reward(context, arm_to_pull)
        agent.receive_reward(arm_to_pull, context[arm_to_pull], reward)
        agent.update_model(num_iter=min(e + 1, config['num_iter']))
        regret_history.append(regret.item())
        accum_regret += regret.item()

        pbar.set_description(
            (
                f'Accumulative regret: {accum_regret}'
            )
        )
        if wandb and args.log:
            wandb.log(
                {
                    'Regret': accum_regret
                }
            )
    if wandb and args.log:
        run.finish()
    print('Done!')
    print("Writing regret history into file")
    print(file_dir)
    np.savetxt(file_dir,
               np.asarray(regret_history, dtype=float),
               delimiter=",", header="instant_regret", comments="")
    print("Finished writing data!")


if __name__ == '__main__':
    torch.backends.cudnn.benchmark = True
    parser = ArgumentParser(description="basic paser for bandit problem")
    parser.add_argument('--config_path', type=str,
                        default='sweep/sweep-default.yaml')
    parser.add_argument('--log', action='store_true', default=False)
    parser.add_argument('--cpu', action='store_true', default=False)
    parser.add_argument('--repeat', type=int, default=1)
    parser.add_argument('--n_exp', type=int, default=None,
                        help='Number of experiments to repeat; overrides --repeat if given')
    args = parser.parse_args()

    with open(args.config_path, 'r') as stream:
        config = yaml.load(stream, yaml.FullLoader)

    # Determine number of experiments
    n_exp = args.n_exp if args.n_exp is not None else args.repeat
    if n_exp is None or n_exp < 1:
        n_exp = 1

    def _load_numeric_table(fpath):
        import numpy as np

        # Detect delimiter from first line
        with open(fpath, 'r', encoding='utf-8', errors='ignore') as _fh:
            _first = _fh.readline()
        _delim = ',' if ',' in _first else ('	' if '	' in _first else None)

        try:
            if _delim is None:
                try:
                    return np.loadtxt(fpath)
                except Exception:
                    return np.genfromtxt(fpath, names=False)
            else:
                try:
                    return np.loadtxt(fpath, delimiter=_delim, skiprows=1)
                except Exception:
                    try:
                        return np.loadtxt(fpath, delimiter=_delim)
                    except Exception:
                        pass
        except Exception:
            pass

        _arr = np.genfromtxt(fpath, delimiter=_delim, names=True, dtype=None, encoding='utf-8')
        if getattr(_arr, "dtype", None) is not None and _arr.dtype.names:
            _num_cols = [name for name in _arr.dtype.names if np.issubdtype(_arr[name].dtype, np.number)]
            if len(_num_cols) == 0:
                raise ValueError(f"No numeric columns found in {fpath}.")
            _mat = np.column_stack([_arr[name] for name in _num_cols])
            return _mat
        raise ValueError(f"Could not parse numeric data from {fpath}.")

    os.makedirs(out_dir, exist_ok=True)

    base_name = os.path.splitext(os.path.basename(base_file))[0]
    per_run_files = []
    all_runs = []

    for i in range(n_exp):
        print(f"=== Experiment {i+1}/{n_exp} ===")
        run(config, args)
        if os.path.exists(base_file):
            try:
                arr = _load_numeric_table(base_file)
            except Exception as e:
                print(f"Error reading {base_file}: {e}")
                continue

            # Normalize to 2D array
            arr = arr.reshape(-1, 1) if arr.ndim == 1 else arr

            # Save a copy per run to avoid being overwritten
            per_file = os.path.join(out_dir, f"{base_name}_run_{i+1}.csv")
            try:
                _header = ",".join([f"col{j+1}" for j in range(arr.shape[1])])
                _steps = np.arange(1, arr.shape[0] + 1).reshape(-1, 1)
                _combo = np.column_stack((_steps, arr))
                _hdr = "step," + _header
                np.savetxt(per_file, _combo, delimiter=",", header=_hdr, comments="")
                per_run_files.append(per_file)
            except Exception as e:
                print("Warning: could not save per-run file:", e)
            all_runs.append(arr)
        else:
            print(f"Warning: expected output {base_file} not found after run {i+1}. Skipping aggregation for this run.")

    if len(all_runs) > 0:
        min_len = min(a.shape[0] for a in all_runs)
        if any(a.shape[0] != min_len for a in all_runs):
            print(f"Note: run lengths differ; truncating all to length {min_len}.")
            all_runs = [a[:min_len] for a in all_runs]

        min_cols = min(a.shape[1] for a in all_runs)
        if any(a.shape[1] != min_cols for a in all_runs):
            print(f"Note: number of columns differ; truncating all to {min_cols} columns.")
            all_runs = [a[:, :min_cols] for a in all_runs]

        stacked = np.stack(all_runs, axis=0)
        mean_curve = stacked.mean(axis=0)
        var_curve = stacked.var(axis=0, ddof=1) if stacked.shape[0] > 1 else np.zeros_like(mean_curve)

        # Save mean and variance
        mean_file = os.path.join(out_dir, f"{base_name}_mean.csv")
        var_file  = os.path.join(out_dir, f"{base_name}_var.csv")
        combo_file = os.path.join(out_dir, f"{base_name}_mean_var.csv")

        steps = np.arange(1, mean_curve.shape[0] + 1).reshape(-1, 1)
        cols = []
        header_parts = ["step"]
        for c in range(mean_curve.shape[1] if mean_curve.ndim == 2 else 1):
            mc = mean_curve[:, c] if mean_curve.ndim == 2 else mean_curve
            vc = var_curve[:, c] if var_curve.ndim == 2 else var_curve
            cols.append(mc.reshape(-1, 1))
            cols.append(vc.reshape(-1, 1))
            header_parts.append(f"col{c+1}_mean")
            header_parts.append(f"col{c+1}_var")
        combo = np.concatenate([steps] + cols, axis=1)
        header = ",".join(header_parts)

        np.savetxt(mean_file, mean_curve, delimiter=",")
        np.savetxt(var_file,  var_curve,  delimiter=",")
        np.savetxt(combo_file, combo, delimiter=",", header=header, comments="")

        print("Saved:")
        print("  Per-run copies:", per_run_files)
        print("  Mean:", mean_file)
        print("  Variance:", var_file)
        print("  Combined:", combo_file)
    else:
        print("No runs produced outputs; nothing to aggregate.")

