
import argparse
import torch
import os
import numpy as np
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import aiohttp
from collections import defaultdict
import numpy as np
import torch
import wandb
from utils import chunk_by_sentences
from torch.nn.utils.rnn import pad_sequence
from dataloader import EmbeddingDataset, LateChunkingEmbeddingDataset
from utils import set_seed
from model import TextClassifier, SATPool, spectral_token_compression

# Configuration dictionary
config = {
    "datasets": {
        "hyperpartisan": "pietrolesci/hyperpartisan_news_detection",
        "20news": "SetFit/20_newsgroups",
        "eurlex": "Muennighoff/multi_eurlex"
    },
    "models": {
        "longformer": "allenai/longformer-base-4096",
        "longt5": "google/long-t5-tglobal-base",
        "longclip": "zer0int/LongCLIP-GmP-ViT-L-14",
        "bert": "bert-base-uncased",
        "roberta": "roberta-base"
    },
    "chunk_sizes": {
        "longformer": 4096,
        "longt5": 4096,
        "longclip": 248,
        "bert": 512,
        "roberta": 512
    }
}


# Training loop
def train(args, model, device, loader, optimizer, criterion,satpool=None,tokenizer=None):
    model.train()
    total_loss = 0
    global_step = 0
    for batch in loader:
        embeddings = batch['embeddings'].to(device)  # [B, L, D]
        if args.mode=='mean':
            embeddings = torch.mean(embeddings, dim=1)
        elif args.mode=='stc':
            embeddings = spectral_token_compression(embeddings, K=args.K, gate=args.gate)
        elif args.mode=='sat':
            embeddings = satpool(embeddings)
        elif args.mode=='lc':
            spans = batch['spans']
            # Process each item in the batch individually
            pooled_outputs = []
            for i in range(embeddings.size(0)):
                # Pass the i-th embedding tensor and the i-th list of spans
                single_pooled_output = model.late_chunk_pooling(embeddings[i], spans[i])
                pooled_outputs.append(single_pooled_output)
            
            # Stack the results back into a batch tensor [B, D]
            embeddings = torch.stack(pooled_outputs)
            
        else:
            pass
        
        labels = batch['labels'].to(device)
        outputs = model(embeddings)
        
        # For BCEWithLogitsLoss, reshape and cast to float
        if args.dataset=='hyperpartisan':
            labels = labels.float().unsqueeze(1)

        elif args.dataset == 'eurlex':
            labels = labels.float()
        try:
            loss = criterion(outputs, labels)
        except:
            loss = criterion(outputs, labels.to(torch.long))
        optimizer.zero_grad()
        loss.backward()
        if args.gradient_norm:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

# Evaluation loop
def evaluate(args, model, device, loader, criterion, dataset_name,satpool=None):
    model.eval()
    total_loss = 0
    predictions = []
    labels = []
    with torch.no_grad():
        for batch in loader:
            embeddings = batch['embeddings'].to(device)  # [B, L, D]
            if args.mode=='mean':
                embeddings = torch.mean(embeddings, dim=1)
            elif args.mode=='stc':
                embeddings = spectral_token_compression(embeddings, K=args.K, gate=args.gate)
            elif args.mode=='sat':
                embeddings = satpool(embeddings)
            elif args.mode == 'lc':
                spans = batch['spans']
                # Process each item in the batch individually
                pooled_outputs = []
                for i in range(embeddings.size(0)):
                    # Pass the i-th embedding tensor and the i-th list of spans
                    single_pooled_output = model.late_chunk_pooling(embeddings[i], spans[i])
                    pooled_outputs.append(single_pooled_output)
                
                # Stack the results back into a batch tensor [B, D]
                embeddings = torch.stack(pooled_outputs)
            else:
                pass
            batch_labels = batch['labels'].to(device)
            
            outputs = model(embeddings)
            
            # For BCEWithLogitsLoss, reshape and cast to float
            if args.dataset=='hyperpartisan':
                batch_labels = batch_labels.float().unsqueeze(1)
            try:
                loss = criterion(outputs, batch_labels)
            except:
                loss = criterion(outputs, batch_labels.to(torch.long))
            total_loss += loss.item()
            logits = outputs.detach().cpu().numpy()
            if dataset_name in ['hyperpartisan', 'eurlex']:
                pred = (logits > 0.5).astype(int)
            else:
                pred = logits.argmax(-1)
            predictions.extend(pred)
            labels.extend(batch_labels.detach().cpu().numpy())
    average = 'micro' if dataset_name == 'eurlex' else ('binary' if dataset_name == 'hyperpartisan' else 'weighted')
    return total_loss / len(loader), accuracy_score(labels, predictions), f1_score(labels, predictions, average=average)

def chunk_by_sentences(input_text: str, tokenizer: callable):
    inputs = tokenizer(input_text, return_tensors='pt', return_offsets_mapping=True)
    punctuation_mark_id = tokenizer.convert_tokens_to_ids('.')
    sep_id = tokenizer.convert_tokens_to_ids('[SEP]')
    token_offsets = inputs['offset_mapping'][0]
    token_ids = inputs['input_ids'][0]
    chunk_positions = [
        (i, int(start + 1))
        for i, (token_id, (start, end)) in enumerate(zip(token_ids, token_offsets))
        if token_id == punctuation_mark_id and (
            token_offsets[i + 1][0] - token_offsets[i][1] > 0 or token_ids[i + 1] == sep_id
        )
    ]
    span_annotations = [(x[0], y[0]) for x, y in zip([(1, 0)] + chunk_positions[:-1], chunk_positions)]
    return span_annotations


def generate_embeddings(hf_dataset, tokenizer, model, chunk_size, device,column_name, mode='cls',stride_ratio=0.5):
    embeddings, labels = [], []
    spans = []
    model.to(device)
    model.eval()
    stride = int(chunk_size * (1-stride_ratio))
    for sample in tqdm(hf_dataset,desc='processing'):
        
        text = sample[column_name[0]]
        if mode == 'lc':
            span_annotations = chunk_by_sentences(text, tokenizer=tokenizer)
            spans.append(span_annotations)


        if isinstance(sample[column_name[1]], list):
            num_classes = num_labels  # Replace this with correct number for Eurlex
            label = torch.zeros(num_classes, dtype=torch.float)
            label[sample[column_name[1]]] = 1.0
        else:
            label = torch.tensor(sample[column_name[1]], dtype=torch.float)
        tokens = tokenizer(text, return_tensors='pt', truncation=False,return_offsets_mapping=True)
        input_ids = tokens['input_ids'][0]
        attention_mask = tokens['attention_mask'][0]
        offsets = tokens['offset_mapping'][0]
        token_embeddings = defaultdict(list)

        chunks = []
        with torch.no_grad():
            for start in range(0, len(input_ids),stride):
                end = start + chunk_size
                input_chunk = input_ids[start:end]
                attn_chunk = attention_mask[start:end]
                offset_chunk = offsets[start:end]
                
                if len(input_chunk)<chunk_size:
                    pad_len = chunk_size - len(input_chunk)
                
                    input_chunk = torch.nn.functional.pad(input_chunk, (0, pad_len), value=tokenizer.pad_token_id)
                    attn_chunk = torch.nn.functional.pad(attn_chunk, (0, pad_len), value=0)
                    offset_chunk = torch.nn.functional.pad(offset_chunk, (0, 0,0,pad_len), value=0)
                if mode =='cls':
                    chunks.append({
                        'input_ids': input_chunk.unsqueeze(0).to(device),
                        'attention_mask': attn_chunk.unsqueeze(0).to(device),
                    })
                else:
                    chunk_input = {
                        'input_ids': input_chunk.unsqueeze(0).to(device),
                        'attention_mask': attn_chunk.unsqueeze(0).to(device),
                    }
                    
                    out = model(**chunk_input,output_hidden_states=True)
                    hs = out.hidden_states[-2].squeeze(0).cpu()
                    for i, offset in enumerate(offset_chunk):
                        start_c, end_c = offset.tolist()
                        if start_c ==0 and end_c ==0:
                            continue
                        true_idx = start + i
                        token_embeddings[true_idx].append(hs[i])

            collected_vectors = []
            if mode =='cls':
                for chunk in chunks:
                    out = model(**chunk)
                    vec = out.last_hidden_state[:,0,:]
                    collected_vectors.append(vec.cpu())
                doc_vec = torch.mean(torch.cat(collected_vectors, dim=0),dim=0)
                embeddings.append(doc_vec)
                label_tensor = torch.tensor(float(label),dtype=torch.float) if isinstance(label, bool) else torch.tensor(label)
                labels.append(label_tensor)
                

            else:
                doc_embedding = []
                for idx in sorted(token_embeddings):
                    vecs = torch.stack(token_embeddings[idx],dim=0)
                    
                    doc_embedding.append(vecs.mean(dim=0))
                if len(doc_embedding)==0:
                    doc_embedding = torch.zeros((1,model.config.hidden_size))
                else:
                    doc_embedding = torch.stack(doc_embedding)
                embeddings.append(doc_embedding)
                label_tensor = torch.tensor(float(label),dtype=torch.float) if isinstance(label, bool) else torch.tensor(label)
                labels.append(label_tensor)
    return embeddings, labels, spans


def token_collate_fn(batch):
    embeddings = [item['embeddings'] for item in batch]  # list of [L_i, D] tensors
    padded_embeddings = pad_sequence(embeddings, batch_first=True)  # [B, L_max, D]

    labels = [item['labels'] for item in batch]
    if isinstance(labels[0], torch.Tensor) and labels[0].dim() > 0:
        # Multilabel case (e.g., Eurlex)
        labels = torch.stack(labels)
    else:
        # Single label case (e.g., 20news, hyperpartisan)
        labels = torch.tensor(labels).long()

    return {
        'embeddings': padded_embeddings,
        'labels': labels,
    }

def token_collate_fn_latechunking(batch):
    embeddings = [item['embeddings'] for item in batch]  # list of [L_i, D] tensors
    padded_embeddings = pad_sequence(embeddings, batch_first=True)  # [B, L_max, D]

    labels = [item['labels'] for item in batch]
    spans = [item['spans'] for item in batch]
    if isinstance(labels[0], torch.Tensor) and labels[0].dim() > 0:
        # Multilabel case (e.g., Eurlex)
        labels = torch.stack(labels)
    else:
        # Single label case (e.g., 20news, hyperpartisan)
        labels = torch.tensor(labels).long()

    return {
        'embeddings': padded_embeddings,
        'labels': labels,
        'spans': spans,
    }


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', default='20news')
    parser.add_argument('--model', default='bert')
    parser.add_argument('--mode', default='cls')
    parser.add_argument('--classifier', type=str, default='mlp')
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--num_exp', type=int, default=1)
    parser.add_argument('--num_epochs', type=int, default=50)
    parser.add_argument('--proj_dim', type=int, default=256)
    parser.add_argument('--d_attn', type=int, default=128)
    parser.add_argument('--lr', type=float, default=1e-4)
    parser.add_argument('--K', type=int, default=4, help='Number of frequency bins for STC')
    parser.add_argument('--gate', type=str, default='none', choices=['softmax', 'none', 'learned'], help='Frequency gate type')
    parser.add_argument('--use_norm', action='store_true', help='Use LayerNorm and dropout before classification')
    parser.add_argument('--gradient_norm', action='store_true', help='Track gradient norm during training')
    parser.add_argument('--normalize_positions', action='store_true', help='Normalize positional encoding to [0,1]')
    parser.add_argument('--wandb', action='store_true')


    args = parser.parse_args()
    set_seed(42)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset_name = config['datasets'][args.dataset]
    model_name = config['models'][args.model]
    chunk_size = config['chunk_sizes'][args.model]
    if args.dataset=='eurlex':
        dataset = load_dataset(dataset_name, 'en', trust_remote_code=True, storage_options={'client_kwargs': {'timeout': aiohttp.ClientTimeout(total=3600)}})
    else:
        dataset = load_dataset(dataset_name)
    column_name = ('text','label')
    if args.dataset == 'hyperpartisan':
        train_data, test_data = dataset['train'].select(range(10000)), dataset['validation'].select(range(5000))
        column_name = ('text','hyperpartisan')
    elif args.dataset=='eurlex':
        train_data, test_data = dataset['train'].select(range(10000)), dataset['test'].select(range(5000))
        column_name = ('text','labels')
    else:
        train_data, test_data = dataset['train'], dataset['test']
    
    if args.model=='longt5':
        from transformers import LongT5Model
        tokenizer = AutoTokenizer.from_pretrained("google/long-t5-tglobal-base")
        base_model = LongT5Model.from_pretrained("google/long-t5-tglobal-base")
    else:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        base_model = AutoModel.from_pretrained(model_name)
    if args.model=='longclip':
        base_model = base_model.text_model
    elif args.model=='longt5':
        base_model = base_model.encoder
    if args.dataset == 'eurlex':
        from itertools import chain
        all_labels = set(chain.from_iterable([sample['labels'] for sample in train_data]))
        num_labels = max(all_labels) + 1  # class indices start at 0

    elif args.dataset == 'hyperpartisan':
        num_labels = 1
    else:
        num_labels = len(set(train_data['label']))
    
    embed_dir = f"embeddings/{args.dataset}/{args.model}"
    os.makedirs(embed_dir, exist_ok=True)
    if args.mode=='cls':
        suffix = 'cls'
    else:
        suffix = 'token'
    train_embed_file = os.path.join(embed_dir, f"train_embeddings_{suffix}.pt")
    train_label_file = os.path.join(embed_dir, f"train_labels_{suffix}.pt")
    test_embed_file = os.path.join(embed_dir, f"test_embeddings_{suffix}.pt")
    test_label_file = os.path.join(embed_dir, f"test_labels_{suffix}.pt")
    train_span_file = os.path.join(embed_dir, "train_spans_lc.pt")
    test_span_file = os.path.join(embed_dir, "test_spans_lc.pt")
    files_exist = all(os.path.exists(f) for f in [train_embed_file, train_label_file, test_embed_file, test_label_file])
    if args.mode == 'lc':
        files_exist = files_exist and all(os.path.exists(f) for f in [train_span_file, test_span_file])

    if files_exist:
        print("Loading saved embeddings and data...")
        train_embeds = torch.load(train_embed_file)
        train_labels = torch.load(train_label_file)
        test_embeds = torch.load(test_embed_file)
        test_labels = torch.load(test_label_file)
        if args.mode == 'lc':
            train_spans = torch.load(train_span_file)
            test_spans = torch.load(test_span_file)
    else:
        print("Generating embeddings from scratch...")
        train_embeds, train_labels, train_spans = generate_embeddings(train_data, tokenizer, base_model, chunk_size, device, column_name, mode=args.mode)
        test_embeds, test_labels, test_spans = generate_embeddings(test_data, tokenizer, base_model, chunk_size, device, column_name, mode=args.mode)
        
        torch.save(train_embeds, train_embed_file)
        torch.save(train_labels, train_label_file)
        torch.save(test_embeds, test_embed_file)
        torch.save(test_labels, test_label_file)
        if args.mode == 'lc':
            torch.save(train_spans, train_span_file)
            torch.save(test_spans, test_span_file)

    
    
    
    if args.mode == 'lc':
        train_dataset = LateChunkingEmbeddingDataset(train_embeds, train_labels, spans=train_spans)
        test_dataset = LateChunkingEmbeddingDataset(test_embeds, test_labels, spans=test_spans)
        collate_fn = token_collate_fn_latechunking
    else:
        train_dataset = EmbeddingDataset(train_embeds, train_labels)
        test_dataset = EmbeddingDataset(test_embeds, test_labels)
        collate_fn = token_collate_fn if args.mode != 'cls' else None

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        collate_fn=collate_fn
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=1 if args.mode == 'lc' else args.batch_size, # Use batch size 1 for lc if model expects it
        shuffle=False,
        collate_fn=collate_fn
    )
    
    if args.mode in ['cls','sat','mean','lc']:
        run_name = f"train_{args.mode}_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}"
    elif args.mode=='maxp':
        run_name = f"train_maxp_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}"
    elif args.mode=='stc':
        run_name = f"train_stc_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
    elif args.mode=='sat':
        run_name = f"train_sat_{args.dataset}_{args.model}_{args.lr}_{args.batch_size}_{args.num_epochs}_{args.proj_dim}_K{args.K}_gate{args.gate}"
        if args.d_attn != 128:
            run_name = run_name + f"_d_attn{args.d_attn}"
    

        
    if args.wandb:
        wandb.init(project="long_text_mlp",name=run_name, config=vars(args))

    result_path = f'results/{run_name}.csv'
        
    hidden_size = base_model.config.hidden_size
    
    if args.mode=='stc':
        hidden_size = (2 * args.K + 1) * hidden_size
        
    if args.dataset in ['hyperpartisan', 'eurlex']:
        criterion = nn.BCEWithLogitsLoss()
    else:
        criterion = nn.CrossEntropyLoss()

    results = []

    satpool = None
    if args.mode=='sat':
        satpool = SATPool(d_model=base_model.config.hidden_size,
                               K=args.K, d_attn=args.d_attn,
                               normalize_pos=args.normalize_positions).to(device)

    for i in range(args.num_exp):
        model = TextClassifier(args, hidden_size, num_labels, args.use_norm, args.classifier, args.proj_dim).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
        if satpool is not None:
            optimizer = torch.optim.Adam(list(model.parameters()) + list(satpool.parameters()), lr=args.lr)
        
        for _ in tqdm(range(args.num_epochs),desc=f'Training {i+1}/10 Experiment'):
            train(args, model, device, train_loader, optimizer, criterion,satpool,tokenizer)
        _, acc, f1 = evaluate(args, model, device, test_loader, criterion, args.dataset,satpool)
        results.append((acc, f1))

    acc_mean, acc_std = np.mean([r[0] for r in results]), np.std([r[0] for r in results])
    f1_mean, f1_std = np.mean([r[1] for r in results]), np.std([r[1] for r in results])

    print(f"Accuracy: {acc_mean:.4f} ± {acc_std:.4f}")
    print(f"F1 Score: {f1_mean:.4f} ± {f1_std:.4f}")

    os.makedirs("results", exist_ok=True)
    
    # Dictionary of results
    results = {
        'accuracy': acc_mean,
        'accuracy_std': acc_std,
        'f1': f1_mean,
        'f1_std': f1_std
    }
    if args.wandb:
        wandb.log(results)
