import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import optuna
import copy
import json
import argparse
import os
from torch.utils.data import Dataset
import random

from next_char_prediction_utils.char_data_utils import load_dataset_and_filter, CharacterProcessor
from next_char_prediction_utils.token_data_utils import TokenProcessor
from models import CharacterLSTM, CharacterTransformer
import pickle

from cifar10_training import get_optimizer
from utils import set_global_seed
import search_spaces


class LMDataset(Dataset):
    def __init__(self, texts, processor, sequence_length=50, stride=1):
        self.seq_len = sequence_length
        self.stride = stride
        self.texts = texts
        self.encoded_texts = [processor.encode(text, add_special_tokens=False) for text in texts]
        self.processor = processor

        self.cum_num_seq_in_text = []
        cumulative = 0
        for encoded_text in self.encoded_texts:
            text_len = len(encoded_text)
            
            if text_len < self.seq_len + 1:
                num_seqs = 1
            else:
                num_seqs = (text_len - self.seq_len - 1) // self.stride + 1
            
            cumulative += num_seqs
            self.cum_num_seq_in_text.append(cumulative)
        
    def __len__(self):
        return self.cum_num_seq_in_text[-1] if self.cum_num_seq_in_text else 0

    def _binSearch_text(self, idx):
        l, r = 0, len(self.cum_num_seq_in_text)
        while l < r:
            mid = l + (r - l) // 2
            if self.cum_num_seq_in_text[mid] < idx + 1:
                l = mid + 1
            elif self.cum_num_seq_in_text[mid] > idx + 1:
                r = mid
            else:
                return mid
        return l

    def __getitem__(self, idx):
        if idx < 0 or idx >= len(self):
            raise IndexError(f"Index {idx} out of range")
        text_idx = self._binSearch_text(idx)
        prev_total = self.cum_num_seq_in_text[text_idx - 1] if text_idx > 0 else 0
        seq_subidx = idx - prev_total

        encoded_text = self.encoded_texts[text_idx]

        if len(encoded_text) < self.seq_len + 1:
            # Pad if text is too short
            pad_idx = self.processor.char_to_idx['<PAD>']
            encoded_text = encoded_text + [pad_idx for _ in range(self.seq_len + 1 - len(encoded_text))]
            
        input_seq = encoded_text[seq_subidx * self.stride : seq_subidx * self.stride + self.seq_len]
        target_seq = encoded_text[seq_subidx * self.stride + 1 : seq_subidx * self.stride + self.seq_len + 1]
        if len(input_seq) != len(target_seq):
            print(input_seq, target_seq)
            raise ValueError("Input and target sequences have different lengths.")
        return torch.tensor(input_seq, dtype=torch.long), torch.tensor(target_seq, dtype=torch.long)

def get_data(data_path, batch_size, seed=42, texts_amount=5000, token_data=False):
    if token_data:
        if not os.path.exists(f'{data_path}'):
            samples = load_dataset_and_filter(data_path)
        else:
            with open(f'{data_path}', 'rb') as handle:
                samples = pickle.load(handle)
    else:
        if not os.path.exists(f'{data_path}'):
            raise ValueError
        else:
            with open(f'{data_path}', 'rb') as handle:
                samples = pickle.load(handle)

    # Limit texts per language
    for lang in samples:
        random.shuffle(samples[lang])
        samples[lang] = samples[lang][:texts_amount]
    data = []
    for lang, texts in samples.items():
        data.extend(texts)
    if token_data:
        processor = TokenProcessor(data)
    else:
        processor = CharacterProcessor(data)
    vs = processor.vocab_size
    train_dataset_full = LMDataset(texts=data, processor=processor, sequence_length=50, stride=(50 // 4))

    train_len = int(0.8 * len(train_dataset_full))
    val_len = int(0.1 * len(train_dataset_full))
    test_len = len(train_dataset_full) - val_len - train_len
    g = torch.Generator().manual_seed(seed)
    train_subset, val_subset, test_subset = torch.utils.data.random_split(train_dataset_full, [train_len, val_len, test_len], generator=g)

    val_subset = copy.deepcopy(val_subset)
    test_subset = copy.deepcopy(test_subset)
    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    return train_loader, val_loader, test_loader, vs


def get_model(model_name, vocab_size=256):
    if model_name == 'char_lstm':
        return CharacterLSTM(vocab_size=vocab_size)
    elif model_name == 'char_transformer':
        return CharacterTransformer(vocab_size=vocab_size)
    else:
        raise ValueError(f"Unknown model name: {model_name}")


def evaluate_model(model, loader, vocab_size, desc="Evaluating", device='cpu'):
    model.eval()
    correct = 0
    total = 0
    loss = 0
    with torch.no_grad():
        for input_seq, target_seq in tqdm(loader, desc=desc, leave=False):
            input_seq = input_seq.to(device)
            target_seq = target_seq.to(device).reshape(-1)
            mask = (target_seq != 0)

            logits, *_ = model(input_seq)
            logits = logits.reshape(-1, vocab_size)
            _, predicted = torch.max(logits.data, 1)

            total += mask.sum().item()
            correct += ((predicted == target_seq) & mask).sum().item()
            loss += nn.CrossEntropyLoss()(logits, target_seq)
    accuracy = 100 * correct / total if total > 0 else 0
    perplexity = torch.exp(loss / len(loader)).item() if (loss / len(loader)) < 10 else float('inf')
    return accuracy, perplexity


def train(num_epoches, model, train_loader, optimizer, criterion, eval_loader, test_loader=None, device='cpu', trial=None, vocab_size=None, clipping=None, scheduler=None):
    eval_accuracies = []
    eval_perplexities = []
    test_accuracies = []
    test_perplexities = []
    for epoch in range(num_epoches):
        model.train() # Set model to training mode
        running_loss = 0.0
        # Training loop for one epoch
        for input_seq, target_seq in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epoches} [Training]"):
            input_seq = input_seq.to(device)
            target_seq = target_seq.to(device)

            optimizer.zero_grad()
            logits, _ = model(input_seq)
            loss = criterion(logits.reshape(-1, vocab_size), target_seq.reshape(-1))
            loss.backward()
            if clipping is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), clipping, norm_type='inf')
            optimizer.step()
            if scheduler is not None:
                scheduler.step()
            running_loss += loss.item()
        eval_acc, eval_perplexity = evaluate_model(model, eval_loader, vocab_size=vocab_size, desc=f"Epoch {epoch+1}/{num_epoches} [Evaluating]", device=device)
        test_acc, test_perplexity = -1, -1
        if test_loader is not None:
            test_acc, test_perplexity = evaluate_model(model, test_loader, vocab_size=vocab_size, desc=f"Epoch {epoch+1}/{num_epoches} [Testing]", device=device)
            test_accuracies.append(test_acc)
            test_perplexities.append(test_perplexity)
        if trial is not None:
            trial.report(eval_acc, epoch)
            if trial.should_prune():
                raise optuna.exceptions.TrialPruned()
        eval_accuracies.append(eval_acc)
        eval_perplexities.append(eval_perplexity)
        print(f"Epoch [{epoch+1}/{num_epoches}], Loss: {running_loss/len(train_loader):.4f}, Eval Accuracy: {eval_acc:.2f}%, Test Accuracy: {test_acc:.2f}%")
    return eval_accuracies, eval_perplexities, test_accuracies, test_perplexities


def suggest_params(trial, search_space):
    params = {}
    for param, config in search_space.items():
        if isinstance(config, dict):
            if config['type'] == 'float':
                params[param] = trial.suggest_float(param, config['min'], config['max'], log=config.get('log', False))
            else:
                params[param] = trial.suggest_int(param, config['min'], config['max'], log=config.get('log', False))
    return params


def tune(n_trials, data_path, model_name, search_space, optimizer_name, device, num_epoches, n_startup_trials, path, seed=42, token_data=False, texts_amount=5000):
    """Original tune function, exactly restored."""
    def objective(trial):
        set_global_seed(seed)
        train_loader, eval_loader, _, vs = get_data(data_path, batch_size=32, texts_amount=texts_amount, token_data=token_data)
        model = get_model(model_name=model_name, vocab_size=vs).to(device)
        n_iters = num_epoches * len(train_loader) 
        criterion = nn.CrossEntropyLoss()
        optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=trial, n_iters=n_iters)
        
        scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)

        val_accuracies, val_perplexities, *_ = train(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                  optimizer=optimizer, criterion=criterion, scheduler=scheduler,
                                  eval_loader=eval_loader, test_loader=None, device=device, 
                                  trial=trial, clipping=clipping, vocab_size=vs)

        if trial.number == 0 or trial.study.best_value < max(val_accuracies):
            best_epoch = np.argmax(np.array(val_accuracies))
            result = trial.params | {'val_score': val_accuracies[best_epoch], 'val_perplexity': val_perplexities[best_epoch], 'trial': trial.number}
            save_result(result, path, optimizer_name)

        return max(val_accuracies)
    
    pruner = optuna.pruners.MedianPruner()
    sampler = optuna.samplers.TPESampler(seed=seed, n_startup_trials=n_startup_trials)

    os.makedirs(os.path.join(path, 'study'), exist_ok=True)
    storage_path = f"sqlite:///{os.path.join(path, 'study', optimizer_name + '.db')}"

    study = optuna.create_study(
        study_name=f"{optimizer_name}",
        storage=storage_path,
        direction="maximize",
        pruner=pruner,
        sampler=sampler,
        load_if_exists=True
    )
    study.optimize(objective, n_trials=n_trials)
    
    best_trial = study.best_trial

    set_global_seed(seed)
    train_loader, eval_loader, test_loader, vs = get_data(data_path, batch_size=search_space['batch_size'], token_data=token_data)
    model = get_model(model_name=model_name, vocab_size=vs).to(device)
    criterion = nn.CrossEntropyLoss()
    n_iters = num_epoches * len(train_loader)
    optimizer, (clipping, scheduler) = get_optimizer(optimizer_name, model, search_space, trial=None, optimizer_params=best_trial.params, n_iters=n_iters)

    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epoches)

    val_accuracies, val_perplexities, test_accuracies, test_perplexities = train(num_epoches=num_epoches, model=model, train_loader=train_loader, 
                                            optimizer=optimizer, criterion=criterion, scheduler=scheduler,
                                            eval_loader=eval_loader, test_loader=test_loader, 
                                            device=device, clipping=clipping, vocab_size=vs)
    
    best_epoch = np.argmax(np.array(val_accuracies))
    result = best_trial.params | {
        'val_score': val_accuracies[best_epoch],
        'val_perplexity': val_perplexities[best_epoch],
        'test_score': test_accuracies[best_epoch],
        'test_perplexity': test_perplexities[best_epoch]
    }
    return {key: float(value) for key, value in result.items()}


def get_arguments():
    parser = argparse.ArgumentParser(description="Hyperparameter tuning for multilingual NLI character-level LSTM model.")
    parser.add_argument('--model', type=str, default='char_lstm', choices=['char_lstm', 'char_transformer'],
                        help='Model architecture to use.')
    parser.add_argument('--optimizer', type=str, default='AdamW', choices=['Muon',
                                                                           'SignumDL', 'SignumDLNesterov',
                                                                           'Signum', 'SignumLinearLR',
                                                                           'Signum_decoupled_wd', 'Signum_decoupled_wd_LinearLR',
                                                                           'AdamW', 'Adam', 'AdamEps', 'AdamWBetas',
                                                                           'AdamWClip', 'AdamEpsScheduling', 'AdamBetaScheduling', 'AdamPaLM2', 
                                                                           'SGD', 'SGDLinearLR', 'SGDCosineAnnealingLR', 'SGDClip', 'SGDLinearLR+Clip',
                                                                           'SoftSignumPT', 'SoftSignumPT-auto', 'SoftSignumPT-const',
                                                                           'SoftSignumPT_not_decoupled_wd', 'SoftSignumPT_not_decoupled_wd-auto', 'SoftSignumPT_not_decoupled_wd-const',
                                                                           'SoftSignumSGD', 'SoftSignumSGD-auto', 'SoftSignumSGD-const',
                                                                           'SoftSignumSGD_not_decoupled_wd', 'SoftSignumSGD_not_decoupled_wd-auto', 'SoftSignumSGD_not_decoupled_wd-const',
                                                                           'Signum+SGD', 'Signum+SGD_not_decoupled_wd'],
                        help='Optimizer to use for training.')
    parser.add_argument('--user', type=str, required=True, choices=['user1', 'user2', 'user3'],
                        help='Name of the user folder for saving the results.')
    parser.add_argument('--data_path', type=str, default='./data/raw_texts.pkl',
                        help='Root path to the MoritzLaurer/multilingual-NLI-26lang-2mil7 directory.')
    parser.add_argument('--n_trials', type=int, default=30, help='Number of Optuna trials.')
    parser.add_argument('--n_startup_trials', type=int, default=10, help='Number of Optuna startup trials.')
    parser.add_argument('--max_epochs', type=int, default=20, help='Epochs for training each trial.')
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
                        help='Device to use for training.')
    parser.add_argument('--texts_amount', type=int, default=5000, help='Number of texts per language to use.')
    parser.add_argument('--token_data', type=bool, default=False, help='Type of data.')
    parser.add_argument('--seed', type=int, default=42, help='Global random seed.')
    return parser.parse_args()


def save_result(
    result: dict,
    path: str,
    optimizer_name: str
):
    os.makedirs(path, exist_ok=True)
    with open(f'{path}/{optimizer_name}.json', 'w') as f:
        json.dump(result, f)


if __name__ == '__main__':
    args = get_arguments()
    selected_search_space = search_spaces.search_spaces_map[args.optimizer]
    selected_search_space['batch_size'] = 64

    path = f'tuning/{args.user}/multilingual_nli/{args.model}'
    if args.token_data:
        path = f'tuning/{args.user}/heavy_tailed_data/{args.model}'

    results = tune(
        n_trials=args.n_trials,
        data_path=args.data_path,
        model_name=args.model,
        search_space=selected_search_space,
        optimizer_name=args.optimizer,
        device=args.device,
        num_epoches=args.max_epochs,
        n_startup_trials=args.n_startup_trials,
        seed=args.seed,
        texts_amount=args.texts_amount,
        path=path,
        token_data=args.token_data
    )

    save_result(results, path, args.optimizer)