import torch
import torch.nn as nn
from spikingjelly.activation_based import surrogate, neuron, functional, layer
from spikingjelly.activation_based.model import spiking_resnet, train_classify, spiking_vgg
from spikingjelly.datasets import split_to_train_test_set, asl_dvs, fast_split_to_train_test_set

import spikingjelly.activation_based.layer as sj_layer

import argparse

import torchvision.transforms as transforms
from torchvision.transforms.functional import InterpolationMode
from torchvision.datasets import ImageFolder, CIFAR100
from data_utils import Cutout

from modelutils import *
from model_library import ASLDVS_CSNN, ASLDVS_4CONV_SNN

from pathlib import Path
from torch.utils.data import Subset

class ASLDVSTrainer(train_classify.Trainer):
    def preprocess_train_sample(self, args, x: torch.Tensor):
        # [N, T, C, H, W] -> [T, N, C, H, W] for multi-step mode
        return x.transpose(0, 1)
    
    def preprocess_test_sample(self, args, x: torch.Tensor):
        # [N, T, C, H, W] -> [T, N, C, H, W] for multi-step mode
        return x.transpose(0, 1)

    def process_model_output(self, args, y: torch.Tensor):
        # y: [T, N, num_classes] -> [N, num_classes] via firing rates
        return y.mean(0)  # return firing rate

    def set_optimizer(self, args, parameters):
        opt_name = args.opt.lower()
        if opt_name == "adam":
            print("Adam optimizer chosen", args.lr)
            optimizer = torch.optim.Adam(parameters, lr=args.lr)
            return optimizer
        else:
            return super().set_optimizer(args, parameters)

    def set_lr_scheduler(self, args, optimizer):
        lr_scheduler = args.lr_scheduler.lower()
        if lr_scheduler == "multistep":
            print("multistep scheduler: " + str(args.lr_milestones), args.lr_gamma)
            main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=args.lr_gamma)
            return main_lr_scheduler
        else:
            return super().set_lr_scheduler(args, optimizer)

    def get_args_parser(self, add_help=True):
        parser = super().get_args_parser()
        parser.add_argument('--T', type=int, help="total time-steps")
        parser.add_argument('--tau', default=2.0, type=float, help="LIF time constant")
        parser.add_argument('--detach', action='store_true', help="detach")
        parser.add_argument("--lr-milestones", nargs="+", default=[150, 255], type=int, help="Epochs at which to decay the learning rate")
        return parser

    def load_model(self, args, num_classes):
        # Build a ASLDVS_4CONV_SNN backbone
        model = ASLDVS_4CONV_SNN(tau=2.0, class_num=num_classes)
        functional.set_step_mode(model, 'm')
        return model

    def load_data(self, args):
        return self.loadASLDVS(args)

    def loadASLDVS(self, args):
        print("loadASLDVS")
        # ------------------------------------------------------------------
        # Data
        # ------------------------------------------------------------------
        full_dataset = asl_dvs.ASLDVS(
            root=args.data_path,
            data_type="frame",
            frames_number=args.T,
            split_by="number",
            transform=None,
        )
        print("split")

        root_dir        = Path("asldvs")              # keep split files in one folder
        root_dir.mkdir(exist_ok=True)

        train_idx_path  = root_dir / "train_idx.pt"
        test_idx_path   = root_dir / "test_idx.pt"

        # ------------------------------------------------------------------
        # 2.  Re-use split if it exists, otherwise create & persist it
        # ------------------------------------------------------------------
        if train_idx_path.exists() and test_idx_path.exists():
            # ---------- load ----------
            train_idx = torch.load(train_idx_path)
            test_idx  = torch.load(test_idx_path)

            dataset_train = Subset(full_dataset, train_idx)
            dataset_test  = Subset(full_dataset, test_idx)

            print(f"✔  Loaded cached split "
                f"({len(dataset_train)} train / {len(dataset_test)} test samples).")

        else:
            # ---------- create ----------
            dataset_train, dataset_test = fast_split_to_train_test_set(
                train_ratio=0.8,
                origin_dataset=full_dataset,
                num_classes=24,
                random_split=False,
                batch_size=32,
            )

            # ---------- save ----------
            torch.save(dataset_train.indices, train_idx_path)
            torch.save(dataset_test.indices,  test_idx_path)

            print(f"   Created new split and saved index files to ‘{root_dir}’\n"
                f"    ({len(dataset_train)} train / {len(dataset_test)} test samples).")
        
        print("Creating data loaders")
        if args.distributed:
            train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, seed=args.seed)
            test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
        else:
            loader_g = torch.Generator()
            loader_g.manual_seed(args.seed)
            train_sampler = torch.utils.data.RandomSampler(dataset_train, generator=loader_g)
            test_sampler = torch.utils.data.SequentialSampler(dataset_test)
        
        print(full_dataset.classes)
        dataset_train.classes = full_dataset.classes
        
        print(len(dataset_train), len(dataset_test))
        
        return dataset_train, dataset_test, train_sampler, test_sampler

    def get_tb_logdir_name(self, args):
        return super().get_tb_logdir_name(args) + f'_T{args.T}'

class ASLDVS4CONVTrainer(ASLDVSTrainer):
    def set_lr_scheduler(self, args, optimizer):
        lr_scheduler = args.lr_scheduler.lower()
        if lr_scheduler == "cosa":
            print("COSA cha: " + str(args.lr_milestones), args.lr_gamma)
            main_lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='min',
                factor=args.lr_gamma,
                patience=5,
                threshold=1e-2,
                cooldown=0,
                min_lr=1e-4,
            )
            # main_lr_scheduler.best = 2.40
            return main_lr_scheduler
        else:
            return super().set_lr_scheduler(args, optimizer)

    def load_model(self, args, num_classes):
        # Build a spiking ResNet-18 backbone
        model = ASLDVS_4CONV_SNN(tau=2.0, class_num=num_classes)
        
        # overwrite first 12 layers
        target_sd = model.conv_fc.state_dict()
        model_path = "./models/asldvs/checkpoint_max_test_acc1.pth"  
        
        src_model = ASLDVS_4CONV_SNN(tau=2.0, class_num=num_classes).to(DEV)
        src_model.load_state_dict(torch.load(model_path, weights_only=False, map_location=DEV)['model'])
        source_sd = src_model.conv_fc.state_dict()
        
        selected = {k: v for k, v in source_sd.items() if int(k.split('.')[0]) < 12}
        target_sd.update(selected)
        model.conv_fc.load_state_dict(target_sd)

        functional.set_step_mode(model, 'm')

        return model

def main():
    trainer = ASLDVS4CONVTrainer()
    args = trainer.get_args_parser().parse_args()
    trainer.main(args)


if __name__ == "__main__":
    main()