import os
import torch
import numpy as np
import random
import json
import math
import sys
import re
from collections import defaultdict, namedtuple
from typing import Iterable
import argparse
import time
import datetime
from functools import reduce

from util import dist
from torch.utils.data import DataLoader, DistributedSampler
from dataset import HourHDVC_collate_fn_pred, build_HourHDVC_dataset_pred
from model import build_LOCO_model, _get_tokenizer
from util.misc import adjust_learning_rate
from args import get_args_parser
from util.metrics import MetricLogger
from util.prediction_post_process import json_to_pkl
from dvc_eval import eval_dvc, eval_soda
from model.lora import LoRALinear, set_lora_active_recursive, freeze_base_parameters, unfreeze_base_parameters

def convert_np_scalars(d):
    """Converts all np.generic types (e.g., np.float64, np.int64) in dictionary d to native Python types."""
    new_d = {}
    for key, value in d.items():
        if isinstance(value, np.generic):
            new_d[key] = value.item()
        elif isinstance(value, dict):
            new_d[key] = convert_np_scalars(value)
        elif isinstance(value, list):
            new_d[key] = [item.item() if isinstance(item, np.generic) else item for item in value]
        else:
            new_d[key] = value
    return new_d

def convert_linear_to_lora(module, r=4, lora_alpha=32, lora_dropout=0.1):
    """
    Recursively replaces all nn.Linear layers in a module with LoRALinear.
    The original in_features and out_features are preserved.
    """
    for name, child in module.named_children():
        if isinstance(child, torch.nn.Linear):
            in_features = child.in_features
            out_features = child.out_features
            lora_layer = LoRALinear(in_features, out_features, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
            lora_layer.weight.data.copy_(child.weight.data)
            if child.bias is not None:
                lora_layer.bias.data.copy_(child.bias.data)
            setattr(module, name, lora_layer)
        else:
            convert_linear_to_lora(child, r, lora_alpha, lora_dropout)

def train_one_epoch(
    model: torch.nn.Module,
    data_loader: Iterable,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
    epoch: int,
    args,
):
    seed = args.seed + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    model.train()
    metric_logger = MetricLogger(delimiter="  ")
    header = f"Epoch: [{epoch}]"
    total_samples = len(data_loader.dataset)
    print(f"Total training samples: {total_samples}")

    if not hasattr(train_one_epoch, 'optimizer_step_count'):
        train_one_epoch.optimizer_step_count = 0

    # Note: accumulation_step_ratio is used to correctly calculate total steps for LR scheduler.
    num_training_steps = args.epochs * total_samples * args.accumulation_step_ratio

    global_loss = torch.tensor(0.0, device='cpu')
    for i_batch, batch_dict in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        batch_size, num_segments = batch_dict["video"].shape[0], batch_dict["video"].shape[1]

        input_tokens_list = batch_dict.get("input_tokens", None)
        output_tokens_list = batch_dict.get("output_tokens", None)
        
        optimizer.zero_grad()
        model.window_number = num_segments

        s_start = 0
        duration = batch_dict["duration"]
        model.duration = duration[0] if isinstance(duration, list) else duration
        
        split_point = args.accumulation_steps
        if num_segments > 1 and (num_segments - 1) < args.accumulation_steps:
            split_point = num_segments - 1

        while s_start < num_segments:
            if 'visual' in args.window_memory or 'speech' in args.window_memory:
                if s_start == 0:
                    args.is_first_window = True
                    model.reset_memory_tokens()
                    if args.accum:
                        model.reset_accm_memory_tokens()
                else:
                    args.is_first_window = False
                    if args.train_middle_connect:
                        model.gradient_off_memory_tokens()
                        model.gradient_on_memory_tokens()
                    else:
                        model.reset_memory_tokens()
            else:
                args.mem_size=0

            s_end = num_segments
            if s_start != num_segments -1:
                 s_end = min(s_start + split_point, num_segments - 1)
            
            total_loss = 0.0
            for s in range(s_start, s_end):
                model.window_start_time = s * args.segment_length
                batched_video = batch_dict["video"][:, s].to(device)

                if s == num_segments - 1:
                    model.global_flag = True
                    freeze_base_parameters(model.t5_model)
                    set_lora_active_recursive(model.t5_model, True)
                else:
                    unfreeze_base_parameters(model.t5_model)
                    set_lora_active_recursive(model.t5_model, False)
                    
                input_tokens = torch.stack([input_tokens_list[i][s] for i in range(batch_size)], dim=0).to(device)
                input_tokenized = {
                    'input_ids': input_tokens,
                    'attention_mask': input_tokens != args.pad_id,
                }
                
                output_tokens = torch.stack([output_tokens_list[i][s] for i in range(batch_size)], dim=0).to(device)
                output_tokenized = {
                    'input_ids': output_tokens,
                    'attention_mask': output_tokens != args.pad_id,
                }
                
                main_loss_dict, _ = model(
                    video=batched_video,
                    input_tokenized=input_tokenized,
                    output_tokenized=output_tokenized,
                )
                loss = main_loss_dict["loss"]
                total_loss += loss

                if s == num_segments - 1:
                    global_loss += loss.item()
                
                del batched_video, input_tokens, input_tokenized, output_tokens, output_tokenized
                torch.cuda.empty_cache()

            total_loss.backward()
            
            if args.clip_max_norm > 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_max_norm)
    
            optimizer.step()
            optimizer.zero_grad()
            
            with torch.no_grad():
                frozen_norm = torch.norm(model.t5_model.shared.weight[:-args.num_bins, :], dim=1).mean(0)
                trainable_weight = model.t5_model.shared.weight[-args.num_bins:, :]
                model.t5_model.shared.weight[-args.num_bins:, :].div_(
                    torch.norm(trainable_weight, dim=1).mean(0) / frozen_norm
                )
                frozen_norm = torch.norm(model.t5_model.lm_head.weight[:-args.num_bins, :], dim=1).mean(0)
                trainable_weight = model.t5_model.lm_head.weight[-args.num_bins:, :]
                model.t5_model.lm_head.weight[-args.num_bins:, :].div_(
                    torch.norm(trainable_weight, dim=1).mean(0) / frozen_norm
                )

            average_loss = total_loss.item() / (s_end - s_start)
            if not math.isfinite(average_loss):
                print(f"Loss is {average_loss}, stopping training")
                sys.exit(1)

            metric_logger.update(loss=average_loss, lr=optimizer.param_groups[0]["lr"])
            
            train_one_epoch.optimizer_step_count += 1
            adjust_learning_rate(
                optimizer,
                curr_step=train_one_epoch.optimizer_step_count,
                num_training_steps=num_training_steps,
                args=args,
            )
            s_start = s_end

    metric_logger.synchronize_between_processes()
    print("Averaged stats:", metric_logger)
    print("Averaged Global Loss:", global_loss / total_samples)
    return {k: meter.global_avg for k, meter in metric_logger.meters.items()}

@torch.no_grad()
def evaluate(
    model: torch.nn.Module,
    data_loader,
    device: torch.device,
    args,
    split="test",
    dataset_name="chapters",
):
    model.eval()
    metric_logger = MetricLogger(delimiter="  ")
    header = f"{split}:"
    res, e2e_res = {}, {}

    for i_batch, batch_dict in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
        for vid in batch_dict["video_id"]:
            res[vid] = []
            e2e_res[vid] = []
        
        local_preds_map = defaultdict(list)

        if 'visual' in args.window_memory or 'speech' in args.window_memory:
            model.reset_memory_tokens()
            if args.accum:
                model.reset_accm_memory_tokens()
        else:
            args.mem_size=0

        model.window_number = len(batch_dict["video"][0])
        
        for s in range(len(batch_dict["video"][0])):
            args.is_first_window = s == 0

            if s == len(batch_dict["video"][0]) - 1:
                model.global_flag = True
                set_lora_active_recursive(model.t5_model, True)
            else:
                set_lora_active_recursive(model.t5_model, False)

            window_start_time = s * args.segment_interval
            model.window_start_time = window_start_time

            duration = batch_dict["duration"]
            model.duration = duration[0] if isinstance(duration, list) else duration

            video = batch_dict["video"][:, s].to(device)
            duration_segments = batch_dict["duration_segments"]

            # 1) Prepare input tokens
            if s == (len(batch_dict["video"][0]) - 1): # Global Window
                input_ids_list, attention_mask_list = [], []
                for b_idx, vid in enumerate(batch_dict["video_id"]):
                    local_texts = local_preds_map[vid]
                    local_text_str = " ".join(local_texts) if local_texts else ""

                    if not local_text_str.strip():
                        tokens = model.t5_tokenizer(["<pad>"], add_special_tokens=False, return_tensors="pt")
                    else:
                        tokens = model.t5_tokenizer(
                            local_text_str, add_special_tokens=False, max_length=args.max_input_tokens,
                            truncation=True, return_tensors="pt"
                        )
                    
                    input_ids = tokens["input_ids"][0].to(device)
                    eos_id = torch.LongTensor([model.t5_tokenizer.eos_token_id]).to(device)
                    input_ids = torch.cat([input_ids, eos_id], dim=0)
                    attn_mask = (input_ids != model.t5_tokenizer.pad_token_id).long()
                    
                    input_ids_list.append(input_ids.unsqueeze(0))
                    attention_mask_list.append(attn_mask.unsqueeze(0))

                input_tokenized = {
                    "input_ids": torch.cat(input_ids_list, dim=0),
                    "attention_mask": torch.cat(attention_mask_list, dim=0),
                }
            else: # Local Window
                input_tokens = torch.ones((video.shape[0], 1), dtype=torch.long).to(device)
                input_tokenized = {'input_ids': input_tokens, 'attention_mask': (input_tokens != 0)}

            # 2) Generate predictions
            output = model.generate(
                video=video, input_tokenized=input_tokenized, use_nucleus_sampling=args.num_beams == 0,
                num_beams=args.num_beams, max_length=args.max_output_tokens, min_length=1,
                top_p=args.top_p, repetition_penalty=args.repetition_penalty,
                length_penalty=args.length_penalty, num_captions=1, temperature=1,
            )

            # 3) Parse prediction results
            for i, vid in enumerate(batch_dict["video_id"]):
                sequences = re.split(r'(?<!<)\s+(?!>)', output[i])
                indexes = [j for j in range(len(sequences) - 1) if sequences[j][:6] == '<time=' and sequences[j+1][:6] == '<time=']
                last_processed = -2

                for j, idx in enumerate(indexes):
                    if idx == last_processed + 1: continue
                    seq = [sequences[k] for k in range(idx + 2, indexes[j + 1] if j < len(indexes) - 1 else len(sequences)) if sequences[k] != '<time=']
                    if not seq: continue
                    text = ' '.join(seq)

                    start_re = re.search(r'\<time\=(\d+)\>', sequences[idx])
                    end_re = re.search(r'\<time\=(\d+)\>', sequences[idx+1])
                    assert start_re and end_re, f"{sequences[idx]}, {sequences[idx+1]}"
                    
                    st = (float(start_re.group(1)) * float(duration_segments[i][s]) / float(args.num_bins - 1))
                    ed = (float(end_re.group(1)) * float(duration_segments[i][s]) / float(args.num_bins - 1))

                    if ed <= st: continue

                    if s != (len(batch_dict["video"][0]) - 1): # Local window
                        res[vid].append({'sentence': text, 'timestamp': [st + window_start_time, ed + window_start_time]})
                        local_preds_map[vid].append(text)
                    else: # Global (e2e) window
                        e2e_res[vid].append({"sentence": text, "timestamp": [0, model.duration]})
                    last_processed = idx

    results = reduce(lambda a, b: a.update(b) or a, dist.all_gather(res), {})
    e2e_results = reduce(lambda a, b: a.update(b) or a, dist.all_gather(e2e_res), {})

    metrics, global_metrics = {}, {}
    if dist.is_main_process():
        if args.save_dir:
            pred_path = os.path.join(args.save_dir, dataset_name + f"_{split}_preds.json")
            e2e_pred_path = os.path.join(args.save_dir, "mad_eval_global_preds.json")
            json.dump({'results': results}, open(pred_path, "w"))
            json.dump({'results': e2e_results}, open(e2e_pred_path, "w"))

        if dataset_name == "mad_sum":
            references_local = [args.mad_val_e2e_local_json_path]
            references_global = [args.mad_val_e2e_global_json_path]
        else:
            raise NotImplementedError

        metrics.update(eval_dvc(pred_path, references_local, tious=[0.1, 0.3, 0.5, 0.7], max_proposals_per_video=1000, verbose=True, no_lang_eval=False))
        metrics.update(eval_soda(pred_path, references_local, verbose=True))
    
    metrics = convert_np_scalars(metrics)
    metrics = reduce(lambda a, b: a.update(b) or a, dist.all_gather(metrics), {})

    if dist.is_main_process():
        global_metrics.update(eval_dvc(e2e_pred_path, references_global, tious=[0.5], max_proposals_per_video=1000, verbose=True, no_lang_eval=False))
        global_metrics.update(eval_soda(e2e_pred_path, references_global, verbose=True))
        
        metrics.update({f"global_{k}": v for k, v in global_metrics.items()})
    
    torch.distributed.barrier()

    if dist.is_main_process():
        pred_path = os.path.join(args.save_dir, dataset_name + f"_{split}_preds.json")
        pkl_path = os.path.join(args.save_dir, dataset_name + f"_{split}_preds.pkl")
        json_to_pkl(pred_path, pkl_path)
        
        # Output final metrics
        for k, v in metrics.items():
            print(f"{k}: {v:.4f}")
        print("\nCIDEr METEOR ROUGE-L Recall Precision F1 Global-Bleu4 Global-Meteor Global-Rouge-L\n")
        print(
            f"{metrics.get('CIDEr', 0) * 100:.2f} ", f"{metrics.get('METEOR', 0) * 100:.2f} ", 
            f"{metrics.get('Rouge-L', 0) * 100:.2f} ", f"{metrics.get('Recall', 0) * 100:.2f} ", 
            f"{metrics.get('Precision', 0) * 100:.2f} ", f"{metrics.get('F1', 0) * 100:.2f} ", 
            f"{metrics.get('global_Bleu_4', 0) * 100:.2f} ", f"{metrics.get('global_METEOR', 0) * 100:.2f} ", 
            f"{metrics.get('global_Rouge-L', 0) * 100:.2f}"
        )
    torch.distributed.barrier()
    return metrics

def main(args):
    dist.init_distributed_mode(args)

    if dist.is_main_process() and args.save_dir and not os.path.isdir(args.save_dir):
        os.makedirs(os.path.join(args.save_dir), exist_ok=True)
    
    print(args)
    device = torch.device(args.device)

    seed = args.seed + dist.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    
    tokenizer = _get_tokenizer(args.model_name, args.num_bins)
    args.pad_id = tokenizer.pad_token_id
    
    nt = namedtuple("data", ["dataset_name", "dataloader_val", "dataloader_train", "dataloader_test"])
    tuples = []

    for dset_name in args.combine_datasets:
        dataset_val = build_HourHDVC_dataset_pred(dset_name, "val", args, tokenizer)
        sampler_val = DistributedSampler(dataset_val, shuffle=False) if args.distributed else torch.utils.data.SequentialSampler(dataset_val)
        dataloader_val = DataLoader(
            dataset_val, batch_size=args.batch_size_val, sampler=sampler_val,
            collate_fn=HourHDVC_collate_fn_pred, num_workers=args.num_workers,
        )
        dataloader_test = dataloader_val

        if not args.eval:
            dataset_train = build_HourHDVC_dataset_pred(dset_name, "train", args, tokenizer)
            sampler_train = DistributedSampler(dataset_train) if args.distributed else torch.utils.data.RandomSampler(dataset_train)
            dataloader_train = DataLoader(
                dataset_train, batch_size=args.batch_size, sampler=sampler_train,
                collate_fn=HourHDVC_collate_fn_pred, num_workers=args.num_workers,
            )
        else:
            dataloader_train = None

        tuples.append(nt(dataset_name=dset_name, dataloader_test=dataloader_test, dataloader_val=dataloader_val, dataloader_train=dataloader_train))

    model = build_LOCO_model(args, tokenizer)
    
    module_map = {
        'visual_encoder': model.visual_encoder,
        't5_encoder': model.t5_model.encoder,
        't5_decoder': model.t5_model.decoder,
    }

    for key in args.lora_apply:
        target_module = module_map.get(key)
        if target_module:
            convert_linear_to_lora(target_module, r=4, lora_alpha=32, lora_dropout=0.1)
        else:
            print(f"[Warning] Unknown module key for LoRA: {key}")
            
    model.to(device)
    
    n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    if dist.is_main_process():
        print("Number of trainable parameters:", n_parameters)

    optimizer = torch.optim.Adam(
        [p for p in model.parameters() if p.requires_grad],
        lr=args.lr, betas=(args.beta1, args.beta2), weight_decay=args.weight_decay,
    )

    if args.load:
        if dist.is_main_process(): print("Loading from", args.load)
        checkpoint = torch.load(args.load, map_location="cpu")
        model.load_state_dict(checkpoint["model"], strict=False)
        if args.resume and not args.eval:
            optimizer.load_state_dict(checkpoint["optimizer"])
            args.start_epoch = checkpoint["epoch"] + 1

    for i, item in enumerate(tuples):
        if not args.eval:
            if dist.is_main_process(): print("Start training")
            start_time = time.time()
            best_epoch, best_acc = args.start_epoch, 0

            for epoch in range(args.start_epoch, args.epochs):
                if dist.is_main_process(): print(f"Starting epoch {epoch}")
                if args.distributed: sampler_train.set_epoch(epoch)

                train_stats = train_one_epoch(
                    model=model, data_loader=item.dataloader_train, optimizer=optimizer,
                    device=device, epoch=epoch, args=args,
                )
                
                # Validation is performed at specific epochs as defined in the original script
                if epoch in [0, 3, 9, 12, 17, 20, 22, 24, 27, 29, 32, 33, 34, 35, 36, 37, 38, 39]:
                    val_stats = {}
                    if item.dataloader_val is not None:
                        print(f"Validating {item.dataset_name}")
                        out = evaluate(
                            model=model, data_loader=item.dataloader_val, device=device,
                            dataset_name=item.dataset_name, args=args, split="val",
                        )
                        val_stats.update({item.dataset_name + "_" + k: v for k, v in out.items()})
                        
                        current_meteor = out.get("METEOR", 0)
                        if current_meteor > best_acc:
                            best_epoch, best_acc = epoch, current_meteor
                            if dist.is_main_process() and args.save_dir:
                                dist.save_on_master(
                                    {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args},
                                    os.path.join(args.save_dir, "best_model.pth"),
                                )
                else:
                    val_stats = {}

                log_stats = {
                    **{f"train_{k}": v for k, v in train_stats.items()},
                    **{f"val_{k}": v for k, v in val_stats.items()},
                    "epoch": epoch, "n_parameters": n_parameters,
                }

                if args.save_dir and dist.is_main_process():
                    with open(os.path.join(args.save_dir, "log.txt"), "a") as f:
                        f.write(json.dumps(log_stats) + "\n")
                    dist.save_on_master(
                        {"model": model.state_dict(), "optimizer": optimizer.state_dict(), "epoch": epoch, "args": args},
                        os.path.join(args.save_dir, "ckpt.pth"),
                    )

            total_time_str = str(datetime.timedelta(seconds=int(time.time() - start_time)))
            print(f"Training time {total_time_str}")
            
            if args.save_dir:
                if dist.is_main_process(): print(f"Loading best checkpoint from epoch {best_epoch}")
                torch.distributed.barrier()
                checkpoint = torch.load(os.path.join(args.save_dir, "best_model.pth"), map_location="cpu")
                model.load_state_dict(checkpoint["model"], strict=False)

        out = evaluate(
            model=model, data_loader=item.dataloader_test, device=device,
            dataset_name=item.dataset_name, args=args, split="test"
        )

        if args.save_dir and dist.is_main_process():
            json.dump(out, open(os.path.join(args.save_dir, item.dataset_name + "summary.json"), "w"))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(parents=[get_args_parser()])
    args = parser.parse_args()
    if args.save_dir:
        args.save_dir = os.path.join(args.presave_dir, args.save_dir)
    
    # These arguments are set based on the provided .sh script and are not intended to be changed.
    args.gt_bound = False
    args.intergrated_mem = False
    os.environ["TOKENIZERS_PARALLELISM"] = "false"

    main(args)
