import mmcv
import decord
import numpy as np
import torch
import torch.nn as nn
import copy
import argparse
import os, sys
import os.path as osp
import glob
import scipy.io
import time
from mmcv.transforms import TRANSFORMS, BaseTransform, to_tensor
from mmaction.structures import ActionDataSample
import os.path as osp
from mmengine.fileio import list_from_file
from mmengine.dataset import BaseDataset
from mmaction.registry import DATASETS
from mmengine.runner import Runner
from mmaction.registry import MODELS
import pdb
from mmaction.utils import register_all_modules
from mmaction.registry import METRICS
from collections import OrderedDict
from mmengine.evaluator import BaseMetric
from mmaction.evaluation import top_k_accuracy
import torch.optim as optim
from mmengine import track_iter_progress
from tqdm import tqdm
from network import *


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='CAiDA')
    parser.add_argument('--gpu_id', type=str, nargs='?', default='3', help="device id to run")
    parser.add_argument('--dset', type=str, default='Sports-DA', choices=[])
    parser.add_argument('--net', type=str, default='swinl', help="")
    parser.add_argument('--output', type=str, default='')

    
    args = parser.parse_args()

    register_all_modules(init_default_scope=True)
    file_client_args = dict(io_backend='disk')

    dataset_type = 'VideoDataset'
    data_root_val = 
    ann_file_test = 


    if args.dset == 'Daily-DA':
        names = ['ARID', 'MIT', 'hmdb51']
        # names = ['ARID']
        args.class_num = 8

    if args.net == 'i3d':
        i3d_nldot_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(i3d_nldot_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 256)),
            dict(type='Flip', flip_ratio=1),
            dict(type='ThreeCrop', crop_size=256),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
        batch_size=8,
        num_workers=8,
        persistent_workers=True,
        sampler=dict(type='DefaultSampler', shuffle=False),
        dataset=dict(
            type=dataset_type,
            ann_file=ann_file_test,
            data_prefix=dict(video=data_root_val),
            pipeline=test_pipeline,
            test_mode=True))
    elif args.net == 'c3d':
        c3d_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(c3d_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=16,
                frame_interval=1,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 128)),
            # dict(type='Flip', flip_ratio=1),
            # dict(type='ThreeCrop', crop_size=128),
            dict(type='CenterCrop', crop_size=112),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=64,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif args.net == 'slowfast':
        slow_fast_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(slow_fast_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 256)),
            dict(type='Flip', flip_ratio=1),
            dict(type='ThreeCrop', crop_size=256),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=8,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif args.net == 'slowonly':
        slowonly_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(slowonly_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=4,
                frame_interval=16,
                num_clips=10,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 256)),
            dict(type='Flip', flip_ratio=1),
            dict(type='ThreeCrop', crop_size=256),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=8,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))

                test_mode=True))
    elif args.net == 'swint':
        swint_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(swint_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='Flip', flip_ratio=1),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=4,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif args.net == 'swins':
        swins_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(swins_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='Flip', flip_ratio=1),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=4,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif args.net == 'swinb':
        swinb_model['cls_head']['num_classes'] = args.class_num 
        model = MODELS.build(swinb_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='Flip', flip_ratio=1),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=4,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    elif args.net == 'swinl':
        swinl_model['cls_head']['num_classes'] = args.class_num
        model = MODELS.build(swinl_model)
        test_pipeline = [
            dict(type='DecordInit', **file_client_args),
            dict(
                type='SampleFrames',
                clip_len=32,
                frame_interval=2,
                num_clips=4,
                test_mode=True),
            dict(type='DecordDecode'),
            dict(type='Resize', scale=(-1, 224)),
            dict(type='Flip', flip_ratio=1),
            dict(type='ThreeCrop', crop_size=224),
            dict(type='FormatShape', input_format='NCTHW'),
            dict(type='PackActionInputs')
        ]
        test_dataloader = dict(
            batch_size=2,
            num_workers=8,
            persistent_workers=True,
            sampler=dict(type='DefaultSampler', shuffle=False),
            dataset=dict(
                type=dataset_type,
                ann_file=ann_file_test,
                data_prefix=dict(video=data_root_val),
                pipeline=test_pipeline,
                test_mode=True))
    
    if args.dset == 'Sports-DA':

    gpu_id = int(args.gpu_id)

    
    pth_files = glob.glob(os.path.join(base_path, '**/*.pth'), recursive=True)

  

    for file_path in pth_files:
        path_parts = file_path.split(os.sep)
        folder_name = path_parts[-2]
        # pdb.set_trace()
        model.load_state_dict(torch.load(file_path)["state_dict"])

        elif args.dset == 'Sports-DA':
            if folder_name =='sports1m':
                s = 0
                # continue
            elif folder_name =='UCF101':
                s = 1
                # continue

        for name in names:
            if folder_name == name:
                continue
          
            if name =='sports1m':
                # continue
                test_dataloader['dataset']['ann_file'] = 
                test_dataloader['dataset']['data_prefix']['video'] = 
                test_data_loader = Runner.build_dataloader(dataloader=test_dataloader)

                pool = nn.AdaptiveAvgPool2d((1, 1)).cuda()
                model.cls_head.num_classes = args.class_num
                # device = torch.device("cuda:1") 
                # model=model.to(device)
                model=model.cuda(gpu_id)
                tt = 0
                with torch.no_grad():
                    start_test = True
                    model.eval()
                    #for data_batch in test_data_loader:
                    for data_batch in tqdm(test_data_loader, desc='Processing batches'):

                        data = model.data_preprocessor(data_batch, training=False)
                        pdb.set_trace()
                        inputs = data['inputs']
                        labels = data['data_samples']

                        predictions,netF_feature = model(**data, mode='get_feat')

                        if args.net == 'slowfast':
                            slow_feature = netF_feature[0]
                            fast_feature = netF_feature[1]


                            pred_scores = torch.stack([sample.pred_score for sample in predictions])
                            gt_labels = torch.stack([sample.gt_label for sample in predictions])
                            if start_test:
                                    all_features_slow = slow_feature.float().cpu()
                                    all_features_fast = fast_feature.float().cpu()
                                    all_output = pred_scores.float().cpu()
                                    all_label = gt_labels.float().cpu()
                                    start_test = False

                            else:
                                all_features_slow = torch.cat((all_features_slow, slow_feature.float().cpu()), 0)
                                all_features_fast = torch.cat((all_features_fast, fast_feature.float().cpu()), 0)
                                all_output = torch.cat((all_output, pred_scores.float().cpu()), 0)
                                all_label = torch.cat((all_label, gt_labels.float().cpu()), 0)
                        else:
                            feature = netF_feature

                            pred_scores = torch.stack([sample.pred_score for sample in predictions])
                            gt_labels = torch.stack([sample.gt_label for sample in predictions])
                            # pdb.set_trace()                            if start_test:
                                    all_features = feature.float().cpu()
                                    all_output = pred_scores.float().cpu()
                                    all_label = gt_labels.float().cpu()
                                    start_test = False
                            else:
                                all_features = torch.cat((all_features, feature.float().cpu()), 0)
                                all_output = torch.cat((all_output, pred_scores.float().cpu()), 0)
                                all_label = torch.cat((all_label, gt_labels.float().cpu()), 0)

                    if args.net == 'slowfast':
                        scipy.io.savemat(f'{args.dset}_{args.net}_slow_{str(s)}_{str(tt)}.mat',{'ft':all_features_slow.numpy(),'output':all_output.numpy(),'label':all_label.numpy()})
                        scipy.io.savemat(f'{args.dset}_{args.net}_fast_{str(s)}_{str(tt)}.mat',{'ft':all_features_fast.numpy(),'output':all_output.numpy(),'label':all_label.numpy()})
                    else:
                        scipy.io.savemat(f'{args.dset}_{args.net}_{str(s)}_{str(tt)}.mat',{'ft':all_features.numpy(),'output':all_output.numpy(),'label':all_label.numpy()})

            elif name =='UCF101':
                # continue
                test_dataloader['dataset']['ann_file'] = 
                test_dataloader['dataset']['data_prefix']['video'] = 
                test_data_loader = Runner.build_dataloader(dataloader=test_dataloader)

                pool = nn.AdaptiveAvgPool2d((1, 1)).cuda()
                model.cls_head.num_classes = args.class_num
                model=model.cuda(gpu_id)

                tt = 1
                with torch.no_grad():
                    start_test = True
                    model.eval()
                    #for data_batch in test_data_loader:
                    for data_batch in tqdm(test_data_loader, desc='Processing batches'):
                        
                        data = model.data_preprocessor(data_batch, training=False)
                        inputs = data['inputs']
                        labels = data['data_samples']
                        predictions,netF_feature = model(**data, mode='get_feat')

                        if args.net == 'slowfast':
                            slow_feature = netF_feature[0]
                            fast_feature = netF_feature[1]
                            pred_scores = torch.stack([sample.pred_score for sample in predictions])
                            gt_labels = torch.stack([sample.gt_label for sample in predictions])
                            if start_test:
                                    all_features_slow = slow_feature.float().cpu()
                                    all_features_fast = fast_feature.float().cpu()
                                    all_output = pred_scores.float().cpu()
                                    all_label = gt_labels.float().cpu()
                                    start_test = False

                            else:
                                all_features_slow = torch.cat((all_features_slow, slow_feature.float().cpu()), 0)
                                all_features_fast = torch.cat((all_features_fast, fast_feature.float().cpu()), 0)
                                all_output = torch.cat((all_output, pred_scores.float().cpu()), 0)
                                all_label = torch.cat((all_label, gt_labels.float().cpu()), 0)

                        else:
                            feature = netF_feature

                            pred_scores = torch.stack([sample.pred_score for sample in predictions])
                            gt_labels = torch.stack([sample.gt_label for sample in predictions])
                            # pdb.set_trace()
                            if start_test:
                                    all_features = feature.float().cpu()
                                    all_output = pred_scores.float().cpu()
                                    all_label = gt_labels.float().cpu()
                                    start_test = False
                            else:
                                all_features = torch.cat((all_features, feature.float().cpu()), 0)
                                all_output = torch.cat((all_output, pred_scores.float().cpu()), 0)
                                all_label = torch.cat((all_label, gt_labels.float().cpu()), 0)

                    if args.net == 'slowfast':
                        scipy.io.savemat(f'{args.dset}_{args.net}_slow_{str(s)}_{str(tt)}.mat',{'ft':all_features_slow.numpy(),'output':all_output.numpy(),'label':all_label.numpy()})
                        scipy.io.savemat(f'{args.dset}_{args.net}_fast_{str(s)}_{str(tt)}.mat',{'ft':all_features_fast.numpy(),'output':all_output.numpy(),'label':all_label.numpy()})
                    else:
                        scipy.io.savemat(f'{args.dset}_{args.net}_{str(s)}_{str(tt)}.mat',{'ft':all_features.numpy(),'output':all_output.numpy(),'label':all_label.numpy()})
