function [label] = DCA(X_src,Y_src,X_tar_all,X_tar_ori,options)

%% 

% Inputs:

%%% X_src  :source feature matrix, d * ns

%%% Y_src  :source label(one-hot), c * ns

%%% X_tar_all  :target feature matrix (10 subsets), d * nt/10

%%% X_tar_ori  :target feature matrix, d * nt

%%% options:option struct

% Outputs:

%%% label    :final prediction of the label using knn



    lambda1 = options.lambda1;

	lambda2 = options.lambda2;
    
    Test_Lab_all = options.Test_Lab_all;
    
    Test_Lab_ori = options.Test_Lab;
    
    Train_Lab = options.Train_Lab;

	dim = options.dim;
    
    nnClass = size(Y_src,2);

	max_iter = options.max_iter;

	[m, n] = size(X_src);

	class_set = unique(Y_src);

	C = length(class_set);
    
    Ws  = rand(m,dim);
    Wt  = rand(m,dim);
    Ws_all = cell(1,nnClass);
    Wt_all = cell(1,nnClass);
    acc_list = [];
    
    options.ReducedDim = dim;

    [coeff, ~] = pca(X_src'); P1s = coeff(:,1:dim);
    [coeff, ~] = pca(X_tar_ori'); P1t = coeff(:,1:dim);
    
    knn_model = fitcknn(X_src',Train_Lab,'NumNeighbors',1);
    class_candidate = knn_model.predict(X_tar_ori');

    acc_max = 0; X_src_da_all = []; X_tar_da_all = [];
    for iter = 1 : max_iter
        
        X_tar = X_tar_all{mod(iter,length(X_tar_all))+1};

        if iter == 1
            knn_model = fitcknn(X_src',Train_Lab,'NumNeighbors',1);
            label_candidate = knn_model.predict(X_tar');
            X_tar_candidate = X_tar;
        end
        
        obj_src = 0; obj_tar = 0; obj_intra = 0;
        X_src_da = []; X_tar_da = [];
        
        for i = 1:nnClass
            src_index_i = Train_Lab == i;
            
            X_src_i = X_src(:,src_index_i);
            
            Y_src_i = Train_Lab(src_index_i);

            tar_index_i = label_candidate(:) == i;

            X_tar_i = X_tar_candidate(:,tar_index_i);
            
            label_candidate_i = label_candidate(tar_index_i);
            
            ns_i = size(X_src_i,2); nt_i = size(X_tar_i,2);
            
            [Ls_i,Lt_i,Lst_i,Lts_i] = construct_mmd(ns_i,nt_i,Y_src_i,label_candidate_i,nnClass);
    
            Ms_i = X_src_i * Ls_i * X_src_i';

	        Mt_i = X_tar_i * Lt_i * X_tar_i';

	        Mst_i = X_src_i * Lst_i * X_tar_i';

	        Mts_i = X_tar_i * Lts_i * X_src_i';
            
            clear Ls_i Lt_i Lst_i Lts_i
            
            % ----------  Ps ----------- %
            if (iter == 1)
                Ps = P1s;
            else
                [U1,S1,V1] = svd(X_src_i*X_src_i'*Ws,'econ');
                Ps = U1*V1';
            end
            clear U1 S1 V1
             % ----------  Pt ----------- %
            if (iter == 1)
                Pt = P1t;
            else
                [U1,S1,V1] = svd(X_tar_i*X_tar_i'*Wt,'econ');
                Pt = U1*V1';
            end
            clear U1 S1 V1

            % ----------  Ws ----------- %
            XX = X_src_i*X_src_i';
            W1 = XX*Ps+lambda1/nnClass.*Wt-lambda2.*(Mts_i'+Mst_i)*Wt;
            W2 = XX+lambda2.*(Ms_i+Ms_i')+lambda1/nnClass.*eye(m);
            Ws = (W1'/W2)';
            clear W1 W2 XX
            % ----------  Wt ----------- %
            XX = X_tar_i*X_tar_i';
            W1 = XX*Pt+lambda1/nnClass.*Ws-lambda2.*(Mts_i+Mst_i')*Ws;
            W2 = XX+lambda2.*(Mt_i+Mt_i')+lambda1/nnClass.*eye(m);
            Wt = (W1'/W2)';
            clear W1 W2 XX XX_lable

            X_src_da = [X_src_da,Ws'*X_src_i]; X_tar_da = [X_tar_da, Wt'*X_tar_i]; 
            Ws_all{i} = Ws; Wt_all{i} = Wt; 
            temp = [Ws',Wt']*[Ms_i, Mst_i; Mts_i, Mt_i]*[Ws;Wt];
            obj_src = obj_src + norm(X_src_i-Ps*Ws'*X_src_i,'fro').^2;
            obj_tar = obj_tar + norm(X_tar_i-Pt*Wt'*X_tar_i,'fro').^2;
            obj_intra = obj_intra + trace(temp);
        end
               
        X_src_da_temp = []; X_tar_da_temp = [];
        for i = 1:nnClass
            src_index_i = Train_Lab == i;            
            X_src_i = X_src(:,src_index_i);
            tar_index_i = class_candidate(:) == i;
            X_tar_i = X_tar_ori(:,tar_index_i);
            X_src_da_temp = [X_src_da_temp,Ws_all{i}'*X_src_i]; X_tar_da_temp = [X_tar_da_temp, Wt_all{i}'*X_tar_i]; 
        end
        knn_model = fitcknn(X_src_da_temp',Train_Lab,'NumNeighbors',1);
        label = knn_model.predict(X_tar_da_temp');
        label_candidate = knn_model.predict(X_tar_da');
        acc = length(find(label == Test_Lab_ori)) / length(Test_Lab_ori);
        acc_list = [acc_list;acc];
        
        if iter>1 && acc_list(iter)>acc_max
            X_src_da_all = X_src_da_temp;
            X_tar_da_all = X_tar_da_temp;
            acc_max = acc_list(iter);
            
        elseif iter == 1
            X_src_da_all = X_src_da_temp;
            X_tar_da_all = X_tar_da_temp;
            acc_max = acc_list(iter);
        end
        
        obj(iter) = obj_src + obj_tar + lambda1*norm(Ws-Wt ,'fro').^2 + lambda2*obj_intra;  
        
        if iter > 10 && abs(obj(iter)-obj(iter-1)) < 1e-7
            iter
            knn_model = fitcknn(X_src_da_all',Train_Lab,'NumNeighbors',1);
            label = knn_model.predict(X_tar_da_all');
            break;
        elseif iter == max_iter
            knn_model = fitcknn(X_src_da_all',Train_Lab,'NumNeighbors',1);
            label = knn_model.predict(X_tar_da_all');
        end  
    
    end

end


function [Ms,Mt,Mst,Mts] = construct_mmd(ns,nt,Y_src,Y_tar_pseudo,C)

	es = 1 / ns * ones(ns,1);

	et = -1 / nt * ones(nt,1);

	e = [es;et];

	Ms = es * es' * C;

	Mt = et * et' * C;

	Mst = es * et' * C;

	Mts = et * es' * C;



	if ~isempty(Y_tar_pseudo) && length(Y_tar_pseudo) == nt

		for c = reshape(unique(Y_src),1,1)

			es = zeros(ns,1);

			et = zeros(nt,1);

			es(Y_src == c) = 1 / length(find(Y_src == c));

			et(Y_tar_pseudo == c) = -1 / length(find(Y_tar_pseudo == c));

			Ms = Ms + es * es';

			Mt = Mt + et * et';

			Mst = Mst + es * et';

			Mts = Mts + et * es';

		end

	end

    M = sum(e.^2);
    
	Ms = Ms / norm(M,'fro');

	Mt = Mt / norm(M,'fro');

	Mst = Mst / norm(M,'fro');

	Mts = Mts / norm(M,'fro');
    
    clear M 

end