import logging
import traceback
import copy
import torch
import os.path
import json
import time
import tqdm
import random
import math
from torch import nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from trainer.model_trainer import ModelTrainer
from torch.utils.data import DataLoader, TensorDataset
from inversion import Inversion

class MyModelTrainer(ModelTrainer):
    def __init__(self, model, args, config):
        self.model = model
        self.id = 0
        self.args = args
        self.config = config
        self.client_num = config['client_num']
        self.dataset = config['dataset']
        self.use_cuda = config['use_cuda']
        self.base_model = config['base_model']
        self.root = config['root']
        self.model_root = config['model_root']
        self.ratio = config['ratio']
        self.task = config['task']
        self.src = config['src']
        self.tgt = config['tgt']
        self.uid_src = config['src_tgt_pairs'][self.task]['uid_src']
        self.iid_src = config['src_tgt_pairs'][self.task]['iid_src']
        self.uid_tgt = config['src_tgt_pairs'][self.task]['uid_tgt']
        self.iid_tgt = config['src_tgt_pairs'][self.task]['iid_tgt']
        self.field_dims_src = {'uid_src': self.uid_src, "iid_src": self.iid_src}
        self.field_dims_tgt = {'uid_tgt': self.uid_tgt, "iid_tgt": self.iid_tgt}

        self.batchsize_src = config['src_tgt_pairs'][self.task]['batchsize_src']
        self.batchsize_tgt = config['src_tgt_pairs'][self.task]['batchsize_tgt']
        self.batchsize_meta = config['src_tgt_pairs'][self.task]['batchsize_meta']
        self.batchsize_map = config['src_tgt_pairs'][self.task]['batchsize_map']
        self.batchsize_test = config['src_tgt_pairs'][self.task]['batchsize_test']
        self.topk = config['src_tgt_pairs'][self.task]['topk']
        self.batchsize_aug = self.batchsize_src

        self.epoch = config['epoch']
        self.emb_dim = config['emb_dim']
        self.meta_dim = config['meta_dim']
        self.num_fields = config['num_fields']
        self.lr = config['lr']
        self.lr_prototype = config['lr_prototype']
        self.wd = config['wd']
        self.results = {'frjve_mae': 10, 'frjve_rmse': 10}
        self.num_all = 0

    def get_model_params(self):
        return self.model.cpu().state_dict()

    def set_model_params(self, model_parameters):
        self.model.load_state_dict(model_parameters)

    def read_log_data(self, path, batchsize):
        cols = ['uid', 'iid', 'y']
        x_col = ['uid', 'iid']
        y_col = ['y']
        data = pd.read_csv(path, header=None)
        data.columns = cols
        X = torch.tensor(data[x_col].values, dtype=torch.long)
        y = torch.tensor(data[y_col].values, dtype=torch.long)
        if self.use_cuda:
            X = X.cuda()
            y = y.cuda()
        dataset = TensorDataset(X, y)
        data_iter = DataLoader(dataset, batchsize, shuffle=True)
        return data, data_iter

    def read_log_data_con(self, path_all, batchsize):
        cols = ['uid', 'iid', 'y']
        x_col = ['uid', 'iid']
        y_col = ['y']
        data_list = []
        for path in path_all:
            data = pd.read_csv(path, header=None)
            data_list.append(data)
        data = pd.concat(data_list, axis=0, ignore_index=True)
        data.columns = cols
        X = torch.tensor(data[x_col].values, dtype=torch.long)
        y = torch.tensor(data[y_col].values, dtype=torch.long)
        if self.use_cuda:
            X = X.cuda()
            y = y.cuda()
        dataset = TensorDataset(X, y)
        data_iter = DataLoader(dataset, batchsize, shuffle=True)
        return data_iter

    def read_map_data(self, train_mapping_idx):
        X = torch.tensor(np.array(train_mapping_idx), dtype=torch.long)
        y = torch.tensor(np.array(range(X.shape[0])), dtype=torch.long)
        if self.use_cuda:
            X = X.cuda()
            y = y.cuda()
        dataset = TensorDataset(X, y)
        data_iter = DataLoader(dataset, self.batchsize_map, shuffle=True)
        return data_iter

    def read_map_data_con(self, train_mapping_idx):
        tensor_list = [torch.tensor(np.array(arr), dtype=torch.long).unsqueeze(0) for arr in train_mapping_idx]
        X = torch.cat(tensor_list, dim=1).squeeze(0)
        tensor_list = [torch.tensor(np.array(range(len(arr))), dtype=torch.long).unsqueeze(0) for arr in train_mapping_idx]
        y = torch.cat(tensor_list, dim=1).squeeze(0)
        if self.use_cuda:
            X = X.cuda()
            y = y.cuda()
        dataset = TensorDataset(X, y)
        data_iter = DataLoader(dataset, self.batchsize_map, shuffle=True)
        return data_iter

    def compute_dataframes_average(self, dataframes, num_list):
        print(num_list)
        if not dataframes:
            raise ValueError("At least one DataFrame is required.")
        if len(dataframes) == 1:
            return dataframes[0]
        if len(dataframes) != len(num_list):
            raise ValueError("The length of num_list must match the number of DataFrames.")
        total_samples = sum(num_list)
        weights = [num / total_samples for num in num_list]

        columns = dataframes[0].columns
        avg_data = {col: [0] * len(dataframes[0]) for col in columns}
        avg_data[columns[0]] = dataframes[0][columns[0]].tolist()

        for df, weight in zip(dataframes, weights):
            for col in columns[1:]:
                avg_data[col] = [sum(x) for x in zip(avg_data[col], df[col] * weight)]

        avg_df = pd.DataFrame(avg_data)

        return avg_df

    def read_rating_preference(self, path_all, src_map_num_all):
        data_df = []
        for path in path_all:
            src_rate_pre_file = path + '_train_src.csv'
            tgt_rate_pre_file = path + '_train_tgt.csv'
            co_uid_cols = ['co_uid']
            src_rate_pre_cols = ['src_rate_pre_'+str(x) for x in range(self.emb_dim)]
            src_rate_pre_cols = co_uid_cols + src_rate_pre_cols
            tgt_rate_pre_cols = ['tgt_rate_pre_'+str(x) for x in range(self.emb_dim)]
            tgt_rate_pre_cols = co_uid_cols + tgt_rate_pre_cols
            src_rate_pre = pd.read_csv(src_rate_pre_file, header=None, names=src_rate_pre_cols)
            tgt_rate_pre = pd.read_csv(tgt_rate_pre_file, header=None, names=tgt_rate_pre_cols)
            merged_rate_pre = pd.merge(src_rate_pre, tgt_rate_pre, on='co_uid')
            data_df.append(merged_rate_pre)

        avg_df = self.compute_dataframes_average(data_df, src_map_num_all)
        X = torch.tensor(avg_df.values, dtype=torch.float32)
        y = torch.tensor(np.array(range(avg_df.shape[0])), dtype=torch.long)
        print('map {} iter / batchsize = {} '.format(len(X), self.batchsize_meta))
        if self.use_cuda:
            X = X.cuda()
            y = y.cuda()
        dataset = TensorDataset(X, y)
        data_iter = DataLoader(dataset, self.batchsize_meta, shuffle=True)
        return data_iter

    def read_meta_test(self, per_iid_path_all, data_path, test_num_all):
        all_per_iid = []
        uid_cols = ['uid']
        iid_cols = ['iid']
        y_col = ['y']
        cols = uid_cols + iid_cols + y_col
        data_root = pd.read_csv(data_path, header=None, names=cols)
        for per_iid_path in per_iid_path_all:
            src_rate_pre_file = per_iid_path + '_test_src.csv'
            per_iid_cols = [str(x) for x in range(self.emb_dim)]
            cols = uid_cols + per_iid_cols
            rate_per = pd.read_csv(src_rate_pre_file, header=None, names=cols)
            rate_per['uid'] = rate_per['uid'].astype(int)
            all_per_iid.append(rate_per)
        avg_df = self.compute_dataframes_average(all_per_iid, test_num_all)
        test_data_with_per_iid = pd.merge(data_root, avg_df, on='uid')

        test_data_with_per_iid = test_data_with_per_iid.astype(float)
        # test data x is [test_co_uid, test_iid, per_iid_emb]
        X = torch.tensor(test_data_with_per_iid[iid_cols + per_iid_cols].values, dtype=torch.float32)
        y = torch.tensor(test_data_with_per_iid[y_col].values, dtype=torch.long)
        if self.use_cuda:
            X = X.cuda()
            y = y.cuda()
        dataset = TensorDataset(X, y)
        data_iter = DataLoader(dataset, self.batchsize_test, shuffle=True)
        return data_iter

    def get_data(self, inversion=False):
        test, data_test = self.read_log_data(self.test_path, self.batchsize_test)
        test_idx = pd.read_csv(self.test_idx_path, header=None, index_col=False).values.tolist()[0]
        co_idx = pd.read_csv(self.co_user_idx_path, header=None, index_col=False).values.tolist()[0]
        co_user_num = len(co_idx)
        mapping_idx = list(set(co_idx) - set(test_idx))
        data_map = self.read_map_data(mapping_idx)
        if inversion:
            return data_test, test_idx, mapping_idx
        else:
            src, data_src = self.read_log_data(self.src_path, self.batchsize_src)
            tgt, data_tgt = self.read_log_data(self.tgt_path, self.batchsize_tgt)
            src_map_num = len(src[src['uid'].isin(set(mapping_idx))])
            test_num = len(src[src['uid'].isin(set(test_idx))])
            return data_src, data_tgt, data_test, data_map, test_idx, mapping_idx, src_map_num, test_num

    def get_data_con(self, src_path, tgt_path, test_path, test_idx_path_all, co_user_idx_path_all, inversion=False):
        data_src = self.read_log_data_con(src_path, self.batchsize_src)
        data_tgt = self.read_log_data_con(tgt_path, self.batchsize_tgt)
        data_test = self.read_log_data_con(test_path, self.batchsize_test)
        mapping_idx_all = []
        for test_idx_path, co_user_idx_path in zip(test_idx_path_all, co_user_idx_path_all):
            test_idx = pd.read_csv(test_idx_path, header=None, index_col=False).values.tolist()[0]
            co_idx = pd.read_csv(co_user_idx_path, header=None, index_col=False).values.tolist()[0]
            co_user_num = len(co_idx)
            mapping_idx = list(set(list(range(co_user_num))) - set(test_idx))
            mapping_idx_all.append(mapping_idx)
        data_map = self.read_map_data_con(mapping_idx_all)
        return data_src, data_tgt, data_test, data_map

    def train(self, data_loader, model, criterion, optimizer, epochs, stage, mapping=False):
        epoch_loss = []

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                               mode="min",
                                                               factor=0.2,
                                                               patience=2)

        for epoch in range(epochs):
            batch_loss = []
            model.train()
            for X, y in data_loader:
                if mapping:
                    src_emb, tgt_emb = model(X, stage)
                    loss = criterion(src_emb, tgt_emb)
                else:
                    pred = model(X, stage)
                    loss = criterion(pred, y.squeeze().float())
                model.zero_grad()
                loss.backward()
                batch_loss.append(loss.item())
                optimizer.step()

            epoch_loss.append(sum(batch_loss) / len(batch_loss))
            scheduler.step(sum(epoch_loss) / len(epoch_loss))
            logging.info('Client Index = {}\tEpoch: {}\tLoss: {:.6f}'.format(
                self.id, epoch, sum(epoch_loss) / len(epoch_loss)))

        logging.info("--------------------------------------------------------------------------")

    def get_optimizer(self, model):
        optimizer_src = torch.optim.Adam(params=model.src_model.parameters(), lr=self.lr, weight_decay=self.wd)
        optimizer_tgt = torch.optim.Adam(params=model.tgt_model.parameters(), lr=self.lr, weight_decay=self.wd)
        optimizer_map = torch.optim.Adam(params=model.mapping.parameters(), lr=self.lr, weight_decay=self.wd)
        return optimizer_src, optimizer_tgt, optimizer_map

    def update_results(self, mae, rmse, phase):
        if mae < self.results[phase + '_mae']:
            self.results[phase + '_mae'] = mae
        if rmse < self.results[phase + '_rmse']:
            self.results[phase + '_rmse'] = rmse

    def FR_JVE_pre(self, model, inverse_idx, test_idx, src, tgt):
        print('==========FR_JVE pre==========')
        mapping = False
        inversion_config_path = 'inversion_config.json'
        with open(inversion_config_path, 'r') as f:
            inversion_config = json.load(f)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        per_iid_emb_root = inversion_config['syn_root'] + '/' + self.dataset + '/_' + str(self.client_num) + '_/tgtclient_' + str(tgt) + '_srcclient_' + str(src) + '/' + \
                           str(inversion_config['target_label']) + '_' + self.base_model
        inversion_config["task"] = self.task
        inversion_config['ratio'] = self.ratio
        inversion_config['base_model'] = self.base_model
        inversion_config['src_item_dims'] = self.iid_src
        inversion_config['tgt_item_dims'] = self.iid_tgt

        if not os.path.exists(per_iid_emb_root + '_train_src.csv') or \
                not os.path.exists(per_iid_emb_root + '_test_src.csv'):
            Inversion(inversion_config, self.dataset, self.client_num, model, inverse_idx, device, stage="train", inv_goal='src', src=src, tgt=tgt ).main()
            Inversion(inversion_config, self.dataset, self.client_num, model, test_idx, device, stage="test", inv_goal='src', src=src, tgt=tgt).main()
        if not os.path.exists(per_iid_emb_root + '_train_tgt.csv'):
            Inversion(inversion_config, self.dataset, self.client_num, model, inverse_idx, device, stage="train", inv_goal='tgt', src=src, tgt=tgt ).main()
            # During inversion, model parameters were set requires_grad = False
            for p in model.parameters():
                p.requires_grad = True
        return per_iid_emb_root

    def FR_JVE(self, model, criterion, optimizer, data_rate_pre_train, data_rate_pre_test, temp):
        print('==========FR_JVE==========')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        mapping = True
        if temp==0:
            self.train(data_rate_pre_train, model, criterion, optimizer, self.epoch, stage='train_source_free', mapping=True)
        if temp==1:
            metrics = self.test(model, data_rate_pre_test, stage='test_source_free', device= device)
            self.update_results(metrics["test_mae"]/metrics["test_total"], math.sqrt(metrics["test_loss"]/metrics["test_total"]), 'frjve')
            print('MAE: {} RMSE: {}'.format(metrics["test_mae"]/metrics["test_total"], math.sqrt(metrics["test_loss"]/metrics["test_total"])))
            return metrics

    def Train_tgt(self, model, data_tgt, criterion, optimizer,temp):
        print('=========Train_tgt========')
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if temp==0:
            self.train(data_tgt, model, criterion, optimizer, self.epoch, stage='train_tgt')

    def Train_src(self, model, data_src, data_map, criterion, optimizer_src, optimizer_map, temp):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        if temp ==0:
            print('==========Train_src==========')
            self.train(data_src, model, criterion, optimizer_src, self.epoch, stage='train_src')
            # self.train(data_map, model, criterion, optimizer_map, self.epoch, stage='train_map', mapping=True)

    def PF_train(self, temp, weight, client_idx, client_idx_other):
        logging.debug("-------model actually train------")
        try:
            i = client_idx
            global_data = {}
            criterion = torch.nn.MSELoss()
            model = self.model
            # self.global_model = copy.deepcopy(weight)
            if self.use_cuda:
                model.cuda()
            for j in client_idx_other:
                if i != j:
                    self.input_root = self.root + self.dataset + '/ready/_' + str(self.client_num) + \
                                      '_/tgtclient_' + str(i) + '_srcclient_' + str(j)
                    self.src_path = self.input_root + '/train_src.csv'
                    self.tgt_path = self.input_root + '/train_tgt.csv'
                    self.meta_path = self.input_root + '/train_meta.csv'
                    self.test_path = self.input_root + '/test.csv'
                    self.test_idx_path = self.input_root + '/test_list.csv'
                    self.co_user_idx_path = self.input_root + '/co_user_list.csv'
                    data_src, data_tgt, data_test, data_map, test_idx, inverse_idx, src_map_num, test_num = self.get_data()

                    global_data[(i, j)] = {
                        'src_path': self.src_path,
                        'tgt_path': self.tgt_path,
                        'meta_path': self.meta_path,
                        'test_path': self.test_path,
                        'test_idx_path': self.test_idx_path,
                        'co_user_idx_path': self.co_user_idx_path,
                        'data_src': data_src,
                        'data_tgt': data_tgt,
                        'data_test': data_test,
                        'data_map': data_map,
                        'test_idx': test_idx,
                        'inverse_idx': inverse_idx,
                        'src_num': src_map_num,
                        'test_num': test_num,
                        'src': j,
                        'tgt': i,
                    }

            optimizer_src, optimizer_tgt, optimizer_map = self.get_optimizer(model)
            optimizer_rp_map = torch.optim.Adam(params=model.rp_mapping.parameters(), lr=self.lr_prototype,
                                                weight_decay=self.wd)

            emb_root = []
            src_path_all = []
            tgt_path_all = []
            test_path_all = []
            test_idx_path_all = []
            co_user_idx_path_all = []
            src_map_num_all = []
            test_num_all = []

            inversion_config_path = 'inversion_config.json'
            with open(inversion_config_path, 'r') as f:
                inversion_config = json.load(f)

            for (i, j), data_dict in global_data.items():
                emb_root_temp = inversion_config['syn_root'] + '/' + self.dataset + '/_' + str(
                self.client_num) + '_/tgtclient_' + str(data_dict['tgt']) + '_srcclient_' + str(data_dict['src']) + '/' + \
                               str(inversion_config['target_label']) + '_' + self.base_model
                emb_root.append(emb_root_temp)
                src_path_all.append(data_dict['src_path'])
                tgt_path_all.append(data_dict['tgt_path'])
                test_path_all.append(data_dict['test_path'])
                test_idx_path_all.append(data_dict['test_idx_path'])
                co_user_idx_path_all.append(data_dict['co_user_idx_path'])
                src_map_num_all.append(data_dict['src_num'])
                test_num_all.append(data_dict['test_num'])

            concatenated_data_src, concatenated_data_tgt, concatenated_data_test, concatenated_data_map = (
                self.get_data_con(src_path_all, tgt_path_all, test_path_all, test_idx_path_all, co_user_idx_path_all))

            self.Train_tgt(model, data_tgt, criterion, optimizer_tgt, temp)
            self.Train_src(model, concatenated_data_src, data_map, criterion, optimizer_src, optimizer_map, temp)

            if temp == 0:
                for (i, j), data_dict in global_data.items():
                    self.FR_JVE_pre(model, data_dict['inverse_idx'], data_dict['test_idx'], data_dict['src'], data_dict['tgt'])

            average_train_tensor = self.read_rating_preference(emb_root, src_map_num_all)
            data_rate_pre_test_all = self.read_meta_test(emb_root, self.test_path, test_num_all)

            metrics_frjve = self.FR_JVE(model, criterion, optimizer_rp_map, average_train_tensor, data_rate_pre_test_all, temp)
            print("===========================================")
            print(self.results)
            print("===========================================")
            if temp == 1:
                return metrics_frjve

        except Exception as e:
            logging.error(traceback.format_exc())

    def test(self, model, data_loader, stage, device):
        print('Evaluating MAE:')
        model.to(device)
        model.eval()
        metrics = {
            'test_loss': 0,
            'test_mae': 0,
            'test_total': 0,
        }

        targets, predicts = list(), list()
        loss = torch.nn.L1Loss()
        mse_loss = torch.nn.MSELoss()
        all_targets, all_predicts = list(), list()
        with torch.no_grad():
            for X, y in data_loader:
                pred = model(X, stage)
                all_targets.extend(y.squeeze(1).tolist())
                all_predicts.extend(pred.tolist())
                targets = y.squeeze(1).tolist()
                predicts = pred.tolist()
                targets = torch.tensor(targets).float()
                predicts = torch.tensor(predicts)
                metrics['test_loss'] += mse_loss(targets, predicts).item() * targets.size(0)
                metrics['test_mae'] += loss(targets, predicts).item() * targets.size(0)
                metrics['test_total'] += targets.size(0)

        return metrics

    def test_on_the_server(self, train_data_local_dict, test_data_local_dict, device, args=None) -> bool:
        return False
